ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240914__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +5 -41
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -4
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +4 -5
- ai_edge_torch/config.py +4 -1
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +35 -16
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/gemma/gemma.py +52 -32
- ai_edge_torch/generative/examples/gemma/gemma2.py +87 -60
- ai_edge_torch/generative/examples/{experimental/gemma → openelm}/convert_to_tflite.py +16 -18
- ai_edge_torch/generative/examples/openelm/openelm.py +237 -0
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +15 -16
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +48 -45
- ai_edge_torch/generative/examples/{experimental/tiny_llama → smollm}/convert_to_tflite.py +16 -17
- ai_edge_torch/generative/examples/smollm/smollm.py +131 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -6
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +20 -20
- ai_edge_torch/generative/examples/t5/t5.py +43 -30
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +75 -34
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +29 -10
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +57 -36
- ai_edge_torch/generative/fx_passes/__init__.py +4 -4
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -4
- ai_edge_torch/generative/layers/attention.py +84 -73
- ai_edge_torch/generative/layers/builder.py +38 -14
- ai_edge_torch/generative/layers/feed_forward.py +26 -8
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +61 -33
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/quantize/example.py +2 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +77 -62
- ai_edge_torch/generative/test/test_model_conversion_large.py +61 -68
- ai_edge_torch/generative/test/test_quantize.py +5 -5
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +28 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/export.py +40 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +44 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/RECORD +59 -63
- ai_edge_torch/_convert/fx_passes/_pass_base.py +0 -53
- ai_edge_torch/_convert/fx_passes/canonicalize_pass.py +0 -35
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → openelm}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/phi → smollm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240914.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,25 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
|
2
|
-
ai_edge_torch/config.py,sha256=
|
2
|
+
ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
|
+
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
4
5
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=rrWwWO1VLdM1khgk2URt5vN4icTeaTqw8CEIsnJRM0E,706
|
6
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
8
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
10
|
ai_edge_torch/_convert/converter.py,sha256=ezmaATnQi7NWDo37LUb-hEXtZSmT7_AT6vqXC6Fcq1o,8615
|
10
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
11
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
13
|
-
ai_edge_torch/_convert/fx_passes/
|
14
|
-
ai_edge_torch/_convert/fx_passes/
|
15
|
-
ai_edge_torch/_convert/fx_passes/
|
16
|
-
ai_edge_torch/_convert/fx_passes/canonicalize_pass.py,sha256=8jcKqWzG7p5r3Cu7DXNP-4o4X2bqLaoXY7N6W8QsZXo,1582
|
17
|
-
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=WKI8V9-V50agkiNVpBFWWp0BEpUfemdENuN1cEaGD-g,2370
|
13
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=xuRI0WehbUlxLHvuYjj8MeyIKBtcCp10D3E1uD1MRdw,1168
|
14
|
+
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=vd2phg5j3Exn6BuGpASe5cU_wY4JV_YcNTssM6Q9k2c,4169
|
16
|
+
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
18
17
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
|
20
19
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
|
22
21
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
|
23
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
22
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
|
24
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
25
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
|
26
25
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
@@ -39,28 +38,24 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa
|
|
39
38
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
39
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
40
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
|
-
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
44
|
-
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
|
45
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
|
46
|
-
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
|
48
|
-
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
|
49
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
|
51
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
|
52
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
55
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
57
|
-
ai_edge_torch/generative/examples/
|
58
|
-
ai_edge_torch/generative/examples/
|
59
|
-
ai_edge_torch/generative/examples/
|
42
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=bN_dtqi5C_dHpLsvXJ9vCb9OnZ0frLeyYoWBXZYJEqA,3061
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=fiFKkEe3TgOdpLnzsCZzIdwvEz0ikxDavQcRGQhlkBY,3053
|
44
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
|
45
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
|
46
|
+
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=DgBuR1uq4YQWfWiENBxrx7UCVr4Jc5kWCyoi6ii5DTE,3058
|
48
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=kQTJlCDz_DHLRLlVWE0JEpbOjIGAKtxH1fTSc-jn1nU,8498
|
49
|
+
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=_tP5ArL0FKiBNoOqN2rG351IzmhNKQmWUfewlcSdKDs,3024
|
51
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=mGyBI-nORoI-LhZkI4MFAonkUflIX9iimAer_K8jpck,7088
|
52
|
+
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=66APmBId5UayZ7SWSO1zxcLiM8TucOMA-fFEHhm61qs,3049
|
54
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=_nK2DAOiSuxv5o8ip0i-gmhvvjwF5e7Dm3m5VTcsR2M,4276
|
60
55
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
56
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
62
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
63
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
57
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
58
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
64
59
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
65
60
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
|
66
61
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
@@ -73,49 +68,49 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py
|
|
73
68
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6HyOoBJrmTh54KVFf7DjNBnBS0pT4cgviYaq8HGMU,2801
|
74
69
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
|
75
70
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
|
-
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=
|
77
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
78
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=
|
71
|
+
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F74-ru_8n1pt6cqfbObw12xoaMJ7NQ,4596
|
72
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=OZ67knK-UB1dBjxydG-Jwkp0Z3FzOCqGPTdg5aBFu4w,21328
|
73
|
+
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
|
79
74
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
81
|
-
ai_edge_torch/generative/examples/test_models/
|
82
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
75
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
76
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
83
77
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
84
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
85
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
86
|
-
ai_edge_torch/generative/fx_passes/__init__.py,sha256=
|
87
|
-
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=
|
78
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5u6aOiCVahHNCgax5k9a8uhJn9eMzLa19ldscFKNyWo,3083
|
79
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkdczTI-Lr-1GDg0R2g4SUUGEMTUZ5uY,7023
|
80
|
+
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
81
|
+
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
88
82
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
89
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
83
|
+
ai_edge_torch/generative/layers/attention.py,sha256=37Fua94dQSiBA9Y5XvHxGb5IfN8p8UgNgu5YwM1Rmrw,13057
|
90
84
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
91
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
92
|
-
ai_edge_torch/generative/layers/feed_forward.py,sha256=
|
93
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
94
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
95
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
85
|
+
ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
|
86
|
+
ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
|
87
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
88
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=zV3pA7giuKPrQdH81dpZz8D6LfGD-1YHuXuhIlypKc0,6784
|
89
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
|
96
90
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
97
91
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
98
92
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
99
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
93
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
|
100
94
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
101
95
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
|
102
96
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
103
|
-
ai_edge_torch/generative/quantize/example.py,sha256=
|
97
|
+
ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
|
104
98
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
105
99
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
|
106
100
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
107
101
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
108
102
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
109
103
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/test/
|
111
|
-
ai_edge_torch/generative/test/test_loader.py,sha256=
|
112
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
113
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
114
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
104
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
105
|
+
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
106
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
|
107
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TD7dELN5cVw5z9dvspFKO74Y_qIJ_VK0MYUoPdRf82Y,4498
|
108
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
109
|
+
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
115
110
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
116
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
111
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
117
112
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
118
|
-
ai_edge_torch/generative/utilities/t5_loader.py,sha256=
|
113
|
+
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
119
114
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
120
115
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
121
116
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -133,7 +128,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGP
|
|
133
128
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
134
129
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
135
130
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
136
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
131
|
+
ai_edge_torch/odml_torch/export.py,sha256=q6cvB3RAKL7hlQFNGqsE3u-NqWWPSzj-8M38u8loSpk,11544
|
137
132
|
ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
|
138
133
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
139
134
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -145,11 +140,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
|
|
145
140
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
146
141
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
147
142
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
148
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
149
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
143
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
144
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=E5j_xHuyDmA9fcgoi6p04zLGV9mFleyXzx6jSBi2wD0,8529
|
150
145
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
151
146
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
152
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
147
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=RN6BwMHuFj_rFgLCZ6Tu32XHbS2HGjPJeir2nROQ2rA,10517
|
148
|
+
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
153
149
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
154
150
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
155
151
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
@@ -161,8 +157,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
161
157
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
162
158
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
163
159
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
160
|
+
ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
161
|
+
ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/METADATA,sha256=6NayY4sdwm5Z4jmaIhk17MIQ3_plQOiWX_gGnL3KwPQ,1859
|
162
|
+
ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
163
|
+
ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240914.dist-info/RECORD,,
|
@@ -1,53 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import abc
|
17
|
-
from collections import namedtuple
|
18
|
-
|
19
|
-
import torch
|
20
|
-
from torch.export import ExportedProgram
|
21
|
-
from torch.fx.passes.infra.pass_base import PassBase as FxPassBase
|
22
|
-
from torch.fx.passes.infra.pass_base import PassResult as FxPassResult
|
23
|
-
|
24
|
-
|
25
|
-
class ExportedProgramPassResult(
|
26
|
-
namedtuple("ExportedProgramPassResult", ["exported_program", "modified"])
|
27
|
-
):
|
28
|
-
|
29
|
-
def __new__(cls, exported_program, modified):
|
30
|
-
return super().__new__(cls, exported_program, modified)
|
31
|
-
|
32
|
-
|
33
|
-
class ExportedProgramPassBase(abc.ABC):
|
34
|
-
|
35
|
-
def __call__(
|
36
|
-
self, exported_program: ExportedProgram
|
37
|
-
) -> ExportedProgramPassResult:
|
38
|
-
self.requires(exported_program)
|
39
|
-
res = self.call(exported_program)
|
40
|
-
self.ensures(exported_program)
|
41
|
-
return res
|
42
|
-
|
43
|
-
@abc.abstractmethod
|
44
|
-
def call(
|
45
|
-
self, exported_program: ExportedProgram
|
46
|
-
) -> ExportedProgramPassResult:
|
47
|
-
pass
|
48
|
-
|
49
|
-
def requires(self, exported_program: ExportedProgram) -> None:
|
50
|
-
pass
|
51
|
-
|
52
|
-
def ensures(self, exported_program: ExportedProgram) -> None:
|
53
|
-
pass
|
@@ -1,35 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from ai_edge_torch._convert.fx_passes import _pass_base
|
17
|
-
import torch
|
18
|
-
from torch.export import ExportedProgram
|
19
|
-
|
20
|
-
# A dummy decomp table for running ExportedProgram.run_decompositions without
|
21
|
-
# any op decompositions but just aot_export_module. Due to the check in
|
22
|
-
# run_decompositions, if None or an empty dict is passed as decomp_table,
|
23
|
-
# it will run the default aten-coreaten decompositions. Therefore a non-empty
|
24
|
-
# dummy decomp table is needed.
|
25
|
-
# Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
|
26
|
-
_dummy_decomp_table = {
|
27
|
-
torch._ops.OperatorBase(): lambda: None,
|
28
|
-
}
|
29
|
-
|
30
|
-
|
31
|
-
class CanonicalizePass(_pass_base.ExportedProgramPassBase):
|
32
|
-
|
33
|
-
def call(self, exported_program: ExportedProgram):
|
34
|
-
exported_program = exported_program.run_decompositions(_dummy_decomp_table)
|
35
|
-
return _pass_base.ExportedProgramPassResult(exported_program, True)
|
@@ -1,219 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# Example of building a Gemma model.
|
16
|
-
#
|
17
|
-
# Note: This is an experimental version of Gemma with external KV cache.
|
18
|
-
# Please use with caution.
|
19
|
-
|
20
|
-
import os
|
21
|
-
from pathlib import Path
|
22
|
-
from typing import Tuple
|
23
|
-
|
24
|
-
from ai_edge_torch.generative.layers import builder
|
25
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
29
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
|
-
import numpy as np
|
31
|
-
import torch
|
32
|
-
from torch import nn
|
33
|
-
|
34
|
-
|
35
|
-
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
|
-
ff_up_proj="model.layers.{}.mlp.up_proj",
|
37
|
-
ff_down_proj="model.layers.{}.mlp.down_proj",
|
38
|
-
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
39
|
-
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
40
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
41
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
42
|
-
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
43
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
44
|
-
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
45
|
-
embedding="model.embed_tokens",
|
46
|
-
final_norm="model.norm",
|
47
|
-
lm_head=None,
|
48
|
-
)
|
49
|
-
|
50
|
-
|
51
|
-
class Gemma(nn.Module):
|
52
|
-
"""A Gemma model built from the Edge Generative API layers."""
|
53
|
-
|
54
|
-
def __init__(self, config: cfg.ModelConfig):
|
55
|
-
super().__init__()
|
56
|
-
|
57
|
-
self.config = config
|
58
|
-
# Construct model layers.
|
59
|
-
self.tok_embedding = nn.Embedding(
|
60
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
61
|
-
)
|
62
|
-
self.lm_head = nn.Linear(
|
63
|
-
config.embedding_dim,
|
64
|
-
config.vocab_size,
|
65
|
-
bias=config.lm_head_use_bias,
|
66
|
-
)
|
67
|
-
# Gemma re-uses the embedding as the head projection layer.
|
68
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
69
|
-
self.transformer_blocks = nn.ModuleList(
|
70
|
-
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
71
|
-
)
|
72
|
-
self.final_norm = builder.build_norm(
|
73
|
-
config.embedding_dim,
|
74
|
-
config.final_norm_config,
|
75
|
-
)
|
76
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
77
|
-
size=config.kv_cache_max,
|
78
|
-
dim=int(
|
79
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
80
|
-
),
|
81
|
-
base=10_000,
|
82
|
-
condense_ratio=1,
|
83
|
-
dtype=torch.float32,
|
84
|
-
device=torch.device("cpu"),
|
85
|
-
)
|
86
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
87
|
-
size=config.kv_cache_max,
|
88
|
-
dtype=torch.float32,
|
89
|
-
device=torch.device("cpu"),
|
90
|
-
)
|
91
|
-
self.config = config
|
92
|
-
|
93
|
-
@torch.inference_mode
|
94
|
-
def forward(
|
95
|
-
self,
|
96
|
-
tokens: torch.Tensor,
|
97
|
-
input_pos: torch.Tensor,
|
98
|
-
kv_cache: kv_utils.EKVCache,
|
99
|
-
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
100
|
-
_, seq_len = tokens.size()
|
101
|
-
assert self.config.max_seq_len >= seq_len, (
|
102
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
103
|
-
f" {self.config.max_seq_len}"
|
104
|
-
)
|
105
|
-
|
106
|
-
cos, sin = self.rope_cache
|
107
|
-
cos = cos.index_select(0, input_pos)
|
108
|
-
sin = sin.index_select(0, input_pos)
|
109
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
110
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
111
|
-
|
112
|
-
# token embeddings of shape (b, t, n_embd)
|
113
|
-
x = self.tok_embedding(tokens)
|
114
|
-
x = x * (self.config.embedding_dim**0.5)
|
115
|
-
|
116
|
-
updated_kv_entires = []
|
117
|
-
for i, block in enumerate(self.transformer_blocks):
|
118
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
119
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
120
|
-
if kv_entry:
|
121
|
-
updated_kv_entires.append(kv_entry)
|
122
|
-
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
123
|
-
|
124
|
-
x = self.final_norm(x)
|
125
|
-
res = self.lm_head(x) # (b, t, vocab_size)
|
126
|
-
return res, updated_kv_cache
|
127
|
-
|
128
|
-
|
129
|
-
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
130
|
-
"""Returns the model config for a Gemma 2B model.
|
131
|
-
|
132
|
-
Args:
|
133
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
134
|
-
is 1024.
|
135
|
-
|
136
|
-
Returns:
|
137
|
-
The model config for a Gemma 2B model.
|
138
|
-
"""
|
139
|
-
attn_config = cfg.AttentionConfig(
|
140
|
-
num_heads=8,
|
141
|
-
head_dim=256,
|
142
|
-
num_query_groups=1,
|
143
|
-
rotary_percentage=1.0,
|
144
|
-
)
|
145
|
-
ff_config = cfg.FeedForwardConfig(
|
146
|
-
type=cfg.FeedForwardType.GATED,
|
147
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
148
|
-
intermediate_size=16384,
|
149
|
-
)
|
150
|
-
norm_config = cfg.NormalizationConfig(
|
151
|
-
type=cfg.NormalizationType.RMS_NORM,
|
152
|
-
epsilon=1e-6,
|
153
|
-
zero_centered=True,
|
154
|
-
)
|
155
|
-
config = cfg.ModelConfig(
|
156
|
-
vocab_size=256000,
|
157
|
-
num_layers=18,
|
158
|
-
max_seq_len=8192,
|
159
|
-
embedding_dim=2048,
|
160
|
-
kv_cache_max_len=kv_cache_max_len,
|
161
|
-
attn_config=attn_config,
|
162
|
-
ff_config=ff_config,
|
163
|
-
pre_attention_norm_config=norm_config,
|
164
|
-
post_attention_norm_config=norm_config,
|
165
|
-
final_norm_config=norm_config,
|
166
|
-
parallel_residual=False,
|
167
|
-
lm_head_use_bias=False,
|
168
|
-
enable_hlfb=True,
|
169
|
-
)
|
170
|
-
return config
|
171
|
-
|
172
|
-
|
173
|
-
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
174
|
-
config = get_model_config_2b(kv_cache_max_len)
|
175
|
-
config.ff_config.intermediate_size = 128
|
176
|
-
config.vocab_size = 128
|
177
|
-
config.num_layers = 2
|
178
|
-
config.max_seq_len = 2 * kv_cache_max_len
|
179
|
-
return config
|
180
|
-
|
181
|
-
|
182
|
-
def build_2b_model(
|
183
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
184
|
-
) -> nn.Module:
|
185
|
-
"""Instantiates the model instance and load checkpoint if provided."""
|
186
|
-
config = (
|
187
|
-
get_fake_model_config(**kwargs)
|
188
|
-
if test_model
|
189
|
-
else get_model_config_2b(**kwargs)
|
190
|
-
)
|
191
|
-
model = Gemma(config)
|
192
|
-
if checkpoint_path is not None:
|
193
|
-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
194
|
-
# since embedding and lm-head use the same weight, we need to set strict
|
195
|
-
# to False.
|
196
|
-
loader.load(model, strict=False)
|
197
|
-
model.eval()
|
198
|
-
return model
|
199
|
-
|
200
|
-
|
201
|
-
def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
|
202
|
-
"""Instantiates and runs a Gemma 2B model."""
|
203
|
-
|
204
|
-
kv_cache_max_len = 1024
|
205
|
-
model = build_2b_model(
|
206
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
207
|
-
)
|
208
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
209
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
210
|
-
tokens[0, :4] = idx
|
211
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
212
|
-
kv = kv_utils.EKVCache.from_model_config(model.config)
|
213
|
-
print("running an inference")
|
214
|
-
print(model.forward(tokens, input_pos, kv))
|
215
|
-
|
216
|
-
|
217
|
-
if __name__ == "__main__":
|
218
|
-
input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
|
219
|
-
define_and_run_2b(input_checkpoint_path)
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|