ai-edge-torch-nightly 0.3.0.dev20241002__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (27) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +10 -93
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -1
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +13 -2
  4. ai_edge_torch/generative/examples/llama/llama.py +19 -24
  5. ai_edge_torch/generative/examples/llama/verify.py +18 -3
  6. ai_edge_torch/generative/examples/openelm/openelm.py +9 -90
  7. ai_edge_torch/generative/examples/phi/phi2.py +10 -86
  8. ai_edge_torch/generative/examples/phi/phi3.py +9 -69
  9. ai_edge_torch/generative/examples/qwen/qwen.py +26 -36
  10. ai_edge_torch/generative/examples/smollm/smollm.py +10 -30
  11. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -3
  12. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +40 -32
  13. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +11 -101
  14. ai_edge_torch/generative/layers/model_config.py +6 -0
  15. ai_edge_torch/generative/test/test_loader.py +2 -1
  16. ai_edge_torch/generative/test/test_model_conversion.py +39 -17
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -5
  18. ai_edge_torch/generative/utilities/model_builder.py +141 -0
  19. ai_edge_torch/lowertools/translate_recipe.py +2 -2
  20. ai_edge_torch/version.py +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/METADATA +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/RECORD +25 -26
  23. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +0 -68
  24. ai_edge_torch/generative/examples/llama/verify_3b.py +0 -73
  25. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/LICENSE +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/WHEEL +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20241002.dist-info → ai_edge_torch_nightly-0.3.0.dev20241005.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,141 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities to be used for re-authoring transformer models."""
17
+
18
+ import copy
19
+
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
22
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+ from torch import nn
28
+
29
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
+ ff_up_proj="model.layers.{}.mlp.up_proj",
31
+ ff_down_proj="model.layers.{}.mlp.down_proj",
32
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
33
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
34
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
35
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
36
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
37
+ pre_attn_norm="model.layers.{}.input_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
+ embedding="model.embed_tokens",
40
+ final_norm="model.norm",
41
+ )
42
+
43
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
44
+ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
45
+
46
+
47
+ class DecoderOnlyModel(nn.Module):
48
+ """A simple decoder-only transformer model built from the Edge Generative API.
49
+
50
+ This model is used for re-authoring. model_config is used to specify the
51
+ details of model architecture and parameters.
52
+
53
+ It assumes that the attention configs for ROPE, i.e. head_dim, rotary_base,
54
+ and rotary_percentage are the same for all layers.
55
+ """
56
+
57
+ def __init__(self, config: cfg.ModelConfig):
58
+ super().__init__()
59
+
60
+ # Construct model layers.
61
+ self.tok_embedding = nn.Embedding(
62
+ config.vocab_size, config.embedding_dim, padding_idx=0
63
+ )
64
+ self.lm_head = nn.Linear(
65
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
66
+ )
67
+ if config.lm_head_share_weight_with_embedding:
68
+ self.lm_head.weight.data = self.tok_embedding.weight.data
69
+ self.transformer_blocks = nn.ModuleList(
70
+ attention.TransformerBlock(config.block_config(idx), config)
71
+ for idx in range(config.num_layers)
72
+ )
73
+ self.final_norm = builder.build_norm(
74
+ config.embedding_dim,
75
+ config.final_norm_config,
76
+ )
77
+ # ROPE parameters for all attn_configs are the same. Take the first one.
78
+ attn_config = config.block_config(0).attn_config
79
+ self.rope_cache = attn_utils.build_rope_cache(
80
+ size=config.kv_cache_max,
81
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
82
+ base=attn_config.rotary_base,
83
+ )
84
+ self.mask_cache = attn_utils.build_causal_mask_cache(
85
+ size=config.kv_cache_max,
86
+ )
87
+ self.config = config
88
+
89
+ @torch.inference_mode
90
+ def forward(
91
+ self,
92
+ tokens: torch.Tensor,
93
+ input_pos: torch.Tensor,
94
+ kv_cache: kv_utils.KVCache,
95
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
96
+ _, seq_len = tokens.size()
97
+ assert self.config.max_seq_len >= seq_len, (
98
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
99
+ f" {self.config.max_seq_len}"
100
+ )
101
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
+ "The number of transformer blocks and the number of KV cache entries"
103
+ " must be the same."
104
+ )
105
+
106
+ cos, sin = self.rope_cache
107
+ cos = cos.index_select(0, input_pos)
108
+ sin = sin.index_select(0, input_pos)
109
+ mask = self.mask_cache.index_select(2, input_pos)
110
+ mask = mask[:, :, :, : self.config.kv_cache_max]
111
+
112
+ # token embeddings of shape (b, t, n_embd)
113
+ x = self.tok_embedding(tokens)
114
+ if self.config.embedding_scale is not None:
115
+ x = x * self.config.embedding_scale
116
+
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
124
+
125
+ x = self.final_norm(x)
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
128
+
129
+
130
+ def build_decoder_only_model(
131
+ checkpoint_path: str,
132
+ config: cfg.ModelConfig,
133
+ tensor_names: loading_utils.ModelLoader.TensorNames,
134
+ ) -> DecoderOnlyModel:
135
+ transformer = DecoderOnlyModel(config)
136
+ loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
137
+ loader.load(
138
+ transformer, strict=not config.lm_head_share_weight_with_embedding
139
+ )
140
+ transformer.eval()
141
+ return transformer
@@ -156,8 +156,8 @@ def translate_to_ai_edge_recipe(
156
156
 
157
157
 
158
158
  def quantize_model(
159
- model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
159
+ model: bytes, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
160
160
  ) -> bytearray:
161
- qt = quantizer.Quantizer(bytearray(model), recipe)
161
+ qt = quantizer.Quantizer(model, recipe)
162
162
  result = qt.quantize()
163
163
  return result.quantized_model
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241002"
16
+ __version__ = "0.3.0.dev20241005"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241002
3
+ Version: 0.3.0.dev20241005
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=ODx8CRsxZZYlliSx6vnHxxTorI9c0WPgrVvwGY5KAQI,706
6
+ ai_edge_torch/version.py,sha256=y5TOP0Z8qFsjIuJuJtSmzOUpHyTa9UH46RdJjtRWYQA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -41,40 +41,38 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
43
43
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
44
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=kxWmmoVvtLP5auB3UXA2vsvZmSnpBs4SBixzYeAXzVA,6255
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=7VF5RYJ8QhROQNIlx-QovO-y6-jFp_EHgAkBNChZaqE,9066
44
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
46
46
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
47
47
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
48
48
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
49
49
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py,sha256=_OrerrTA6tvP9Tnwj601QO95Cm8PlOiYP-mxvtmBmb4,2186
51
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=GGo6Kxiwqva4JfurGx3InU3nROW70XtYvxUwEf_6mBQ,2180
52
- ai_edge_torch/generative/examples/llama/llama.py,sha256=5vlh2Z8vEPH8Z4LoHoFYCcuOQynx4mbVE37v3yMl1hE,7162
53
- ai_edge_torch/generative/examples/llama/verify.py,sha256=7xwKM_yzLCrmFsYj1UbsjW58ZG8Yic0xw1GFkdydrCU,2525
54
- ai_edge_torch/generative/examples/llama/verify_3b.py,sha256=IijBWqLXINOfwayM-8EIpc7OcC6Nj5CnberStx-vDSk,2528
50
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
51
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
52
+ ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
55
53
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
54
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
57
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hxbpvk0fNswzbqZfGteflqKMmkH7yzeMuW6r29s_xnQ,7374
55
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=JsrtuUY4q1Rovxsht2cGCuANUj1sUKnah6bAoSe8AoU,4387
58
56
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
59
57
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
58
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
61
59
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
62
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=82SEKRwtKfT9VcNQaykGmemiov_XaXWLi4Zyw9Vtmj0,6075
63
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=Xh-l7TQdXYZJ9PViRVk2_y91Ec7Yntn0UpkuzRIG3T8,9231
60
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=CQ55KfOdoOM43CxF7yNQsgq8b-j0S50bXpxYzgq-keM,3418
61
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
64
62
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
65
63
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
66
64
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
65
  ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=QAAVoSKDVf2rHAChzumGloVCWIU0Oe5UYKgv3T192Iw,2496
68
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=b03q1On6JzPhJzTs1dQwT_tJjO7C9NYmyzrzV2kQ_yo,4579
66
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
69
67
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
70
68
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
71
69
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
72
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=dal8vnZjQd6vR7sc76-FYGDKUlVjOlfUALV-pwbXJGc,3264
70
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
73
71
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
74
72
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
75
73
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
76
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lwWrKY1NpnbvHQRenpltVN65QlzjWmSScl5CLSipBkc,6110
77
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
74
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
75
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=i9mcBITt4jJqKLA4Qdt3uFotCrglv14tPg8VnqsVnaI,5004
78
76
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
79
77
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
80
78
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
@@ -96,7 +94,7 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYo
96
94
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=ZpjSIiayjTEVwg5Q1vI9Iy5tq1YSF5zaVDF4HTp_Z2s,4353
97
95
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
98
96
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
99
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=aSNHOAar5yPnGAeKsv8zrqYhOq9RR_7hwqHUMBb2mkM,5930
97
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
100
98
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
101
99
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
102
100
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
@@ -106,7 +104,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHif
106
104
  ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
107
105
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
108
106
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
109
- ai_edge_torch/generative/layers/model_config.py,sha256=Fa0eFCMlyfdwd3cM1drhP9vlXRhIguDrglsHn4ax2_w,6948
107
+ ai_edge_torch/generative/layers/model_config.py,sha256=xZt4xaNZJPvtdy4hfbnRencEENr689zO0WnZbhpNTIs,7137
110
108
  ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
111
109
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
112
110
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
@@ -123,14 +121,15 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
123
121
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
124
122
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
125
123
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
126
- ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
127
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
128
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=ASXTeO9TxjhqcNwXwbyMUP07aqye7wD6JU6OGZCEmR4,8907
124
+ ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
125
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
126
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
129
127
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
130
128
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
131
129
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
132
130
  ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
133
131
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
132
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
134
133
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
135
134
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
136
135
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
@@ -148,7 +147,7 @@ ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjF
148
147
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
149
148
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
150
149
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=S7RWzauts-15xP6VYuM3aAd9cyAGHstYD2A4dlv3d30,9059
151
- ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
150
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
152
151
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
153
152
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
154
153
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
@@ -181,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
181
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
182
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
183
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
184
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/METADATA,sha256=l2x0NhvSM0VtobvX6i8hXWKYdfjaRUizk42xaJrQXtw,1897
186
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
- ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/METADATA,sha256=O3P5ofz2aERMO1xbvIC7Z4RWsUNLJOZgn4pxEH3ftRc,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/RECORD,,
@@ -1,68 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Example of converting Llama 3.2 3B model to multi-signature tflite model."""
17
-
18
- import os
19
- import pathlib
20
-
21
- from absl import app
22
- from absl import flags
23
- from ai_edge_torch.generative.examples.llama import llama
24
- from ai_edge_torch.generative.utilities import converter
25
-
26
- _CHECKPOINT_PATH = flags.DEFINE_string(
27
- 'checkpoint_path',
28
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
29
- 'The path to the model checkpoint, or directory holding the checkpoint.',
30
- )
31
- _TFLITE_PATH = flags.DEFINE_string(
32
- 'tflite_path',
33
- '/tmp/',
34
- 'The tflite file path to export.',
35
- )
36
- _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
- 'prefill_seq_len',
38
- 1024,
39
- 'The maximum size of prefill input tensor.',
40
- )
41
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
- 'kv_cache_max_len',
43
- 1280,
44
- 'The maximum size of KV cache buffer, including both prefill and decode.',
45
- )
46
- _QUANTIZE = flags.DEFINE_bool(
47
- 'quantize',
48
- True,
49
- 'Whether the model should be quantized.',
50
- )
51
-
52
-
53
- def main(_):
54
- pytorch_model = llama.build_3b_model(
55
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
- )
57
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
- output_filename = f'llama_3b_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
- converter.convert_to_tflite(
60
- pytorch_model,
61
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
- prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
- quantize=_QUANTIZE.value,
64
- )
65
-
66
-
67
- if __name__ == '__main__':
68
- app.run(main)
@@ -1,73 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Verifies the reauthored Llama 3.2-3B model."""
17
-
18
- import logging
19
- import pathlib
20
-
21
- from absl import app
22
- from absl import flags
23
- from ai_edge_torch.generative.examples.llama import llama
24
- from ai_edge_torch.generative.utilities import transformers_verifier
25
- from ai_edge_torch.generative.utilities import verifier
26
- import transformers
27
-
28
-
29
- _PROMPTS = flags.DEFINE_multi_string(
30
- "prompts",
31
- "What is the meaning of life?",
32
- "The input prompts to generate answers.",
33
- )
34
- _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
- "max_new_tokens",
36
- 30,
37
- "The maximum size of the generated tokens.",
38
- )
39
-
40
-
41
- def main(_):
42
- checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
43
- logging.info("Loading the original model from: %s", checkpoint)
44
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
-
46
- # Locate the cached dir.
47
- cached_config_file = transformers.utils.cached_file(
48
- checkpoint, transformers.utils.CONFIG_NAME
49
- )
50
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
51
- logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
52
- reauthored_model = llama.build_3b_model(reauthored_checkpoint)
53
-
54
- logging.info("Loading the tokenizer from: %s", checkpoint)
55
- # Llama tokenizer_config.json sets a fast tokenizer class explicitly,
56
- # "PreTrainedTokenizerFast". It works only when the fast tokenizer is
57
- # available.
58
- tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
59
-
60
- verifier.verify_reauthored_model(
61
- original_model=transformers_verifier.TransformersModelWrapper(
62
- original_model
63
- ),
64
- reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
65
- tokenizer=verifier.TokenizerWrapper(tokenizer),
66
- generate_prompts=_PROMPTS.value,
67
- max_new_tokens=_MAX_NEW_TOKENS.value,
68
- atol=1e-04,
69
- )
70
-
71
-
72
- if __name__ == "__main__":
73
- app.run(main)