ai-edge-torch-nightly 0.3.0.dev20241221__py3-none-any.whl → 0.3.0.dev20241222__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -72,12 +72,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
72
72
  pre_attention_norm_config=norm_config,
73
73
  post_attention_norm_config=norm_config,
74
74
  )
75
+ embedding_dim = 2048
75
76
  config = cfg.ModelConfig(
76
77
  vocab_size=256000,
77
78
  num_layers=18,
78
79
  max_seq_len=8192,
79
- embedding_dim=2048,
80
- embedding_scale=2048**0.5,
80
+ embedding_dim=embedding_dim,
81
+ embedding_scale=embedding_dim**0.5,
81
82
  kv_cache_max_len=kv_cache_max_len,
82
83
  block_configs=block_config,
83
84
  final_norm_config=norm_config,
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
@@ -136,29 +136,45 @@ class Gemma2(nn.Module):
136
136
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
137
137
  f" {self.config.max_seq_len}"
138
138
  )
139
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
140
- "The number of transformer blocks and the number of KV cache entries"
141
- " must be the same."
142
- )
143
139
 
140
+ # token embeddings of shape (b, t, n_embd)
141
+ input_embeds = self.tok_embedding(tokens)
144
142
  # RoPE parameters are the same for all blocks. Use the first layer.
145
143
  attn_config = self.config.block_config(0).attn_config
146
144
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
147
145
  rope = rotary_pos_emb.build_rope(
148
146
  input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
149
147
  )
148
+ mask = [self.get_attention_mask(
149
+ self.config.block_config(i).attn_config.attn_type, input_pos
150
+ ) for i in range(self.config.num_layers)]
150
151
 
151
- # token embeddings of shape (b, t, n_embd)
152
- x = self.tok_embedding(tokens)
153
- x = x * (self.config.embedding_dim**0.5)
152
+ return self._forward_with_embeds(
153
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
154
+ )
155
+
156
+ def _forward_with_embeds(
157
+ self,
158
+ input_embeds: torch.Tensor,
159
+ rope: Tuple[torch.Tensor, torch.Tensor],
160
+ mask: List[torch.Tensor],
161
+ input_pos: torch.Tensor,
162
+ kv_cache: kv_utils.KVCache,
163
+ export_config: Optional[model_builder.ExportConfig] = None,
164
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
+ """Forwards the model with input embeddings."""
166
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
167
+ "The number of transformer blocks and the number of KV cache entries"
168
+ " must be the same."
169
+ )
154
170
 
171
+ if self.config.embedding_scale is not None:
172
+ input_embeds = input_embeds * self.config.embedding_scale
173
+ x = input_embeds
155
174
  updated_kv_entries = []
156
175
  for i, block in enumerate(self.transformer_blocks):
157
- mask = self.get_attention_mask(
158
- block.config.attn_config.attn_type, input_pos
159
- )
160
176
  kv_entry = kv_cache.caches[i] if kv_cache else None
161
- x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
177
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162
178
  if kv_entry:
163
179
  updated_kv_entries.append(kv_entry)
164
180
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
@@ -227,11 +243,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
227
243
  )
228
244
 
229
245
  num_layers = 26
246
+ embedding_dim = 2304
230
247
  config = cfg.ModelConfig(
231
248
  vocab_size=256000,
232
249
  num_layers=num_layers,
233
250
  max_seq_len=8192,
234
- embedding_dim=2304,
251
+ embedding_dim=embedding_dim,
252
+ embedding_scale=embedding_dim**0.5,
235
253
  kv_cache_max_len=kv_cache_max_len,
236
254
  block_configs=[get_block_config(i) for i in range(num_layers)],
237
255
  final_norm_config=norm_config,
@@ -248,6 +266,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
248
266
  config.num_layers = 2
249
267
  config.max_seq_len = 2 * kv_cache_max_len
250
268
  config.embedding_dim = 128
269
+ config.embedding_scale = config.embedding_dim**0.5
251
270
  config.block_configs = config.block_configs[: config.num_layers]
252
271
  for block_config in config.block_configs:
253
272
  block_config.attn_config.num_heads = 4
@@ -55,6 +55,7 @@ class Decoder(model_builder.DecoderOnlyModel):
55
55
  kv_cache: kv_utils.KVCache,
56
56
  input_embeds: torch.Tensor = None,
57
57
  export_config: Optional[model_builder.ExportConfig] = None,
58
+ called_by_generate: bool = True,
58
59
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
59
60
  if input_embeds is None:
60
61
  return super().forward(tokens, input_pos, kv_cache)
@@ -75,7 +76,7 @@ class Decoder(model_builder.DecoderOnlyModel):
75
76
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
76
77
  mask[:, embeds_len:] = float("-inf")
77
78
 
78
- return self.forward_with_embeds(
79
+ return self._forward_with_embeds(
79
80
  input_embeds, rope, mask, input_pos, kv_cache
80
81
  )
81
82
 
@@ -113,12 +114,13 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
113
114
  pre_attention_norm_config=norm_config,
114
115
  post_attention_norm_config=norm_config,
115
116
  )
117
+ embedding_dim = 2048
116
118
  config = cfg.ModelConfig(
117
119
  vocab_size=257216,
118
120
  num_layers=18,
119
121
  max_seq_len=8192,
120
- embedding_dim=2048,
121
- embedding_scale=2048**0.5,
122
+ embedding_dim=embedding_dim,
123
+ embedding_scale=embedding_dim**0.5,
122
124
  kv_cache_max_len=kv_cache_max_len,
123
125
  block_configs=block_config,
124
126
  final_norm_config=norm_config,
@@ -0,0 +1,172 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a decoder of PaliGemma2 3B model which is Gemma2."""
17
+
18
+ from typing import Optional
19
+
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
24
+ from ai_edge_torch.generative.utilities import model_builder
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+
28
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
29
+ ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
30
+ ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
31
+ ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
32
+ attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
33
+ attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
34
+ attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
35
+ attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
36
+ pre_attn_norm="language_model.model.layers.{}.input_layernorm",
37
+ post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
38
+ pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
39
+ post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
40
+ embedding="language_model.model.embed_tokens",
41
+ final_norm="language_model.model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Decoder2(gemma2.Gemma2):
47
+ """A decoder of PaliGemma2 3B model which is Gemma2.
48
+
49
+ Besides a tensor of text token IDs, forward() can also take a tensor of
50
+ embeddings which may include text or image or both.
51
+ """
52
+
53
+ @torch.inference_mode
54
+ def forward(
55
+ self,
56
+ tokens: torch.Tensor,
57
+ input_pos: torch.Tensor,
58
+ kv_cache: kv_utils.KVCache,
59
+ input_embeds: torch.Tensor = None,
60
+ export_config: Optional[model_builder.ExportConfig] = None,
61
+ called_by_generate: bool = True,
62
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
63
+ if input_embeds is None:
64
+ return super().forward(tokens, input_pos, kv_cache)
65
+
66
+ assert input_embeds is not None
67
+
68
+ repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69
+ # ROPE parameters for all attn_configs are the same. Take the first one.
70
+ attn_config = self.config.block_config(0).attn_config
71
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72
+ rope = rotary_pos_emb.build_rope(
73
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
74
+ )
75
+
76
+ if called_by_generate:
77
+ # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78
+ mask = [self.get_attention_mask(
79
+ self.config.block_config(i).attn_config.attn_type, input_pos
80
+ ) for i in range(self.config.num_layers)]
81
+ else:
82
+ # By default, don't mask image embeds with a diagonal causal mask.
83
+ embeds_len = input_embeds.shape[1]
84
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
85
+ mask[:, embeds_len:] = float("-inf")
86
+ mask = [mask] * self.config.num_layers
87
+
88
+ return self._forward_with_embeds(
89
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
90
+ )
91
+
92
+
93
+ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
94
+ """Returns the model config for the decoder of a PaliGemma 3B model.
95
+
96
+ Args:
97
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
98
+ is 1024.
99
+
100
+ Returns:
101
+ The model config for the decoder of a PaliGemma 3B model.
102
+ """
103
+ norm_config = cfg.NormalizationConfig(
104
+ type=cfg.NormalizationType.RMS_NORM,
105
+ epsilon=1e-6,
106
+ zero_centered=True,
107
+ )
108
+ ff_config = cfg.FeedForwardConfig(
109
+ type=cfg.FeedForwardType.GATED,
110
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
111
+ intermediate_size=9216,
112
+ pre_ff_norm_config=norm_config,
113
+ post_ff_norm_config=norm_config,
114
+ )
115
+
116
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
117
+ attn_config = cfg.AttentionConfig(
118
+ num_heads=8,
119
+ head_dim=256,
120
+ num_query_groups=4,
121
+ rotary_base=10000,
122
+ rotary_percentage=1.0,
123
+ logit_softcap=50.0,
124
+ sliding_window_size=4096,
125
+ attn_type=(
126
+ cfg.AttentionType.GLOBAL
127
+ if idx % 2 == 0
128
+ else cfg.AttentionType.LOCAL_SLIDING
129
+ ),
130
+ )
131
+ return cfg.TransformerBlockConfig(
132
+ attn_config=attn_config,
133
+ ff_config=ff_config,
134
+ pre_attention_norm_config=norm_config,
135
+ post_attention_norm_config=norm_config,
136
+ )
137
+
138
+ num_layers = 26
139
+ embedding_dim = 2304
140
+ config = cfg.ModelConfig(
141
+ vocab_size=257216,
142
+ num_layers=num_layers,
143
+ max_seq_len=8192,
144
+ embedding_dim=embedding_dim,
145
+ embedding_scale=embedding_dim**0.5,
146
+ kv_cache_max_len=kv_cache_max_len,
147
+ block_configs=[get_block_config(i) for i in range(num_layers)],
148
+ final_norm_config=norm_config,
149
+ lm_head_use_bias=False,
150
+ enable_hlfb=True,
151
+ final_logit_softcap=30.0,
152
+ )
153
+ return config
154
+
155
+
156
+ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157
+ config = get_decoder2_config(kv_cache_max_len)
158
+ # PaliGemma2 decoder has only one block config.
159
+ config.block_config(0).ff_config.intermediate_size = 128
160
+ config.vocab_size = 128
161
+ config.num_layers = 2
162
+ config.max_seq_len = 2 * kv_cache_max_len
163
+ return config
164
+
165
+
166
+ def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
167
+ return model_builder.build_decoder_only_model(
168
+ checkpoint_path=checkpoint_path,
169
+ config=get_decoder2_config(**kwargs),
170
+ tensor_names=TENSOR_NAMES,
171
+ model_class=Decoder2,
172
+ )
@@ -19,6 +19,7 @@ from dataclasses import dataclass
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
22
+ from ai_edge_torch.generative.examples.paligemma import decoder2
22
23
  from ai_edge_torch.generative.examples.paligemma import image_encoder
23
24
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
@@ -38,13 +39,14 @@ class PaliGemmaConfig:
38
39
  decoder_config: cfg.ModelConfig
39
40
 
40
41
  image_token_id: int
42
+ image_projection_scale: float
41
43
  image_projection_use_bias: bool = False
42
44
 
43
45
 
44
46
  class PaliGemma(nn.Module):
45
47
  """PaliGemma model from the Edge Generative API."""
46
48
 
47
- def __init__(self, config: PaliGemmaConfig):
49
+ def __init__(self, config: PaliGemmaConfig, decoder_class: nn.Module):
48
50
  super().__init__()
49
51
 
50
52
  self.image_encoder = image_encoder.SiglipVisionEncoder(
@@ -55,7 +57,7 @@ class PaliGemma(nn.Module):
55
57
  config.decoder_config.embedding_dim,
56
58
  bias=config.image_projection_use_bias,
57
59
  )
58
- self.decoder = decoder.Decoder(config.decoder_config)
60
+ self.decoder = decoder_class(config.decoder_config)
59
61
  image_embedding_config = config.image_encoder_config.image_embedding
60
62
  self.num_patches = (
61
63
  image_embedding_config.image_size // image_embedding_config.patch_size
@@ -70,6 +72,7 @@ class PaliGemma(nn.Module):
70
72
  kv_cache: kv_utils.KVCache,
71
73
  pixel_values: torch.Tensor = None,
72
74
  export_config: Optional[model_builder.ExportConfig] = None,
75
+ called_by_generate: bool = True,
73
76
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
74
77
  if pixel_values is None:
75
78
  return self.decoder(
@@ -77,15 +80,15 @@ class PaliGemma(nn.Module):
77
80
  input_pos=input_pos,
78
81
  kv_cache=kv_cache,
79
82
  input_embeds=None,
80
- export_config=export_config
83
+ export_config=export_config,
84
+ called_by_generate=called_by_generate,
81
85
  )
82
86
 
83
87
  input_embeds = self.decoder.tok_embedding(tokens)
84
88
 
85
89
  image_encoded = self.image_encoder(pixel_values=pixel_values)
86
90
  image_embeds = self.image_projection(image_encoded)
87
- if self.config.decoder_config.embedding_scale is not None:
88
- image_embeds = image_embeds / self.config.decoder_config.embedding_scale
91
+ image_embeds = image_embeds / self.config.image_projection_scale
89
92
 
90
93
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
91
94
  # can be done like:
@@ -110,10 +113,11 @@ class PaliGemma(nn.Module):
110
113
  kv_cache=kv_cache,
111
114
  input_embeds=input_embeds,
112
115
  export_config=export_config,
116
+ called_by_generate=called_by_generate,
113
117
  )
114
118
 
115
119
 
116
- def get_model_config(**kwargs) -> PaliGemmaConfig:
120
+ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
117
121
  """Returns the model config for a PaliGemma 3B-224 model.
118
122
 
119
123
  Returns:
@@ -121,31 +125,42 @@ def get_model_config(**kwargs) -> PaliGemmaConfig:
121
125
  """
122
126
  return PaliGemmaConfig(
123
127
  image_encoder_config=image_encoder.get_image_encoder_config(),
124
- decoder_config=decoder.get_decoder_config(**kwargs),
125
- image_projection_use_bias=True,
128
+ decoder_config=get_decoder_config(**kwargs),
126
129
  image_token_id=257152,
130
+ image_projection_scale=2048**0.5,
131
+ image_projection_use_bias=True,
127
132
  )
128
133
 
129
134
 
130
- def get_fake_model_config() -> PaliGemmaConfig:
135
+ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
131
136
  return PaliGemmaConfig(
132
137
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
133
- decoder_config=decoder.get_fake_decoder_config(),
134
- image_projection_use_bias=True,
138
+ decoder_config=get_decoder_config(**kwargs),
135
139
  image_token_id=257152,
140
+ image_projection_scale=2048**0.5,
141
+ image_projection_use_bias=True,
136
142
  )
137
143
 
138
144
 
139
- def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
140
- config = get_model_config(**kwargs)
141
- model = PaliGemma(config)
145
+ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
146
+ if version == 1:
147
+ decoder_class = decoder.Decoder
148
+ decoder_tensor_names = decoder.TENSOR_NAMES
149
+ get_decoder_config = decoder.get_decoder_config
150
+ else:
151
+ decoder_class = decoder2.Decoder2
152
+ decoder_tensor_names = decoder2.TENSOR_NAMES
153
+ get_decoder_config = decoder2.get_decoder2_config
154
+
155
+ config = get_model_config(get_decoder_config, **kwargs)
156
+ model = PaliGemma(config, decoder_class)
142
157
  # Load the parameters of image encoder.
143
158
  loader = loading_utils.ModelLoader(
144
159
  checkpoint_path, image_encoder.TENSOR_NAMES
145
160
  )
146
161
  loader.load(model.image_encoder, strict=False)
147
162
  # Load the parameters of decoder.
148
- loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
163
+ loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
149
164
  loader.load(model.decoder, strict=False)
150
165
 
151
166
  # Load the parameters of image projection.
@@ -22,11 +22,18 @@ from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import paligemma
23
23
  from ai_edge_torch.generative.layers import kv_cache
24
24
  from ai_edge_torch.generative.utilities import verifier
25
+ import kagglehub
25
26
  from PIL import Image
26
27
  import requests
27
28
  import torch
28
29
  import transformers
29
30
 
31
+ _VERSION = flags.DEFINE_enum(
32
+ "version",
33
+ "1",
34
+ ["1", "2"],
35
+ "The version of PaliGemma model to verify.",
36
+ )
30
37
  _IMAGE_URL = flags.DEFINE_string(
31
38
  "image_url",
32
39
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
@@ -34,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
34
41
  )
35
42
  _PROMPTS = flags.DEFINE_string(
36
43
  "prompts",
37
- "Caption en",
44
+ "describe en",
38
45
  "The input prompts to generate answers.",
39
46
  )
40
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -43,28 +50,47 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
43
50
  "The maximum size of the generated tokens.",
44
51
  )
45
52
 
53
+ _CHECKPOINT = {
54
+ "1": "google/paligemma-3b-mix-224",
55
+ "2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
56
+ }
57
+
46
58
 
47
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
48
60
  """Reauthored PaliGemma model wrapper."""
49
61
 
62
+ def __init__(self, model: torch.nn.Module):
63
+ super().__init__(model)
64
+ self.forward_called_by_generate = False
65
+
50
66
  def _init_kv_cache(self):
51
67
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
52
68
 
69
+ def _get_extra_args_for_forward(self):
70
+ return {"called_by_generate": self.forward_called_by_generate}
71
+
53
72
 
54
73
  def main(_):
55
- checkpoint = "google/paligemma-3b-mix-224"
74
+ if _VERSION.value == "1":
75
+ checkpoint = _CHECKPOINT[_VERSION.value]
76
+ # Locate the cached dir.
77
+ cached_config_file = transformers.utils.cached_file(
78
+ checkpoint, transformers.utils.CONFIG_NAME
79
+ )
80
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
81
+ else:
82
+ checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
83
+ reauthored_checkpoint = checkpoint
84
+
56
85
  logging.info("Loading the original model from: %s", checkpoint)
57
86
  original_model = (
58
87
  transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
59
88
  )
60
89
 
61
- # Locate the cached dir.
62
- cached_config_file = transformers.utils.cached_file(
63
- checkpoint, transformers.utils.CONFIG_NAME
64
- )
65
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
66
90
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
- reauthored_model = paligemma.build_model(reauthored_checkpoint)
91
+ reauthored_model = paligemma.build_model(
92
+ reauthored_checkpoint, version=int(_VERSION.value)
93
+ )
68
94
 
69
95
  logging.info("Loading the processor from: %s", checkpoint)
70
96
  # It works only when GemmaTokenizerFast is available. In some environments,
@@ -93,7 +119,7 @@ def main(_):
93
119
  logging.info("outputs_reauthored: %s", outputs_reauthored)
94
120
 
95
121
  try:
96
- assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
122
+ assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-02)
97
123
  except AssertionError as e:
98
124
  logging.error("*** FAILED *** verify with forward()")
99
125
  raise e
@@ -111,6 +137,7 @@ def main(_):
111
137
  logging.info("outputs_from_original_model: [[%s]]", response_original)
112
138
 
113
139
  logging.info("Generating answer with the reauthored model...")
140
+ wrapped_reauthored_model.forward_called_by_generate = True
114
141
  outputs_reauthored = wrapped_reauthored_model.generate(
115
142
  prompts=inputs["input_ids"],
116
143
  pixel_values=inputs["pixel_values"],
@@ -0,0 +1,72 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored decoder of PaliGemma2 3B model."""
17
+
18
+ import logging
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.paligemma import decoder2
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
+ from ai_edge_torch.generative.utilities import verifier
25
+ import kagglehub
26
+ import transformers
27
+
28
+ _PROMPTS = flags.DEFINE_multi_string(
29
+ "prompts",
30
+ "What is the meaning of life?",
31
+ "The input prompts to generate answers.",
32
+ )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
38
+
39
+
40
+ def main(_):
41
+ checkpoint = kagglehub.model_download(
42
+ "google/paligemma-2/transformers/paligemma2-3b-pt-224"
43
+ )
44
+ logging.info("Loading the original model from: %s", checkpoint)
45
+ original_full_model = (
46
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
47
+ )
48
+ original_language_model = original_full_model.eval().language_model
49
+
50
+ logging.info("Building the reauthored model from: %s", checkpoint)
51
+ reauthored_model = decoder2.build_decoder2(checkpoint)
52
+
53
+ logging.info("Loading the tokenizer from: %s", checkpoint)
54
+ # It works only when GemmaTokenizerFast is available. In some environments,
55
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
56
+ # sentencepiece model file properly.
57
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
58
+
59
+ verifier.verify_reauthored_model(
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_language_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
65
+ generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
67
+ atol=1e-04,
68
+ )
69
+
70
+
71
+ if __name__ == "__main__":
72
+ app.run(main)
@@ -20,31 +20,48 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import image_encoder
23
+ import kagglehub
23
24
  from PIL import Image
24
25
  import requests
25
26
  import torch
26
27
  import transformers
27
28
 
29
+ _VERSION = flags.DEFINE_enum(
30
+ "version",
31
+ "1",
32
+ ["1", "2"],
33
+ "The version of PaliGemma vision model to verify.",
34
+ )
28
35
  _IMAGE_URL = flags.DEFINE_string(
29
36
  "image_url",
30
37
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31
38
  "The image URI to encode.",
32
39
  )
33
40
 
41
+ _CHECKPOINT = {
42
+ "1": "google/paligemma-3b-mix-224",
43
+ "2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
44
+ }
45
+
34
46
 
35
47
  def main(_):
36
- checkpoint = "google/paligemma-3b-mix-224"
48
+ if _VERSION.value == "1":
49
+ checkpoint = _CHECKPOINT[_VERSION.value]
50
+ # Locate the cached dir.
51
+ cached_config_file = transformers.utils.cached_file(
52
+ checkpoint, transformers.utils.CONFIG_NAME
53
+ )
54
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
55
+ else:
56
+ checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
57
+ reauthored_checkpoint = checkpoint
58
+
37
59
  logging.info("Loading the original model from: %s", checkpoint)
38
60
  original_full_model = (
39
61
  transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
40
62
  )
41
63
  original_vision_model = original_full_model.eval().vision_tower
42
64
 
43
- # Locate the cached dir.
44
- cached_config_file = transformers.utils.cached_file(
45
- checkpoint, transformers.utils.CONFIG_NAME
46
- )
47
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
48
65
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
49
66
  reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
50
67
 
@@ -69,7 +86,7 @@ def main(_):
69
86
 
70
87
  try:
71
88
  assert torch.allclose(
72
- outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
89
+ outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
73
90
  )
74
91
  except AssertionError as e:
75
92
  logging.error("*** FAILED *** verify with an image")
@@ -118,11 +118,11 @@ class DecoderOnlyModel(nn.Module):
118
118
  mask = self.mask_cache.index_select(2, input_pos)
119
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
120
 
121
- return self.forward_with_embeds(
121
+ return self._forward_with_embeds(
122
122
  input_embeds, rope, mask, input_pos, kv_cache, export_config
123
123
  )
124
124
 
125
- def forward_with_embeds(
125
+ def _forward_with_embeds(
126
126
  self,
127
127
  input_embeds: torch.Tensor,
128
128
  rope: Tuple[torch.Tensor, torch.Tensor],
@@ -16,7 +16,7 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import List
19
+ from typing import Any,List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -87,6 +87,10 @@ class ReauthoredModelWrapper(ModelWrapper):
87
87
  """Returns an initialized KV cache."""
88
88
  return kv_utils.KVCache.from_model_config(self.model.config)
89
89
 
90
+ def _get_extra_args_for_forward(self) -> dict[str, Any]:
91
+ """Returns extra arguments for the forward() method."""
92
+ return {}
93
+
90
94
  def _forward_with_kv_cache(
91
95
  self,
92
96
  tokens: torch.Tensor,
@@ -105,26 +109,15 @@ class ReauthoredModelWrapper(ModelWrapper):
105
109
  Returns:
106
110
  The output logits and the updated KV cache.
107
111
  """
108
- # Verification requires logit outputs on prefill for comparison.
109
- if (
110
- self.export_config is not None
111
- and not self.export_config.output_logits_on_prefill
112
- ):
113
- raise ValueError("Verifier requires logit output on prefill.")
114
- # Since the reauthored model doesn't include keyword arguments, pass
115
- # pixel_values only when it is not None. Otherwise, it may raise an error.
116
- if pixel_values is None:
117
- output = self.model.forward(
118
- tokens, input_pos, kv_cache, export_config=self.export_config
119
- )
120
- else:
121
- output = self.model.forward(
122
- tokens,
123
- input_pos,
124
- kv_cache,
125
- pixel_values=pixel_values,
126
- export_config=self.export_config,
127
- )
112
+ extra_args = self._get_extra_args_for_forward()
113
+ if self.export_config is not None:
114
+ # Verification requires logit outputs on prefill for comparison.
115
+ if not self.export_config.output_logits_on_prefill:
116
+ raise ValueError("Verifier requires logit output on prefill.")
117
+ extra_args["export_config"] = self.export_config
118
+ if pixel_values is not None:
119
+ extra_args["pixel_values"] = pixel_values
120
+ output = self.model.forward(tokens, input_pos, kv_cache, **extra_args)
128
121
  return output["logits"], output["kv_cache"]
129
122
 
130
123
  def forward(
@@ -141,6 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
141
134
  prompts: torch.Tensor,
142
135
  max_new_tokens: int,
143
136
  pixel_values: torch.Tensor = None,
137
+ eos_token_id: int = 1,
144
138
  ) -> torch.IntTensor:
145
139
  input_ids = prompts[0].int().tolist()
146
140
  tokens = torch.tensor([input_ids])
@@ -152,6 +146,8 @@ class ReauthoredModelWrapper(ModelWrapper):
152
146
  )
153
147
  generated_token = logits[0][-1].argmax().item()
154
148
  input_ids.append(generated_token)
149
+ if generated_token == eos_token_id:
150
+ break
155
151
  tokens = torch.tensor([[generated_token]])
156
152
  input_pos = torch.tensor([len(input_ids) - 1])
157
153
  pixel_values = None # Pass only for the first time.
@@ -254,7 +250,11 @@ def verify_model_with_prompts(
254
250
  logging.info("outputs_from_original_model: [[%s]]", response_original)
255
251
 
256
252
  logging.info("Generating answer with the reauthored model...")
257
- outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
253
+ outputs_reauthored = reauthored_model.generate(
254
+ prompt_tokens,
255
+ max_new_tokens,
256
+ eos_token_id=tokenizer.tokenizer.eos_token_id,
257
+ )
258
258
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
259
259
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
260
260
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241221"
16
+ __version__ = "0.3.0.dev20241222"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241221
3
+ Version: 0.3.0.dev20241222
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=4pSrONNJgkt6DeTfleRz5DpcHts3SW-iInT2ibr1t9A,706
6
+ ai_edge_torch/version.py,sha256=PKEPravHVUIDugudfMDzqU57wXbpvrsY94puBM6FS-c,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -46,8 +46,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3
46
46
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
49
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=N0jKVZA3qWKOaHVbIM3WmQh3u0Sq7MTw_oO3Zo16wCw,3456
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=roEwWVXASbk5BFj7jojjEJpHui6gCelT51l-TtN_ZaQ,9367
49
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=VTM2nO3TqK2d1DyEb2MiHc-Tyw2lMcUXyOhvg0H5ENY,10147
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -64,12 +64,14 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=HDDTd4F0kOurhXyqikP5umdY0gVm-FHA1ysaKcz88CM,5261
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=DDVFHGqRbJgnLT4XJRYJ-MAp2-xPnI4fAUGSYVNMprc,5342
68
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=EjNbZwXM_T_0FXgHUAtLupihPsNlhPWeOop3IJ10Wzg,6320
68
69
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
69
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
70
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
70
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDBFu_uVzdARH06BU6xRerVdjahSCm39nQcYigJVoHE,6261
71
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=__RUyh0L5Td2jbm1xGnSldbfpKHtxyXAh2h06KVGxLA,5622
71
72
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
72
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
73
+ ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
74
+ ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=qaROQSjgs0DtVOX4KS5kPmlDrBFn0yJr83_kWIN8NzM,3540
73
75
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
74
76
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
75
77
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
@@ -147,12 +149,12 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
147
149
  ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
148
150
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
149
151
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
150
- ai_edge_torch/generative/utilities/model_builder.py,sha256=plKHp5csnZpx3GQ1SYTqFpdoaxTVcwXgCmzO5N6ya6I,6350
152
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=S08WNqVKCmxd2QjtMlwETd7J97UnlME_bTKdz5LMkGU,6352
151
153
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
154
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
155
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
154
156
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
155
- ai_edge_torch/generative/utilities/verifier.py,sha256=ESSA8W1EYNsd4ntwmXbr-dn-BcIS27hf53XL5RTwjEU,11941
157
+ ai_edge_torch/generative/utilities/verifier.py,sha256=awO-sQrEpsFxIkZw72ysWZenYEmkLOLOuj62o2c7XeQ,11994
156
158
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
157
159
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
158
160
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -201,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
201
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
202
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
203
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
204
- ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
- ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/METADATA,sha256=_mQiElLiIpig6KWylK15amdyQP57haDyWH4Xaqqt_Ls,1966
206
- ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
- ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
- ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/METADATA,sha256=0-7zfD8burp8x7iTlCrOe2JO8BZuV1zcRmMrcsGFjVk,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20241222.dist-info/RECORD,,