ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240911__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
 - ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
 - ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
 - ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
 - ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
 - ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
 - ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
 - ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
 - ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +38 -22
 - ai_edge_torch/generative/layers/attention.py +60 -63
 - ai_edge_torch/generative/layers/kv_cache.py +160 -51
 - ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
 - ai_edge_torch/generative/test/test_model_conversion.py +71 -33
 - ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
 - ai_edge_torch/generative/test/utils.py +54 -0
 - ai_edge_torch/version.py +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/METADATA +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/RECORD +22 -32
 - ai_edge_torch/generative/examples/experimental/gemma/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +0 -88
 - ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
 - ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
 - ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
 - ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
 - ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
 - ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
 - ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
 - ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
 - /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/LICENSE +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/WHEEL +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240911.dist-info}/top_level.txt +0 -0
 
| 
         @@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116 
     | 
|
| 
       2 
2 
     | 
    
         
             
            ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
         
     | 
| 
       3 
3 
     | 
    
         
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         
     | 
| 
       4 
4 
     | 
    
         
             
            ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
         
     | 
| 
       5 
     | 
    
         
            -
            ai_edge_torch/version.py,sha256= 
     | 
| 
      
 5 
     | 
    
         
            +
            ai_edge_torch/version.py,sha256=vCTKdj1Lc6r2UbJhIZpLdXauJSS0KfBLzgy9e3D16AA,706
         
     | 
| 
       6 
6 
     | 
    
         
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       7 
7 
     | 
    
         
             
            ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
         
     | 
| 
       8 
8 
     | 
    
         
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         
     | 
| 
         @@ -39,24 +39,14 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa 
     | 
|
| 
       39 
39 
     | 
    
         
             
            ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       40 
40 
     | 
    
         
             
            ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       41 
41 
     | 
    
         
             
            ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       42 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       43 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       44 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
         
     | 
| 
       45 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
         
     | 
| 
       46 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       47 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
         
     | 
| 
       48 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
         
     | 
| 
       49 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       50 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
         
     | 
| 
       51 
     | 
    
         
            -
            ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
         
     | 
| 
       52 
42 
     | 
    
         
             
            ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       53 
     | 
    
         
            -
            ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256= 
     | 
| 
       54 
     | 
    
         
            -
            ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256= 
     | 
| 
       55 
     | 
    
         
            -
            ai_edge_torch/generative/examples/gemma/gemma.py,sha256= 
     | 
| 
       56 
     | 
    
         
            -
            ai_edge_torch/generative/examples/gemma/gemma2.py,sha256= 
     | 
| 
       57 
     | 
    
         
            -
            ai_edge_torch/generative/examples/ 
     | 
| 
       58 
     | 
    
         
            -
            ai_edge_torch/generative/examples/ 
     | 
| 
       59 
     | 
    
         
            -
            ai_edge_torch/generative/examples/ 
     | 
| 
      
 43 
     | 
    
         
            +
            ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=ZJvw8uFVu7FEJ7eXfpzn-pPKgPELoxkGz4Zg7LKKMSI,3048
         
     | 
| 
      
 44 
     | 
    
         
            +
            ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=hM-fwjZG53p1UE_lkovLMmHRDHleJsb6_0ib0_k0v54,3040
         
     | 
| 
      
 45 
     | 
    
         
            +
            ai_edge_torch/generative/examples/gemma/gemma.py,sha256=oVV1lXgi9cMPES6JmiV8fJOgBQruRdHpyJL7MmXU09M,7283
         
     | 
| 
      
 46 
     | 
    
         
            +
            ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=X6WfUCDJDEqyyEAYGq1lmKtlDXcYLzy-n2moQPLJA_U,9769
         
     | 
| 
      
 47 
     | 
    
         
            +
            ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
      
 48 
     | 
    
         
            +
            ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
         
     | 
| 
      
 49 
     | 
    
         
            +
            ai_edge_torch/generative/examples/phi/phi2.py,sha256=BzvUrClFx5HKf6PYzJc7ba2O3AwYUJE485u5GSOiPy4,6851
         
     | 
| 
       60 
50 
     | 
    
         
             
            ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       61 
51 
     | 
    
         
             
            ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
         
     | 
| 
       62 
52 
     | 
    
         
             
            ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
         
     | 
| 
         @@ -78,19 +68,18 @@ ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8v 
     | 
|
| 
       78 
68 
     | 
    
         
             
            ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
         
     | 
| 
       79 
69 
     | 
    
         
             
            ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       80 
70 
     | 
    
         
             
            ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
         
     | 
| 
       81 
     | 
    
         
            -
            ai_edge_torch/generative/examples/test_models/ 
     | 
| 
       82 
     | 
    
         
            -
            ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
         
     | 
| 
      
 71 
     | 
    
         
            +
            ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=PbWpfg3AOEZjI1FlnZCxRD-kIKtdkR9AOZ6l-9-TpRA,5664
         
     | 
| 
       83 
72 
     | 
    
         
             
            ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       84 
     | 
    
         
            -
            ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256= 
     | 
| 
       85 
     | 
    
         
            -
            ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256= 
     | 
| 
      
 73 
     | 
    
         
            +
            ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
         
     | 
| 
      
 74 
     | 
    
         
            +
            ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=RK7oisSwIPqUWwwE1P-hDJlEnRJJ_V29UjUCxt4xETE,6780
         
     | 
| 
       86 
75 
     | 
    
         
             
            ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
         
     | 
| 
       87 
76 
     | 
    
         
             
            ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
         
     | 
| 
       88 
77 
     | 
    
         
             
            ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       89 
     | 
    
         
            -
            ai_edge_torch/generative/layers/attention.py,sha256= 
     | 
| 
      
 78 
     | 
    
         
            +
            ai_edge_torch/generative/layers/attention.py,sha256=ee0KHRakhjLjawP32FY2EntxOkyPvjiEZChLnBn_HPc,12601
         
     | 
| 
       90 
79 
     | 
    
         
             
            ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
         
     | 
| 
       91 
80 
     | 
    
         
             
            ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
         
     | 
| 
       92 
81 
     | 
    
         
             
            ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
         
     | 
| 
       93 
     | 
    
         
            -
            ai_edge_torch/generative/layers/kv_cache.py,sha256= 
     | 
| 
      
 82 
     | 
    
         
            +
            ai_edge_torch/generative/layers/kv_cache.py,sha256=WDu03NQwkDCrrrT9Du_3ZOxlURZz3XDbS1PLzFozhMI,6013
         
     | 
| 
       94 
83 
     | 
    
         
             
            ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
         
     | 
| 
       95 
84 
     | 
    
         
             
            ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
         
     | 
| 
       96 
85 
     | 
    
         
             
            ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
         
     | 
| 
         @@ -107,11 +96,12 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC 
     | 
|
| 
       107 
96 
     | 
    
         
             
            ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
         
     | 
| 
       108 
97 
     | 
    
         
             
            ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
         
     | 
| 
       109 
98 
     | 
    
         
             
            ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       110 
     | 
    
         
            -
            ai_edge_torch/generative/test/ 
     | 
| 
      
 99 
     | 
    
         
            +
            ai_edge_torch/generative/test/test_kv_cache.py,sha256=FU2rmU03Lp-vZ5wWXXCao1WEw7xbpqebFMANL_O2chA,3713
         
     | 
| 
       111 
100 
     | 
    
         
             
            ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
         
     | 
| 
       112 
     | 
    
         
            -
            ai_edge_torch/generative/test/test_model_conversion.py,sha256= 
     | 
| 
       113 
     | 
    
         
            -
            ai_edge_torch/generative/test/test_model_conversion_large.py,sha256= 
     | 
| 
      
 101 
     | 
    
         
            +
            ai_edge_torch/generative/test/test_model_conversion.py,sha256=OmAHSGkxTNzDX5_kYjK7pxlPk0YZLqL9YiVIJQfuvPc,5889
         
     | 
| 
      
 102 
     | 
    
         
            +
            ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
         
     | 
| 
       114 
103 
     | 
    
         
             
            ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
         
     | 
| 
      
 104 
     | 
    
         
            +
            ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
         
     | 
| 
       115 
105 
     | 
    
         
             
            ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
         
     | 
| 
       116 
106 
     | 
    
         
             
            ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
         
     | 
| 
       117 
107 
     | 
    
         
             
            ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
         
     | 
| 
         @@ -161,8 +151,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9 
     | 
|
| 
       161 
151 
     | 
    
         
             
            ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       162 
152 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         
     | 
| 
       163 
153 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         
     | 
| 
       164 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       165 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       166 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       167 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       168 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
      
 154 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         
     | 
| 
      
 155 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/METADATA,sha256=caHeAQX6pxEskue_BvgwkTfZEsG55rXHFwPDcV9oCN8,1859
         
     | 
| 
      
 156 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
         
     | 
| 
      
 157 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         
     | 
| 
      
 158 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/RECORD,,
         
     | 
| 
         @@ -1,14 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
         @@ -1,88 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            #
         
     | 
| 
       16 
     | 
    
         
            -
            # Note: This is an experimental version of Gemma with external KV cache.
         
     | 
| 
       17 
     | 
    
         
            -
            # Please use with caution.
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
            import os
         
     | 
| 
       21 
     | 
    
         
            -
            from pathlib import Path
         
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
            import ai_edge_torch
         
     | 
| 
       24 
     | 
    
         
            -
            from ai_edge_torch.generative.examples.experimental.gemma import gemma
         
     | 
| 
       25 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
         
     | 
| 
       26 
     | 
    
         
            -
            from ai_edge_torch.generative.quantize import quant_recipes
         
     | 
| 
       27 
     | 
    
         
            -
            import torch
         
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
       30 
     | 
    
         
            -
            def convert_gemma_to_tflite(
         
     | 
| 
       31 
     | 
    
         
            -
                checkpoint_path: str,
         
     | 
| 
       32 
     | 
    
         
            -
                prefill_seq_len: int = 512,
         
     | 
| 
       33 
     | 
    
         
            -
                kv_cache_max_len: int = 1024,
         
     | 
| 
       34 
     | 
    
         
            -
                quantize: bool = True,
         
     | 
| 
       35 
     | 
    
         
            -
            ):
         
     | 
| 
       36 
     | 
    
         
            -
              """An example method for converting a Gemma 2B model to multi-signature
         
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
              tflite model.
         
     | 
| 
       39 
     | 
    
         
            -
              Args:
         
     | 
| 
       40 
     | 
    
         
            -
                  checkpoint_path (str): The filepath to the model checkpoint, or directory
         
     | 
| 
       41 
     | 
    
         
            -
                    holding the checkpoint.
         
     | 
| 
       42 
     | 
    
         
            -
                  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
         
     | 
| 
       43 
     | 
    
         
            -
                    Defaults to 512.
         
     | 
| 
       44 
     | 
    
         
            -
                  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
         
     | 
| 
       45 
     | 
    
         
            -
                    including both prefill and decode. Defaults to 1024.
         
     | 
| 
       46 
     | 
    
         
            -
                  quantize (bool, optional): Whether the model should be quanized. Defaults
         
     | 
| 
       47 
     | 
    
         
            -
                    to True.
         
     | 
| 
       48 
     | 
    
         
            -
              """
         
     | 
| 
       49 
     | 
    
         
            -
              pytorch_model = gemma.build_2b_model(
         
     | 
| 
       50 
     | 
    
         
            -
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         
     | 
| 
       51 
     | 
    
         
            -
              )
         
     | 
| 
       52 
     | 
    
         
            -
              # Tensors used to trace the model graph during conversion.
         
     | 
| 
       53 
     | 
    
         
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
         
     | 
| 
       54 
     | 
    
         
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         
     | 
| 
       55 
     | 
    
         
            -
              decode_token = torch.tensor([[0]], dtype=torch.long)
         
     | 
| 
       56 
     | 
    
         
            -
              decode_input_pos = torch.tensor([0], dtype=torch.int64)
         
     | 
| 
       57 
     | 
    
         
            -
              kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
         
     | 
| 
       58 
     | 
    
         
            -
             
     | 
| 
       59 
     | 
    
         
            -
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         
     | 
| 
       60 
     | 
    
         
            -
              edge_model = (
         
     | 
| 
       61 
     | 
    
         
            -
                  ai_edge_torch.signature(
         
     | 
| 
       62 
     | 
    
         
            -
                      'prefill',
         
     | 
| 
       63 
     | 
    
         
            -
                      pytorch_model,
         
     | 
| 
       64 
     | 
    
         
            -
                      sample_kwargs={
         
     | 
| 
       65 
     | 
    
         
            -
                          'tokens': prefill_tokens,
         
     | 
| 
       66 
     | 
    
         
            -
                          'input_pos': prefill_input_pos,
         
     | 
| 
       67 
     | 
    
         
            -
                          'kv_cache': kv,
         
     | 
| 
       68 
     | 
    
         
            -
                      },
         
     | 
| 
       69 
     | 
    
         
            -
                  )
         
     | 
| 
       70 
     | 
    
         
            -
                  .signature(
         
     | 
| 
       71 
     | 
    
         
            -
                      'decode',
         
     | 
| 
       72 
     | 
    
         
            -
                      pytorch_model,
         
     | 
| 
       73 
     | 
    
         
            -
                      sample_kwargs={
         
     | 
| 
       74 
     | 
    
         
            -
                          'tokens': decode_token,
         
     | 
| 
       75 
     | 
    
         
            -
                          'input_pos': decode_input_pos,
         
     | 
| 
       76 
     | 
    
         
            -
                          'kv_cache': kv,
         
     | 
| 
       77 
     | 
    
         
            -
                      },
         
     | 
| 
       78 
     | 
    
         
            -
                  )
         
     | 
| 
       79 
     | 
    
         
            -
                  .convert(quant_config=quant_config)
         
     | 
| 
       80 
     | 
    
         
            -
              )
         
     | 
| 
       81 
     | 
    
         
            -
              edge_model.export(
         
     | 
| 
       82 
     | 
    
         
            -
                  f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
         
     | 
| 
       83 
     | 
    
         
            -
              )
         
     | 
| 
       84 
     | 
    
         
            -
             
     | 
| 
       85 
     | 
    
         
            -
             
     | 
| 
       86 
     | 
    
         
            -
            if __name__ == '__main__':
         
     | 
| 
       87 
     | 
    
         
            -
              checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
         
     | 
| 
       88 
     | 
    
         
            -
              convert_gemma_to_tflite(checkpoint_path)
         
     | 
| 
         @@ -1,219 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            # Example of building a Gemma model.
         
     | 
| 
       16 
     | 
    
         
            -
            #
         
     | 
| 
       17 
     | 
    
         
            -
            # Note: This is an experimental version of Gemma with external KV cache.
         
     | 
| 
       18 
     | 
    
         
            -
            # Please use with caution.
         
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
            import os
         
     | 
| 
       21 
     | 
    
         
            -
            from pathlib import Path
         
     | 
| 
       22 
     | 
    
         
            -
            from typing import Tuple
         
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
            from ai_edge_torch.generative.layers import builder
         
     | 
| 
       25 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.attention_utils as attn_utils
         
     | 
| 
       26 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import attention
         
     | 
| 
       27 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
         
     | 
| 
       28 
     | 
    
         
            -
            import ai_edge_torch.generative.layers.model_config as cfg
         
     | 
| 
       29 
     | 
    
         
            -
            import ai_edge_torch.generative.utilities.loader as loading_utils
         
     | 
| 
       30 
     | 
    
         
            -
            import numpy as np
         
     | 
| 
       31 
     | 
    
         
            -
            import torch
         
     | 
| 
       32 
     | 
    
         
            -
            from torch import nn
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
     | 
    
         
            -
             
     | 
| 
       35 
     | 
    
         
            -
            TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
         
     | 
| 
       36 
     | 
    
         
            -
                ff_up_proj="model.layers.{}.mlp.up_proj",
         
     | 
| 
       37 
     | 
    
         
            -
                ff_down_proj="model.layers.{}.mlp.down_proj",
         
     | 
| 
       38 
     | 
    
         
            -
                ff_gate_proj="model.layers.{}.mlp.gate_proj",
         
     | 
| 
       39 
     | 
    
         
            -
                attn_query_proj="model.layers.{}.self_attn.q_proj",
         
     | 
| 
       40 
     | 
    
         
            -
                attn_key_proj="model.layers.{}.self_attn.k_proj",
         
     | 
| 
       41 
     | 
    
         
            -
                attn_value_proj="model.layers.{}.self_attn.v_proj",
         
     | 
| 
       42 
     | 
    
         
            -
                attn_output_proj="model.layers.{}.self_attn.o_proj",
         
     | 
| 
       43 
     | 
    
         
            -
                pre_attn_norm="model.layers.{}.input_layernorm",
         
     | 
| 
       44 
     | 
    
         
            -
                post_attn_norm="model.layers.{}.post_attention_layernorm",
         
     | 
| 
       45 
     | 
    
         
            -
                embedding="model.embed_tokens",
         
     | 
| 
       46 
     | 
    
         
            -
                final_norm="model.norm",
         
     | 
| 
       47 
     | 
    
         
            -
                lm_head=None,
         
     | 
| 
       48 
     | 
    
         
            -
            )
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
            class Gemma(nn.Module):
         
     | 
| 
       52 
     | 
    
         
            -
              """A Gemma model built from the Edge Generative API layers."""
         
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
     | 
    
         
            -
              def __init__(self, config: cfg.ModelConfig):
         
     | 
| 
       55 
     | 
    
         
            -
                super().__init__()
         
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
       57 
     | 
    
         
            -
                self.config = config
         
     | 
| 
       58 
     | 
    
         
            -
                # Construct model layers.
         
     | 
| 
       59 
     | 
    
         
            -
                self.tok_embedding = nn.Embedding(
         
     | 
| 
       60 
     | 
    
         
            -
                    config.vocab_size, config.embedding_dim, padding_idx=0
         
     | 
| 
       61 
     | 
    
         
            -
                )
         
     | 
| 
       62 
     | 
    
         
            -
                self.lm_head = nn.Linear(
         
     | 
| 
       63 
     | 
    
         
            -
                    config.embedding_dim,
         
     | 
| 
       64 
     | 
    
         
            -
                    config.vocab_size,
         
     | 
| 
       65 
     | 
    
         
            -
                    bias=config.lm_head_use_bias,
         
     | 
| 
       66 
     | 
    
         
            -
                )
         
     | 
| 
       67 
     | 
    
         
            -
                # Gemma re-uses the embedding as the head projection layer.
         
     | 
| 
       68 
     | 
    
         
            -
                self.lm_head.weight.data = self.tok_embedding.weight.data
         
     | 
| 
       69 
     | 
    
         
            -
                self.transformer_blocks = nn.ModuleList(
         
     | 
| 
       70 
     | 
    
         
            -
                    attention.TransformerBlock(config) for _ in range(config.num_layers)
         
     | 
| 
       71 
     | 
    
         
            -
                )
         
     | 
| 
       72 
     | 
    
         
            -
                self.final_norm = builder.build_norm(
         
     | 
| 
       73 
     | 
    
         
            -
                    config.embedding_dim,
         
     | 
| 
       74 
     | 
    
         
            -
                    config.final_norm_config,
         
     | 
| 
       75 
     | 
    
         
            -
                )
         
     | 
| 
       76 
     | 
    
         
            -
                self.rope_cache = attn_utils.build_rope_cache(
         
     | 
| 
       77 
     | 
    
         
            -
                    size=config.kv_cache_max,
         
     | 
| 
       78 
     | 
    
         
            -
                    dim=int(
         
     | 
| 
       79 
     | 
    
         
            -
                        config.attn_config.rotary_percentage * config.attn_config.head_dim
         
     | 
| 
       80 
     | 
    
         
            -
                    ),
         
     | 
| 
       81 
     | 
    
         
            -
                    base=10_000,
         
     | 
| 
       82 
     | 
    
         
            -
                    condense_ratio=1,
         
     | 
| 
       83 
     | 
    
         
            -
                    dtype=torch.float32,
         
     | 
| 
       84 
     | 
    
         
            -
                    device=torch.device("cpu"),
         
     | 
| 
       85 
     | 
    
         
            -
                )
         
     | 
| 
       86 
     | 
    
         
            -
                self.mask_cache = attn_utils.build_causal_mask_cache(
         
     | 
| 
       87 
     | 
    
         
            -
                    size=config.kv_cache_max,
         
     | 
| 
       88 
     | 
    
         
            -
                    dtype=torch.float32,
         
     | 
| 
       89 
     | 
    
         
            -
                    device=torch.device("cpu"),
         
     | 
| 
       90 
     | 
    
         
            -
                )
         
     | 
| 
       91 
     | 
    
         
            -
                self.config = config
         
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
       93 
     | 
    
         
            -
              @torch.inference_mode
         
     | 
| 
       94 
     | 
    
         
            -
              def forward(
         
     | 
| 
       95 
     | 
    
         
            -
                  self,
         
     | 
| 
       96 
     | 
    
         
            -
                  tokens: torch.Tensor,
         
     | 
| 
       97 
     | 
    
         
            -
                  input_pos: torch.Tensor,
         
     | 
| 
       98 
     | 
    
         
            -
                  kv_cache: kv_utils.EKVCache,
         
     | 
| 
       99 
     | 
    
         
            -
              ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
         
     | 
| 
       100 
     | 
    
         
            -
                _, seq_len = tokens.size()
         
     | 
| 
       101 
     | 
    
         
            -
                assert self.config.max_seq_len >= seq_len, (
         
     | 
| 
       102 
     | 
    
         
            -
                    f"Cannot forward sequence of length {seq_len}, max seq length is only"
         
     | 
| 
       103 
     | 
    
         
            -
                    f" {self.config.max_seq_len}"
         
     | 
| 
       104 
     | 
    
         
            -
                )
         
     | 
| 
       105 
     | 
    
         
            -
             
     | 
| 
       106 
     | 
    
         
            -
                cos, sin = self.rope_cache
         
     | 
| 
       107 
     | 
    
         
            -
                cos = cos.index_select(0, input_pos)
         
     | 
| 
       108 
     | 
    
         
            -
                sin = sin.index_select(0, input_pos)
         
     | 
| 
       109 
     | 
    
         
            -
                mask = self.mask_cache.index_select(2, input_pos)
         
     | 
| 
       110 
     | 
    
         
            -
                mask = mask[:, :, :, : self.config.kv_cache_max]
         
     | 
| 
       111 
     | 
    
         
            -
             
     | 
| 
       112 
     | 
    
         
            -
                # token embeddings of shape (b, t, n_embd)
         
     | 
| 
       113 
     | 
    
         
            -
                x = self.tok_embedding(tokens)
         
     | 
| 
       114 
     | 
    
         
            -
                x = x * (self.config.embedding_dim**0.5)
         
     | 
| 
       115 
     | 
    
         
            -
             
     | 
| 
       116 
     | 
    
         
            -
                updated_kv_entires = []
         
     | 
| 
       117 
     | 
    
         
            -
                for i, block in enumerate(self.transformer_blocks):
         
     | 
| 
       118 
     | 
    
         
            -
                  kv_entry = kv_cache.caches[i] if kv_cache else None
         
     | 
| 
       119 
     | 
    
         
            -
                  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
         
     | 
| 
       120 
     | 
    
         
            -
                  if kv_entry:
         
     | 
| 
       121 
     | 
    
         
            -
                    updated_kv_entires.append(kv_entry)
         
     | 
| 
       122 
     | 
    
         
            -
                updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
         
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
                x = self.final_norm(x)
         
     | 
| 
       125 
     | 
    
         
            -
                res = self.lm_head(x)  # (b, t, vocab_size)
         
     | 
| 
       126 
     | 
    
         
            -
                return res, updated_kv_cache
         
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
             
     | 
| 
       129 
     | 
    
         
            -
            def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
         
     | 
| 
       130 
     | 
    
         
            -
              """Returns the model config for a Gemma 2B model.
         
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
              Args:
         
     | 
| 
       133 
     | 
    
         
            -
                kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
         
     | 
| 
       134 
     | 
    
         
            -
                  is 1024.
         
     | 
| 
       135 
     | 
    
         
            -
             
     | 
| 
       136 
     | 
    
         
            -
              Returns:
         
     | 
| 
       137 
     | 
    
         
            -
                The model config for a Gemma 2B model.
         
     | 
| 
       138 
     | 
    
         
            -
              """
         
     | 
| 
       139 
     | 
    
         
            -
              attn_config = cfg.AttentionConfig(
         
     | 
| 
       140 
     | 
    
         
            -
                  num_heads=8,
         
     | 
| 
       141 
     | 
    
         
            -
                  head_dim=256,
         
     | 
| 
       142 
     | 
    
         
            -
                  num_query_groups=1,
         
     | 
| 
       143 
     | 
    
         
            -
                  rotary_percentage=1.0,
         
     | 
| 
       144 
     | 
    
         
            -
              )
         
     | 
| 
       145 
     | 
    
         
            -
              ff_config = cfg.FeedForwardConfig(
         
     | 
| 
       146 
     | 
    
         
            -
                  type=cfg.FeedForwardType.GATED,
         
     | 
| 
       147 
     | 
    
         
            -
                  activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
         
     | 
| 
       148 
     | 
    
         
            -
                  intermediate_size=16384,
         
     | 
| 
       149 
     | 
    
         
            -
              )
         
     | 
| 
       150 
     | 
    
         
            -
              norm_config = cfg.NormalizationConfig(
         
     | 
| 
       151 
     | 
    
         
            -
                  type=cfg.NormalizationType.RMS_NORM,
         
     | 
| 
       152 
     | 
    
         
            -
                  epsilon=1e-6,
         
     | 
| 
       153 
     | 
    
         
            -
                  zero_centered=True,
         
     | 
| 
       154 
     | 
    
         
            -
              )
         
     | 
| 
       155 
     | 
    
         
            -
              config = cfg.ModelConfig(
         
     | 
| 
       156 
     | 
    
         
            -
                  vocab_size=256000,
         
     | 
| 
       157 
     | 
    
         
            -
                  num_layers=18,
         
     | 
| 
       158 
     | 
    
         
            -
                  max_seq_len=8192,
         
     | 
| 
       159 
     | 
    
         
            -
                  embedding_dim=2048,
         
     | 
| 
       160 
     | 
    
         
            -
                  kv_cache_max_len=kv_cache_max_len,
         
     | 
| 
       161 
     | 
    
         
            -
                  attn_config=attn_config,
         
     | 
| 
       162 
     | 
    
         
            -
                  ff_config=ff_config,
         
     | 
| 
       163 
     | 
    
         
            -
                  pre_attention_norm_config=norm_config,
         
     | 
| 
       164 
     | 
    
         
            -
                  post_attention_norm_config=norm_config,
         
     | 
| 
       165 
     | 
    
         
            -
                  final_norm_config=norm_config,
         
     | 
| 
       166 
     | 
    
         
            -
                  parallel_residual=False,
         
     | 
| 
       167 
     | 
    
         
            -
                  lm_head_use_bias=False,
         
     | 
| 
       168 
     | 
    
         
            -
                  enable_hlfb=True,
         
     | 
| 
       169 
     | 
    
         
            -
              )
         
     | 
| 
       170 
     | 
    
         
            -
              return config
         
     | 
| 
       171 
     | 
    
         
            -
             
     | 
| 
       172 
     | 
    
         
            -
             
     | 
| 
       173 
     | 
    
         
            -
            def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
         
     | 
| 
       174 
     | 
    
         
            -
              config = get_model_config_2b(kv_cache_max_len)
         
     | 
| 
       175 
     | 
    
         
            -
              config.ff_config.intermediate_size = 128
         
     | 
| 
       176 
     | 
    
         
            -
              config.vocab_size = 128
         
     | 
| 
       177 
     | 
    
         
            -
              config.num_layers = 2
         
     | 
| 
       178 
     | 
    
         
            -
              config.max_seq_len = 2 * kv_cache_max_len
         
     | 
| 
       179 
     | 
    
         
            -
              return config
         
     | 
| 
       180 
     | 
    
         
            -
             
     | 
| 
       181 
     | 
    
         
            -
             
     | 
| 
       182 
     | 
    
         
            -
            def build_2b_model(
         
     | 
| 
       183 
     | 
    
         
            -
                checkpoint_path: str, test_model: bool = False, **kwargs
         
     | 
| 
       184 
     | 
    
         
            -
            ) -> nn.Module:
         
     | 
| 
       185 
     | 
    
         
            -
              """Instantiates the model instance and load checkpoint if provided."""
         
     | 
| 
       186 
     | 
    
         
            -
              config = (
         
     | 
| 
       187 
     | 
    
         
            -
                  get_fake_model_config(**kwargs)
         
     | 
| 
       188 
     | 
    
         
            -
                  if test_model
         
     | 
| 
       189 
     | 
    
         
            -
                  else get_model_config_2b(**kwargs)
         
     | 
| 
       190 
     | 
    
         
            -
              )
         
     | 
| 
       191 
     | 
    
         
            -
              model = Gemma(config)
         
     | 
| 
       192 
     | 
    
         
            -
              if checkpoint_path is not None:
         
     | 
| 
       193 
     | 
    
         
            -
                loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
         
     | 
| 
       194 
     | 
    
         
            -
                # since embedding and lm-head use the same weight, we need to set strict
         
     | 
| 
       195 
     | 
    
         
            -
                # to False.
         
     | 
| 
       196 
     | 
    
         
            -
                loader.load(model, strict=False)
         
     | 
| 
       197 
     | 
    
         
            -
              model.eval()
         
     | 
| 
       198 
     | 
    
         
            -
              return model
         
     | 
| 
       199 
     | 
    
         
            -
             
     | 
| 
       200 
     | 
    
         
            -
             
     | 
| 
       201 
     | 
    
         
            -
            def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
         
     | 
| 
       202 
     | 
    
         
            -
              """Instantiates and runs a Gemma 2B model."""
         
     | 
| 
       203 
     | 
    
         
            -
             
     | 
| 
       204 
     | 
    
         
            -
              kv_cache_max_len = 1024
         
     | 
| 
       205 
     | 
    
         
            -
              model = build_2b_model(
         
     | 
| 
       206 
     | 
    
         
            -
                  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
         
     | 
| 
       207 
     | 
    
         
            -
              )
         
     | 
| 
       208 
     | 
    
         
            -
              idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
         
     | 
| 
       209 
     | 
    
         
            -
              tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
         
     | 
| 
       210 
     | 
    
         
            -
              tokens[0, :4] = idx
         
     | 
| 
       211 
     | 
    
         
            -
              input_pos = torch.arange(0, kv_cache_max_len)
         
     | 
| 
       212 
     | 
    
         
            -
              kv = kv_utils.EKVCache.from_model_config(model.config)
         
     | 
| 
       213 
     | 
    
         
            -
              print("running an inference")
         
     | 
| 
       214 
     | 
    
         
            -
              print(model.forward(tokens, input_pos, kv))
         
     | 
| 
       215 
     | 
    
         
            -
             
     | 
| 
       216 
     | 
    
         
            -
             
     | 
| 
       217 
     | 
    
         
            -
            if __name__ == "__main__":
         
     | 
| 
       218 
     | 
    
         
            -
              input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
         
     | 
| 
       219 
     | 
    
         
            -
              define_and_run_2b(input_checkpoint_path)
         
     | 
| 
         @@ -1,14 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
         @@ -1,14 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
         @@ -1,87 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            #
         
     | 
| 
       16 
     | 
    
         
            -
            # Note: This is an experimental version of TinyLlama with external KV cache.
         
     | 
| 
       17 
     | 
    
         
            -
            # Please use with caution.
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
            import os
         
     | 
| 
       21 
     | 
    
         
            -
            from pathlib import Path
         
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
            import ai_edge_torch
         
     | 
| 
       24 
     | 
    
         
            -
            from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
         
     | 
| 
       25 
     | 
    
         
            -
            from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
         
     | 
| 
       26 
     | 
    
         
            -
            from ai_edge_torch.generative.quantize import quant_recipes
         
     | 
| 
       27 
     | 
    
         
            -
            import torch
         
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
       30 
     | 
    
         
            -
            def convert_tiny_llama_to_tflite(
         
     | 
| 
       31 
     | 
    
         
            -
                checkpoint_path: str,
         
     | 
| 
       32 
     | 
    
         
            -
                prefill_seq_len: int = 512,
         
     | 
| 
       33 
     | 
    
         
            -
                kv_cache_max_len: int = 1024,
         
     | 
| 
       34 
     | 
    
         
            -
                quantize: bool = True,
         
     | 
| 
       35 
     | 
    
         
            -
            ):
         
     | 
| 
       36 
     | 
    
         
            -
              """An example for converting TinyLlama model to multi-signature tflite model.
         
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
              Args:
         
     | 
| 
       39 
     | 
    
         
            -
                  checkpoint_path (str): The filepath to the model checkpoint, or directory
         
     | 
| 
       40 
     | 
    
         
            -
                    holding the checkpoint.
         
     | 
| 
       41 
     | 
    
         
            -
                  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
         
     | 
| 
       42 
     | 
    
         
            -
                    Defaults to 512.
         
     | 
| 
       43 
     | 
    
         
            -
                  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
         
     | 
| 
       44 
     | 
    
         
            -
                    including both prefill and decode. Defaults to 1024.
         
     | 
| 
       45 
     | 
    
         
            -
                  quantize (bool, optional): Whether the model should be quanized. Defaults
         
     | 
| 
       46 
     | 
    
         
            -
                    to True.
         
     | 
| 
       47 
     | 
    
         
            -
              """
         
     | 
| 
       48 
     | 
    
         
            -
              pytorch_model = tiny_llama.build_model(
         
     | 
| 
       49 
     | 
    
         
            -
                  checkpoint_path, kv_cache_max_len=kv_cache_max_len
         
     | 
| 
       50 
     | 
    
         
            -
              )
         
     | 
| 
       51 
     | 
    
         
            -
              # Tensors used to trace the model graph during conversion.
         
     | 
| 
       52 
     | 
    
         
            -
              prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
         
     | 
| 
       53 
     | 
    
         
            -
              prefill_input_pos = torch.arange(0, prefill_seq_len)
         
     | 
| 
       54 
     | 
    
         
            -
              decode_token = torch.tensor([[0]], dtype=torch.long)
         
     | 
| 
       55 
     | 
    
         
            -
              decode_input_pos = torch.tensor([0], dtype=torch.int64)
         
     | 
| 
       56 
     | 
    
         
            -
              kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
              quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
         
     | 
| 
       59 
     | 
    
         
            -
              edge_model = (
         
     | 
| 
       60 
     | 
    
         
            -
                  ai_edge_torch.signature(
         
     | 
| 
       61 
     | 
    
         
            -
                      'prefill',
         
     | 
| 
       62 
     | 
    
         
            -
                      pytorch_model,
         
     | 
| 
       63 
     | 
    
         
            -
                      sample_kwargs={
         
     | 
| 
       64 
     | 
    
         
            -
                          'tokens': prefill_tokens,
         
     | 
| 
       65 
     | 
    
         
            -
                          'input_pos': prefill_input_pos,
         
     | 
| 
       66 
     | 
    
         
            -
                          'kv_cache': kv,
         
     | 
| 
       67 
     | 
    
         
            -
                      },
         
     | 
| 
       68 
     | 
    
         
            -
                  )
         
     | 
| 
       69 
     | 
    
         
            -
                  .signature(
         
     | 
| 
       70 
     | 
    
         
            -
                      'decode',
         
     | 
| 
       71 
     | 
    
         
            -
                      pytorch_model,
         
     | 
| 
       72 
     | 
    
         
            -
                      sample_kwargs={
         
     | 
| 
       73 
     | 
    
         
            -
                          'tokens': decode_token,
         
     | 
| 
       74 
     | 
    
         
            -
                          'input_pos': decode_input_pos,
         
     | 
| 
       75 
     | 
    
         
            -
                          'kv_cache': kv,
         
     | 
| 
       76 
     | 
    
         
            -
                      },
         
     | 
| 
       77 
     | 
    
         
            -
                  )
         
     | 
| 
       78 
     | 
    
         
            -
                  .convert(quant_config=quant_config)
         
     | 
| 
       79 
     | 
    
         
            -
              )
         
     | 
| 
       80 
     | 
    
         
            -
              edge_model.export(
         
     | 
| 
       81 
     | 
    
         
            -
                  f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
         
     | 
| 
       82 
     | 
    
         
            -
              )
         
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
       84 
     | 
    
         
            -
             
     | 
| 
       85 
     | 
    
         
            -
            if __name__ == '__main__':
         
     | 
| 
       86 
     | 
    
         
            -
              checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
         
     | 
| 
       87 
     | 
    
         
            -
              convert_tiny_llama_to_tflite(checkpoint_path)
         
     |