ai-edge-torch-nightly 0.3.0.dev20241220__py3-none-any.whl → 0.3.0.dev20241224__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (23) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +3 -2
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +32 -13
  3. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -3
  4. ai_edge_torch/generative/examples/paligemma/decoder.py +14 -5
  5. ai_edge_torch/generative/examples/paligemma/decoder2.py +174 -0
  6. ai_edge_torch/generative/examples/paligemma/paligemma.py +30 -15
  7. ai_edge_torch/generative/examples/paligemma/verify.py +36 -9
  8. ai_edge_torch/generative/examples/paligemma/verify_decoder2.py +72 -0
  9. ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +24 -7
  10. ai_edge_torch/generative/test/test_model_conversion_large.py +28 -9
  11. ai_edge_torch/generative/utilities/model_builder.py +5 -4
  12. ai_edge_torch/generative/utilities/verifier.py +22 -22
  13. ai_edge_torch/odml_torch/export.py +6 -1
  14. ai_edge_torch/odml_torch/jax_bridge/__init__.py +4 -1
  15. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  16. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -2
  17. ai_edge_torch/odml_torch/lowerings/_rand.py +142 -0
  18. ai_edge_torch/version.py +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20241220.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20241220.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/RECORD +23 -20
  21. {ai_edge_torch_nightly-0.3.0.dev20241220.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20241220.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20241220.dist-info → ai_edge_torch_nightly-0.3.0.dev20241224.dist-info}/top_level.txt +0 -0
@@ -72,12 +72,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
72
72
  pre_attention_norm_config=norm_config,
73
73
  post_attention_norm_config=norm_config,
74
74
  )
75
+ embedding_dim = 2048
75
76
  config = cfg.ModelConfig(
76
77
  vocab_size=256000,
77
78
  num_layers=18,
78
79
  max_seq_len=8192,
79
- embedding_dim=2048,
80
- embedding_scale=2048**0.5,
80
+ embedding_dim=embedding_dim,
81
+ embedding_scale=embedding_dim**0.5,
81
82
  kv_cache_max_len=kv_cache_max_len,
82
83
  block_configs=block_config,
83
84
  final_norm_config=norm_config,
@@ -15,7 +15,7 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
@@ -136,29 +136,45 @@ class Gemma2(nn.Module):
136
136
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
137
137
  f" {self.config.max_seq_len}"
138
138
  )
139
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
140
- "The number of transformer blocks and the number of KV cache entries"
141
- " must be the same."
142
- )
143
139
 
140
+ # token embeddings of shape (b, t, n_embd)
141
+ input_embeds = self.tok_embedding(tokens)
144
142
  # RoPE parameters are the same for all blocks. Use the first layer.
145
143
  attn_config = self.config.block_config(0).attn_config
146
144
  n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
147
145
  rope = rotary_pos_emb.build_rope(
148
146
  input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
149
147
  )
148
+ mask = [self.get_attention_mask(
149
+ self.config.block_config(i).attn_config.attn_type, input_pos
150
+ ) for i in range(self.config.num_layers)]
150
151
 
151
- # token embeddings of shape (b, t, n_embd)
152
- x = self.tok_embedding(tokens)
153
- x = x * (self.config.embedding_dim**0.5)
152
+ return self._forward_with_embeds(
153
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
154
+ )
155
+
156
+ def _forward_with_embeds(
157
+ self,
158
+ input_embeds: torch.Tensor,
159
+ rope: Tuple[torch.Tensor, torch.Tensor],
160
+ mask: List[torch.Tensor],
161
+ input_pos: torch.Tensor,
162
+ kv_cache: kv_utils.KVCache,
163
+ export_config: Optional[model_builder.ExportConfig] = None,
164
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
+ """Forwards the model with input embeddings."""
166
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
167
+ "The number of transformer blocks and the number of KV cache entries"
168
+ " must be the same."
169
+ )
154
170
 
171
+ if self.config.embedding_scale is not None:
172
+ input_embeds = input_embeds * self.config.embedding_scale
173
+ x = input_embeds
155
174
  updated_kv_entries = []
156
175
  for i, block in enumerate(self.transformer_blocks):
157
- mask = self.get_attention_mask(
158
- block.config.attn_config.attn_type, input_pos
159
- )
160
176
  kv_entry = kv_cache.caches[i] if kv_cache else None
161
- x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
177
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162
178
  if kv_entry:
163
179
  updated_kv_entries.append(kv_entry)
164
180
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
@@ -227,11 +243,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
227
243
  )
228
244
 
229
245
  num_layers = 26
246
+ embedding_dim = 2304
230
247
  config = cfg.ModelConfig(
231
248
  vocab_size=256000,
232
249
  num_layers=num_layers,
233
250
  max_seq_len=8192,
234
- embedding_dim=2304,
251
+ embedding_dim=embedding_dim,
252
+ embedding_scale=embedding_dim**0.5,
235
253
  kv_cache_max_len=kv_cache_max_len,
236
254
  block_configs=[get_block_config(i) for i in range(num_layers)],
237
255
  final_norm_config=norm_config,
@@ -248,6 +266,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
248
266
  config.num_layers = 2
249
267
  config.max_seq_len = 2 * kv_cache_max_len
250
268
  config.embedding_dim = 128
269
+ config.embedding_scale = config.embedding_dim**0.5
251
270
  config.block_configs = config.block_configs[: config.num_layers]
252
271
  for block_config in config.block_configs:
253
272
  block_config.attn_config.num_heads = 4
@@ -29,9 +29,15 @@ from ai_edge_torch.generative.utilities import converter
29
29
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
30
30
  import torch
31
31
 
32
+ _VERSION = flags.DEFINE_enum(
33
+ 'version',
34
+ '2',
35
+ ['1', '2'],
36
+ 'The version of PaliGemma model to verify.',
37
+ )
32
38
  _CHECKPOINT_PATH = flags.DEFINE_string(
33
39
  'checkpoint_path',
34
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
40
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
35
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
36
42
  )
37
43
  _TFLITE_PATH = flags.DEFINE_string(
@@ -63,10 +69,12 @@ _QUANTIZE = flags.DEFINE_bool(
63
69
 
64
70
  def main(_):
65
71
  pytorch_model = paligemma.build_model(
66
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
72
+ _CHECKPOINT_PATH.value,
73
+ version=int(_VERSION.value),
74
+ kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
67
75
  )
68
76
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
69
- output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
77
+ output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
70
78
  converter.convert_to_tflite(
71
79
  pytorch_model,
72
80
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
@@ -19,6 +19,7 @@ from typing import Optional
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
22
23
  from ai_edge_torch.generative.utilities import model_builder
23
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
25
  import torch
@@ -54,6 +55,7 @@ class Decoder(model_builder.DecoderOnlyModel):
54
55
  kv_cache: kv_utils.KVCache,
55
56
  input_embeds: torch.Tensor = None,
56
57
  export_config: Optional[model_builder.ExportConfig] = None,
58
+ called_by_generate: bool = True,
57
59
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
58
60
  if input_embeds is None:
59
61
  return super().forward(tokens, input_pos, kv_cache)
@@ -61,8 +63,12 @@ class Decoder(model_builder.DecoderOnlyModel):
61
63
  assert input_embeds is not None
62
64
 
63
65
  repo_pos = input_pos + 1 # PaliGemma position is 1-based.
64
- cos, sin = self.rope_cache
65
- rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
66
+ # ROPE parameters for all attn_configs are the same. Take the first one.
67
+ attn_config = self.config.block_config(0).attn_config
68
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
69
+ rope = rotary_pos_emb.build_rope(
70
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
71
+ )
66
72
 
67
73
  # The first part of input_embeds are image embeddings. Diagonal causal mask
68
74
  # doesn't work here.
@@ -70,7 +76,7 @@ class Decoder(model_builder.DecoderOnlyModel):
70
76
  mask = torch.zeros(embeds_len, self.config.kv_cache_max)
71
77
  mask[:, embeds_len:] = float("-inf")
72
78
 
73
- return self.forward_with_embeds(
79
+ return self._forward_with_embeds(
74
80
  input_embeds, rope, mask, input_pos, kv_cache
75
81
  )
76
82
 
@@ -108,12 +114,13 @@ def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
108
114
  pre_attention_norm_config=norm_config,
109
115
  post_attention_norm_config=norm_config,
110
116
  )
117
+ embedding_dim = 2048
111
118
  config = cfg.ModelConfig(
112
119
  vocab_size=257216,
113
120
  num_layers=18,
114
121
  max_seq_len=8192,
115
- embedding_dim=2048,
116
- embedding_scale=2048**0.5,
122
+ embedding_dim=embedding_dim,
123
+ embedding_scale=embedding_dim**0.5,
117
124
  kv_cache_max_len=kv_cache_max_len,
118
125
  block_configs=block_config,
119
126
  final_norm_config=norm_config,
@@ -130,6 +137,8 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
130
137
  config.vocab_size = 128
131
138
  config.num_layers = 2
132
139
  config.max_seq_len = 2 * kv_cache_max_len
140
+ config.embedding_dim = 128
141
+ config.embedding_scale = 128**0.5
133
142
  return config
134
143
 
135
144
 
@@ -0,0 +1,174 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a decoder of PaliGemma2 3B model which is Gemma2."""
17
+
18
+ from typing import Optional
19
+
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
24
+ from ai_edge_torch.generative.utilities import model_builder
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import torch
27
+
28
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
29
+ ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
30
+ ff_down_proj="language_model.model.layers.{}.mlp.down_proj",
31
+ ff_gate_proj="language_model.model.layers.{}.mlp.gate_proj",
32
+ attn_query_proj="language_model.model.layers.{}.self_attn.q_proj",
33
+ attn_key_proj="language_model.model.layers.{}.self_attn.k_proj",
34
+ attn_value_proj="language_model.model.layers.{}.self_attn.v_proj",
35
+ attn_output_proj="language_model.model.layers.{}.self_attn.o_proj",
36
+ pre_attn_norm="language_model.model.layers.{}.input_layernorm",
37
+ post_attn_norm="language_model.model.layers.{}.post_attention_layernorm",
38
+ pre_ff_norm="language_model.model.layers.{}.pre_feedforward_layernorm",
39
+ post_ff_norm="language_model.model.layers.{}.post_feedforward_layernorm",
40
+ embedding="language_model.model.embed_tokens",
41
+ final_norm="language_model.model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Decoder2(gemma2.Gemma2):
47
+ """A decoder of PaliGemma2 3B model which is Gemma2.
48
+
49
+ Besides a tensor of text token IDs, forward() can also take a tensor of
50
+ embeddings which may include text or image or both.
51
+ """
52
+
53
+ @torch.inference_mode
54
+ def forward(
55
+ self,
56
+ tokens: torch.Tensor,
57
+ input_pos: torch.Tensor,
58
+ kv_cache: kv_utils.KVCache,
59
+ input_embeds: torch.Tensor = None,
60
+ export_config: Optional[model_builder.ExportConfig] = None,
61
+ called_by_generate: bool = True,
62
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
63
+ if input_embeds is None:
64
+ return super().forward(tokens, input_pos, kv_cache)
65
+
66
+ assert input_embeds is not None
67
+
68
+ repo_pos = input_pos + 1 # PaliGemma2 position is 1-based.
69
+ # ROPE parameters for all attn_configs are the same. Take the first one.
70
+ attn_config = self.config.block_config(0).attn_config
71
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
72
+ rope = rotary_pos_emb.build_rope(
73
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
74
+ )
75
+
76
+ if called_by_generate:
77
+ # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78
+ mask = [self.get_attention_mask(
79
+ self.config.block_config(i).attn_config.attn_type, input_pos
80
+ ) for i in range(self.config.num_layers)]
81
+ else:
82
+ # By default, don't mask image embeds with a diagonal causal mask.
83
+ embeds_len = input_embeds.shape[1]
84
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
85
+ mask[:, embeds_len:] = float("-inf")
86
+ mask = [mask] * self.config.num_layers
87
+
88
+ return self._forward_with_embeds(
89
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
90
+ )
91
+
92
+
93
+ def get_decoder2_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
94
+ """Returns the model config for the decoder of a PaliGemma 3B model.
95
+
96
+ Args:
97
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
98
+ is 1024.
99
+
100
+ Returns:
101
+ The model config for the decoder of a PaliGemma 3B model.
102
+ """
103
+ norm_config = cfg.NormalizationConfig(
104
+ type=cfg.NormalizationType.RMS_NORM,
105
+ epsilon=1e-6,
106
+ zero_centered=True,
107
+ )
108
+ ff_config = cfg.FeedForwardConfig(
109
+ type=cfg.FeedForwardType.GATED,
110
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
111
+ intermediate_size=9216,
112
+ pre_ff_norm_config=norm_config,
113
+ post_ff_norm_config=norm_config,
114
+ )
115
+
116
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
117
+ attn_config = cfg.AttentionConfig(
118
+ num_heads=8,
119
+ head_dim=256,
120
+ num_query_groups=4,
121
+ rotary_base=10000,
122
+ rotary_percentage=1.0,
123
+ logit_softcap=50.0,
124
+ sliding_window_size=4096,
125
+ attn_type=(
126
+ cfg.AttentionType.GLOBAL
127
+ if idx % 2 == 0
128
+ else cfg.AttentionType.LOCAL_SLIDING
129
+ ),
130
+ )
131
+ return cfg.TransformerBlockConfig(
132
+ attn_config=attn_config,
133
+ ff_config=ff_config,
134
+ pre_attention_norm_config=norm_config,
135
+ post_attention_norm_config=norm_config,
136
+ )
137
+
138
+ num_layers = 26
139
+ embedding_dim = 2304
140
+ config = cfg.ModelConfig(
141
+ vocab_size=257216,
142
+ num_layers=num_layers,
143
+ max_seq_len=8192,
144
+ embedding_dim=embedding_dim,
145
+ embedding_scale=embedding_dim**0.5,
146
+ kv_cache_max_len=kv_cache_max_len,
147
+ block_configs=[get_block_config(i) for i in range(num_layers)],
148
+ final_norm_config=norm_config,
149
+ lm_head_use_bias=False,
150
+ enable_hlfb=True,
151
+ final_logit_softcap=30.0,
152
+ )
153
+ return config
154
+
155
+
156
+ def get_fake_decoder2_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
157
+ config = get_decoder2_config(kv_cache_max_len)
158
+ # PaliGemma2 decoder has only one block config.
159
+ config.block_config(0).ff_config.intermediate_size = 128
160
+ config.vocab_size = 128
161
+ config.num_layers = 2
162
+ config.max_seq_len = 2 * kv_cache_max_len
163
+ config.embedding_dim = 128
164
+ config.embedding_scale = 128**0.5
165
+ return config
166
+
167
+
168
+ def build_decoder2(checkpoint_path: str, **kwargs) -> torch.nn.Module:
169
+ return model_builder.build_decoder_only_model(
170
+ checkpoint_path=checkpoint_path,
171
+ config=get_decoder2_config(**kwargs),
172
+ tensor_names=TENSOR_NAMES,
173
+ model_class=Decoder2,
174
+ )
@@ -19,6 +19,7 @@ from dataclasses import dataclass
19
19
  from typing import Optional
20
20
 
21
21
  from ai_edge_torch.generative.examples.paligemma import decoder
22
+ from ai_edge_torch.generative.examples.paligemma import decoder2
22
23
  from ai_edge_torch.generative.examples.paligemma import image_encoder
23
24
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
24
25
  import ai_edge_torch.generative.layers.model_config as cfg
@@ -38,13 +39,14 @@ class PaliGemmaConfig:
38
39
  decoder_config: cfg.ModelConfig
39
40
 
40
41
  image_token_id: int
42
+ image_projection_scale: float
41
43
  image_projection_use_bias: bool = False
42
44
 
43
45
 
44
46
  class PaliGemma(nn.Module):
45
47
  """PaliGemma model from the Edge Generative API."""
46
48
 
47
- def __init__(self, config: PaliGemmaConfig):
49
+ def __init__(self, config: PaliGemmaConfig, decoder_class: nn.Module):
48
50
  super().__init__()
49
51
 
50
52
  self.image_encoder = image_encoder.SiglipVisionEncoder(
@@ -55,7 +57,7 @@ class PaliGemma(nn.Module):
55
57
  config.decoder_config.embedding_dim,
56
58
  bias=config.image_projection_use_bias,
57
59
  )
58
- self.decoder = decoder.Decoder(config.decoder_config)
60
+ self.decoder = decoder_class(config.decoder_config)
59
61
  image_embedding_config = config.image_encoder_config.image_embedding
60
62
  self.num_patches = (
61
63
  image_embedding_config.image_size // image_embedding_config.patch_size
@@ -70,6 +72,7 @@ class PaliGemma(nn.Module):
70
72
  kv_cache: kv_utils.KVCache,
71
73
  pixel_values: torch.Tensor = None,
72
74
  export_config: Optional[model_builder.ExportConfig] = None,
75
+ called_by_generate: bool = True,
73
76
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
74
77
  if pixel_values is None:
75
78
  return self.decoder(
@@ -77,15 +80,15 @@ class PaliGemma(nn.Module):
77
80
  input_pos=input_pos,
78
81
  kv_cache=kv_cache,
79
82
  input_embeds=None,
80
- export_config=export_config
83
+ export_config=export_config,
84
+ called_by_generate=called_by_generate,
81
85
  )
82
86
 
83
87
  input_embeds = self.decoder.tok_embedding(tokens)
84
88
 
85
89
  image_encoded = self.image_encoder(pixel_values=pixel_values)
86
90
  image_embeds = self.image_projection(image_encoded)
87
- if self.config.decoder_config.embedding_scale is not None:
88
- image_embeds = image_embeds / self.config.decoder_config.embedding_scale
91
+ image_embeds = image_embeds / self.config.image_projection_scale
89
92
 
90
93
  # Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
91
94
  # can be done like:
@@ -110,10 +113,11 @@ class PaliGemma(nn.Module):
110
113
  kv_cache=kv_cache,
111
114
  input_embeds=input_embeds,
112
115
  export_config=export_config,
116
+ called_by_generate=called_by_generate,
113
117
  )
114
118
 
115
119
 
116
- def get_model_config(**kwargs) -> PaliGemmaConfig:
120
+ def get_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
117
121
  """Returns the model config for a PaliGemma 3B-224 model.
118
122
 
119
123
  Returns:
@@ -121,31 +125,42 @@ def get_model_config(**kwargs) -> PaliGemmaConfig:
121
125
  """
122
126
  return PaliGemmaConfig(
123
127
  image_encoder_config=image_encoder.get_image_encoder_config(),
124
- decoder_config=decoder.get_decoder_config(**kwargs),
125
- image_projection_use_bias=True,
128
+ decoder_config=get_decoder_config(**kwargs),
126
129
  image_token_id=257152,
130
+ image_projection_scale=2048**0.5,
131
+ image_projection_use_bias=True,
127
132
  )
128
133
 
129
134
 
130
- def get_fake_model_config() -> PaliGemmaConfig:
135
+ def get_fake_model_config(get_decoder_config, **kwargs) -> PaliGemmaConfig:
131
136
  return PaliGemmaConfig(
132
137
  image_encoder_config=image_encoder.get_fake_image_encoder_config(),
133
- decoder_config=decoder.get_fake_decoder_config(),
138
+ decoder_config=get_decoder_config(**kwargs),
139
+ image_token_id=127,
140
+ image_projection_scale=128**0.5,
134
141
  image_projection_use_bias=True,
135
- image_token_id=257152,
136
142
  )
137
143
 
138
144
 
139
- def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
140
- config = get_model_config(**kwargs)
141
- model = PaliGemma(config)
145
+ def build_model(checkpoint_path: str, version: int = 2, **kwargs) -> PaliGemma:
146
+ if version == 1:
147
+ decoder_class = decoder.Decoder
148
+ decoder_tensor_names = decoder.TENSOR_NAMES
149
+ get_decoder_config = decoder.get_decoder_config
150
+ else:
151
+ decoder_class = decoder2.Decoder2
152
+ decoder_tensor_names = decoder2.TENSOR_NAMES
153
+ get_decoder_config = decoder2.get_decoder2_config
154
+
155
+ config = get_model_config(get_decoder_config, **kwargs)
156
+ model = PaliGemma(config, decoder_class)
142
157
  # Load the parameters of image encoder.
143
158
  loader = loading_utils.ModelLoader(
144
159
  checkpoint_path, image_encoder.TENSOR_NAMES
145
160
  )
146
161
  loader.load(model.image_encoder, strict=False)
147
162
  # Load the parameters of decoder.
148
- loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
163
+ loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
149
164
  loader.load(model.decoder, strict=False)
150
165
 
151
166
  # Load the parameters of image projection.
@@ -22,11 +22,18 @@ from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import paligemma
23
23
  from ai_edge_torch.generative.layers import kv_cache
24
24
  from ai_edge_torch.generative.utilities import verifier
25
+ import kagglehub
25
26
  from PIL import Image
26
27
  import requests
27
28
  import torch
28
29
  import transformers
29
30
 
31
+ _VERSION = flags.DEFINE_enum(
32
+ "version",
33
+ "2",
34
+ ["1", "2"],
35
+ "The version of PaliGemma model to verify.",
36
+ )
30
37
  _IMAGE_URL = flags.DEFINE_string(
31
38
  "image_url",
32
39
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
@@ -34,7 +41,7 @@ _IMAGE_URL = flags.DEFINE_string(
34
41
  )
35
42
  _PROMPTS = flags.DEFINE_string(
36
43
  "prompts",
37
- "Caption en",
44
+ "describe en",
38
45
  "The input prompts to generate answers.",
39
46
  )
40
47
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
@@ -43,28 +50,47 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
43
50
  "The maximum size of the generated tokens.",
44
51
  )
45
52
 
53
+ _CHECKPOINT = {
54
+ "1": "google/paligemma-3b-mix-224",
55
+ "2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
56
+ }
57
+
46
58
 
47
59
  class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
48
60
  """Reauthored PaliGemma model wrapper."""
49
61
 
62
+ def __init__(self, model: torch.nn.Module):
63
+ super().__init__(model)
64
+ self.forward_called_by_generate = False
65
+
50
66
  def _init_kv_cache(self):
51
67
  return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
52
68
 
69
+ def _get_extra_args_for_forward(self):
70
+ return {"called_by_generate": self.forward_called_by_generate}
71
+
53
72
 
54
73
  def main(_):
55
- checkpoint = "google/paligemma-3b-mix-224"
74
+ if _VERSION.value == "1":
75
+ checkpoint = _CHECKPOINT[_VERSION.value]
76
+ # Locate the cached dir.
77
+ cached_config_file = transformers.utils.cached_file(
78
+ checkpoint, transformers.utils.CONFIG_NAME
79
+ )
80
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
81
+ else:
82
+ checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
83
+ reauthored_checkpoint = checkpoint
84
+
56
85
  logging.info("Loading the original model from: %s", checkpoint)
57
86
  original_model = (
58
87
  transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
59
88
  )
60
89
 
61
- # Locate the cached dir.
62
- cached_config_file = transformers.utils.cached_file(
63
- checkpoint, transformers.utils.CONFIG_NAME
64
- )
65
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
66
90
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
- reauthored_model = paligemma.build_model(reauthored_checkpoint)
91
+ reauthored_model = paligemma.build_model(
92
+ reauthored_checkpoint, version=int(_VERSION.value)
93
+ )
68
94
 
69
95
  logging.info("Loading the processor from: %s", checkpoint)
70
96
  # It works only when GemmaTokenizerFast is available. In some environments,
@@ -93,7 +119,7 @@ def main(_):
93
119
  logging.info("outputs_reauthored: %s", outputs_reauthored)
94
120
 
95
121
  try:
96
- assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
122
+ assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-02)
97
123
  except AssertionError as e:
98
124
  logging.error("*** FAILED *** verify with forward()")
99
125
  raise e
@@ -111,6 +137,7 @@ def main(_):
111
137
  logging.info("outputs_from_original_model: [[%s]]", response_original)
112
138
 
113
139
  logging.info("Generating answer with the reauthored model...")
140
+ wrapped_reauthored_model.forward_called_by_generate = True
114
141
  outputs_reauthored = wrapped_reauthored_model.generate(
115
142
  prompts=inputs["input_ids"],
116
143
  pixel_values=inputs["pixel_values"],
@@ -0,0 +1,72 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored decoder of PaliGemma2 3B model."""
17
+
18
+ import logging
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.paligemma import decoder2
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
+ from ai_edge_torch.generative.utilities import verifier
25
+ import kagglehub
26
+ import transformers
27
+
28
+ _PROMPTS = flags.DEFINE_multi_string(
29
+ "prompts",
30
+ "What is the meaning of life?",
31
+ "The input prompts to generate answers.",
32
+ )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
38
+
39
+
40
+ def main(_):
41
+ checkpoint = kagglehub.model_download(
42
+ "google/paligemma-2/transformers/paligemma2-3b-pt-224"
43
+ )
44
+ logging.info("Loading the original model from: %s", checkpoint)
45
+ original_full_model = (
46
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
47
+ )
48
+ original_language_model = original_full_model.eval().language_model
49
+
50
+ logging.info("Building the reauthored model from: %s", checkpoint)
51
+ reauthored_model = decoder2.build_decoder2(checkpoint)
52
+
53
+ logging.info("Loading the tokenizer from: %s", checkpoint)
54
+ # It works only when GemmaTokenizerFast is available. In some environments,
55
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
56
+ # sentencepiece model file properly.
57
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
58
+
59
+ verifier.verify_reauthored_model(
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_language_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(processor.tokenizer),
65
+ generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
67
+ atol=1e-04,
68
+ )
69
+
70
+
71
+ if __name__ == "__main__":
72
+ app.run(main)
@@ -20,31 +20,48 @@ import pathlib
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.paligemma import image_encoder
23
+ import kagglehub
23
24
  from PIL import Image
24
25
  import requests
25
26
  import torch
26
27
  import transformers
27
28
 
29
+ _VERSION = flags.DEFINE_enum(
30
+ "version",
31
+ "2",
32
+ ["1", "2"],
33
+ "The version of PaliGemma vision model to verify.",
34
+ )
28
35
  _IMAGE_URL = flags.DEFINE_string(
29
36
  "image_url",
30
37
  "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
31
38
  "The image URI to encode.",
32
39
  )
33
40
 
41
+ _CHECKPOINT = {
42
+ "1": "google/paligemma-3b-mix-224",
43
+ "2": "google/paligemma-2/transformers/paligemma2-3b-pt-224",
44
+ }
45
+
34
46
 
35
47
  def main(_):
36
- checkpoint = "google/paligemma-3b-mix-224"
48
+ if _VERSION.value == "1":
49
+ checkpoint = _CHECKPOINT[_VERSION.value]
50
+ # Locate the cached dir.
51
+ cached_config_file = transformers.utils.cached_file(
52
+ checkpoint, transformers.utils.CONFIG_NAME
53
+ )
54
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
55
+ else:
56
+ checkpoint = kagglehub.model_download(_CHECKPOINT[_VERSION.value])
57
+ reauthored_checkpoint = checkpoint
58
+
37
59
  logging.info("Loading the original model from: %s", checkpoint)
38
60
  original_full_model = (
39
61
  transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
40
62
  )
41
63
  original_vision_model = original_full_model.eval().vision_tower
42
64
 
43
- # Locate the cached dir.
44
- cached_config_file = transformers.utils.cached_file(
45
- checkpoint, transformers.utils.CONFIG_NAME
46
- )
47
- reauthored_checkpoint = pathlib.Path(cached_config_file).parent
48
65
  logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
49
66
  reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
50
67
 
@@ -69,7 +86,7 @@ def main(_):
69
86
 
70
87
  try:
71
88
  assert torch.allclose(
72
- outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
89
+ outputs_original, outputs_reauthored, atol=1e-03, rtol=1e-04
73
90
  )
74
91
  except AssertionError as e:
75
92
  logging.error("*** FAILED *** verify with an image")
@@ -21,6 +21,8 @@ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.llama import llama
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
+ from ai_edge_torch.generative.examples.paligemma import decoder
25
+ from ai_edge_torch.generative.examples.paligemma import decoder2
24
26
  from ai_edge_torch.generative.examples.paligemma import paligemma
25
27
  from ai_edge_torch.generative.examples.phi import phi2
26
28
  from ai_edge_torch.generative.examples.phi import phi3
@@ -171,13 +173,9 @@ class TestModelConversion(googletest.TestCase):
171
173
  pytorch_model = amd_llama_135m.AmdLlama(config).eval()
172
174
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
173
175
 
174
- @googletest.skipIf(
175
- ai_edge_torch.config.in_oss,
176
- reason="tests with custom ops are not supported in oss",
177
- )
178
- def disabled_test_paligemma(self):
179
- config = paligemma.get_fake_model_config()
180
- pytorch_model = paligemma.PaliGemma(config).eval()
176
+ def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol):
177
+ config = paligemma.get_fake_model_config(decoder_config)
178
+ pytorch_model = paligemma.PaliGemma(config, decoder_class).eval()
181
179
 
182
180
  image_embedding_config = config.image_encoder_config.image_embedding
183
181
  num_patches = (
@@ -215,11 +213,32 @@ class TestModelConversion(googletest.TestCase):
215
213
  kv,
216
214
  pixel_values=pixel_values,
217
215
  signature_name="prefill_pixel",
218
- atol=1e-3,
219
- rtol=1e-5,
216
+ atol=atol,
217
+ rtol=rtol,
220
218
  )
221
219
  )
222
220
 
221
+ @googletest.skipIf(
222
+ ai_edge_torch.config.in_oss,
223
+ reason="tests with custom ops are not supported in oss",
224
+ )
225
+ def disabled_test_paligemma1(self):
226
+ self._test_paligemma_model(
227
+ decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5
228
+ )
229
+
230
+ @googletest.skipIf(
231
+ ai_edge_torch.config.in_oss,
232
+ reason="tests with custom ops are not supported in oss",
233
+ )
234
+ def disabled_test_paligemma2(self):
235
+ self._test_paligemma_model(
236
+ decoder2.Decoder2,
237
+ decoder2.get_fake_decoder2_config,
238
+ atol=1e-3,
239
+ rtol=1e-5,
240
+ )
241
+
223
242
  @googletest.skipIf(
224
243
  ai_edge_torch.config.in_oss,
225
244
  reason="tests with custom ops are not supported in oss",
@@ -107,8 +107,6 @@ class DecoderOnlyModel(nn.Module):
107
107
 
108
108
  # token embeddings of shape (b, t, n_embd)
109
109
  input_embeds = self.tok_embedding(tokens)
110
- mask = self.mask_cache.index_select(2, input_pos)
111
- mask = mask[:, :, :, : self.config.kv_cache_max]
112
110
 
113
111
  # ROPE parameters for all attn_configs are the same. Take the first one.
114
112
  attn_config = self.config.block_config(0).attn_config
@@ -117,11 +115,14 @@ class DecoderOnlyModel(nn.Module):
117
115
  input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
118
116
  )
119
117
 
120
- return self.forward_with_embeds(
118
+ mask = self.mask_cache.index_select(2, input_pos)
119
+ mask = mask[:, :, :, : self.config.kv_cache_max]
120
+
121
+ return self._forward_with_embeds(
121
122
  input_embeds, rope, mask, input_pos, kv_cache, export_config
122
123
  )
123
124
 
124
- def forward_with_embeds(
125
+ def _forward_with_embeds(
125
126
  self,
126
127
  input_embeds: torch.Tensor,
127
128
  rope: Tuple[torch.Tensor, torch.Tensor],
@@ -16,7 +16,7 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import List
19
+ from typing import Any,List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -87,6 +87,10 @@ class ReauthoredModelWrapper(ModelWrapper):
87
87
  """Returns an initialized KV cache."""
88
88
  return kv_utils.KVCache.from_model_config(self.model.config)
89
89
 
90
+ def _get_extra_args_for_forward(self) -> dict[str, Any]:
91
+ """Returns extra arguments for the forward() method."""
92
+ return {}
93
+
90
94
  def _forward_with_kv_cache(
91
95
  self,
92
96
  tokens: torch.Tensor,
@@ -105,26 +109,15 @@ class ReauthoredModelWrapper(ModelWrapper):
105
109
  Returns:
106
110
  The output logits and the updated KV cache.
107
111
  """
108
- # Verification requires logit outputs on prefill for comparison.
109
- if (
110
- self.export_config is not None
111
- and not self.export_config.output_logits_on_prefill
112
- ):
113
- raise ValueError("Verifier requires logit output on prefill.")
114
- # Since the reauthored model doesn't include keyword arguments, pass
115
- # pixel_values only when it is not None. Otherwise, it may raise an error.
116
- if pixel_values is None:
117
- output = self.model.forward(
118
- tokens, input_pos, kv_cache, export_config=self.export_config
119
- )
120
- else:
121
- output = self.model.forward(
122
- tokens,
123
- input_pos,
124
- kv_cache,
125
- pixel_values=pixel_values,
126
- export_config=self.export_config,
127
- )
112
+ extra_args = self._get_extra_args_for_forward()
113
+ if self.export_config is not None:
114
+ # Verification requires logit outputs on prefill for comparison.
115
+ if not self.export_config.output_logits_on_prefill:
116
+ raise ValueError("Verifier requires logit output on prefill.")
117
+ extra_args["export_config"] = self.export_config
118
+ if pixel_values is not None:
119
+ extra_args["pixel_values"] = pixel_values
120
+ output = self.model.forward(tokens, input_pos, kv_cache, **extra_args)
128
121
  return output["logits"], output["kv_cache"]
129
122
 
130
123
  def forward(
@@ -141,6 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
141
134
  prompts: torch.Tensor,
142
135
  max_new_tokens: int,
143
136
  pixel_values: torch.Tensor = None,
137
+ eos_token_id: int = 1,
144
138
  ) -> torch.IntTensor:
145
139
  input_ids = prompts[0].int().tolist()
146
140
  tokens = torch.tensor([input_ids])
@@ -152,6 +146,8 @@ class ReauthoredModelWrapper(ModelWrapper):
152
146
  )
153
147
  generated_token = logits[0][-1].argmax().item()
154
148
  input_ids.append(generated_token)
149
+ if generated_token == eos_token_id:
150
+ break
155
151
  tokens = torch.tensor([[generated_token]])
156
152
  input_pos = torch.tensor([len(input_ids) - 1])
157
153
  pixel_values = None # Pass only for the first time.
@@ -254,7 +250,11 @@ def verify_model_with_prompts(
254
250
  logging.info("outputs_from_original_model: [[%s]]", response_original)
255
251
 
256
252
  logging.info("Generating answer with the reauthored model...")
257
- outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
253
+ outputs_reauthored = reauthored_model.generate(
254
+ prompt_tokens,
255
+ max_new_tokens,
256
+ eos_token_id=tokenizer.tokenizer.eos_token_id,
257
+ )
258
258
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
259
259
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
260
260
 
@@ -198,7 +198,12 @@ class MlirLowered:
198
198
  # build, which may not have the same StableHLO version as what used in
199
199
  # TFLite converter. Therefore we always serialize MLIR module in VHLO.
200
200
  # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201
- target_version = stablehlo.get_minimum_version()
201
+ if stablehlo.get_api_version() < 9:
202
+ target_version = stablehlo.get_minimum_version()
203
+ else:
204
+ target_version = stablehlo.get_version_from_compatibility_requirement(
205
+ stablehlo.StablehloCompatibilityRequirement.WEEK_4
206
+ )
202
207
  module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203
208
  self.module_bytecode, target_version
204
209
  )
@@ -12,4 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
15
+ from ai_edge_torch.odml_torch.jax_bridge import _wrap
16
+ from ai_edge_torch.odml_torch.jax_bridge import utils
17
+
18
+ wrap = _wrap.wrap
@@ -18,6 +18,7 @@ from . import _convolution
18
18
  from . import _jax_lowerings
19
19
  from . import _layer_norm
20
20
  from . import _quantized_decomposed
21
+ from . import _rand
21
22
  from . import context
22
23
  from . import registry
23
24
  from . import utils
@@ -26,6 +26,7 @@ import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
26
26
 
27
27
  LoweringContext = context.LoweringContext
28
28
 
29
+
29
30
  @functools.cache
30
31
  def _log_usage(op):
31
32
  logging.warning("Use jax lowering: %s", str(op))
@@ -184,8 +185,6 @@ lower_by_torch_xla2(torch.ops.aten.permute_copy)
184
185
  lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
185
186
  lower_by_torch_xla2(torch.ops.aten.pow)
186
187
  lower_by_torch_xla2(torch.ops.aten.prod)
187
- lower_by_torch_xla2(torch.ops.aten.rand)
188
- lower_by_torch_xla2(torch.ops.aten.randn)
189
188
  lower_by_torch_xla2(torch.ops.aten.reciprocal)
190
189
  lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
191
190
  lower_by_torch_xla2(torch.ops.aten.relu)
@@ -0,0 +1,142 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import uuid
16
+
17
+ from ai_edge_torch.odml_torch import export_utils
18
+ from ai_edge_torch.odml_torch.lowerings import context
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import func
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils._pytree as pytree
26
+
27
+ LoweringContext = context.LoweringContext
28
+ lower = registry.lower
29
+
30
+
31
+ def _random_lowering(
32
+ lctx: LoweringContext,
33
+ size: list[int],
34
+ generator,
35
+ dtype: torch.dtype,
36
+ rand_tensor,
37
+ composite_name: str,
38
+ ):
39
+ if dtype is None:
40
+ dtype = torch.float32
41
+
42
+ rand_tensor = rand_tensor.type(dtype)
43
+ data = rand_tensor.detach().numpy()
44
+
45
+ shape, _ = pytree.tree_flatten(size)
46
+ elty = export_utils.torch_dtype_to_ir_element_type(dtype)
47
+
48
+ decomp_name = f"{composite_name}.impl_{uuid.uuid4().hex[:8]}"
49
+
50
+ with ir.InsertionPoint(lctx.ir_module.body):
51
+
52
+ @func.FuncOp.from_py_func(
53
+ ir.RankedTensorType.get(
54
+ [len(shape)],
55
+ ir.IntegerType.get_signless(32),
56
+ ),
57
+ name=decomp_name,
58
+ )
59
+ def _rand_impl(_):
60
+ return [stablehlo.constant(ir.DenseElementsAttr.get(data))]
61
+
62
+ seed, seed2 = (
63
+ torch.randint(
64
+ torch.iinfo(torch.int64).min,
65
+ torch.iinfo(torch.int64).max,
66
+ (2,),
67
+ dtype=torch.int64,
68
+ generator=generator,
69
+ )
70
+ .detach()
71
+ .numpy()
72
+ )
73
+
74
+ shape_ = stablehlo.constant(
75
+ ir.DenseElementsAttr.get(np.array(shape, dtype=np.int32))
76
+ )
77
+ return stablehlo.CompositeOp(
78
+ result=[ir.RankedTensorType.get(shape, elty)],
79
+ inputs=[shape_],
80
+ name=composite_name,
81
+ composite_attributes=ir.DictAttr.get({
82
+ "seed": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed),
83
+ "seed2": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed2),
84
+ }),
85
+ decomposition=decomp_name,
86
+ ).results[0]
87
+
88
+
89
+ # Schema:
90
+ # - aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
91
+ # Device? device=None, bool? pin_memory=None) -> Tensor
92
+ # - aten::rand.generator(SymInt[] size, *, Generator? generator,
93
+ # ScalarType? dtype=None, Layout? layout=None, Device? device=None,
94
+ # bool? pin_memory=None) -> Tensor
95
+ @registry.lower(torch.ops.aten.rand)
96
+ def _aten_rand(
97
+ lctx: LoweringContext,
98
+ size,
99
+ generator=None,
100
+ dtype=None,
101
+ layout=torch.strided,
102
+ device=None,
103
+ pin_memory=False,
104
+ ):
105
+ return _random_lowering(
106
+ lctx,
107
+ size,
108
+ generator,
109
+ dtype,
110
+ rand_tensor=torch.ops.aten.rand.generator(
111
+ size, generator=generator, dtype=dtype
112
+ ),
113
+ composite_name="odml.random_uniform",
114
+ )
115
+
116
+
117
+ # Schema:
118
+ # - aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
119
+ # Device? device=None, bool? pin_memory=None) -> Tensor
120
+ # - aten::randn.generator(SymInt[] size, *, Generator? generator,
121
+ # ScalarType? dtype=None, Layout? layout=None, Device? device=None,
122
+ # bool? pin_memory=None) -> Tensor
123
+ @registry.lower(torch.ops.aten.randn)
124
+ def _aten_randn(
125
+ lctx: LoweringContext,
126
+ size,
127
+ generator=None,
128
+ dtype=None,
129
+ layout=torch.strided,
130
+ device=None,
131
+ pin_memory=False,
132
+ ):
133
+ return _random_lowering(
134
+ lctx,
135
+ size,
136
+ generator,
137
+ dtype,
138
+ rand_tensor=torch.ops.aten.randn.generator(
139
+ size, generator=generator, dtype=dtype
140
+ ),
141
+ composite_name="odml.random_standard_normal",
142
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241220"
16
+ __version__ = "0.3.0.dev20241224"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241220
3
+ Version: 0.3.0.dev20241224
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=xD-MWAEa1ROHhyF3rY7MaL28xsuON0aJwaiXbJ04qfc,706
6
+ ai_edge_torch/version.py,sha256=TkfJYt2lJC8A_AcieO1xVmMQ2xdnoTOwF8CZ5dZeaqc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -46,8 +46,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3
46
46
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
49
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=N0jKVZA3qWKOaHVbIM3WmQh3u0Sq7MTw_oO3Zo16wCw,3456
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=roEwWVXASbk5BFj7jojjEJpHui6gCelT51l-TtN_ZaQ,9367
49
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=VTM2nO3TqK2d1DyEb2MiHc-Tyw2lMcUXyOhvg0H5ENY,10147
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -63,13 +63,15 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6x
63
63
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
66
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=LFCcnkmOksySDa_5bLBzoGMijYdFVjXIMidUlyzAbNk,2996
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=amN96oBMTPolOFvGa47vG92AZ-BNLm8j0bBYd-IrMvI,5407
68
+ ai_edge_torch/generative/examples/paligemma/decoder2.py,sha256=0V_CX0Pn5Fj_-koOGjc_Av2KMSAaVjAlD-G8P6FBGyY,6385
68
69
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
69
- ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
70
- ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
70
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=FwGlFHl9zktGDxnoOpEtbS6NYN5RyzcOXH7lvNUCwEU,6257
71
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=KT3Ruy40tSESxQuy-Sw01NAI3zId1BZr6Bp7FZj1wZk,5622
71
72
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
72
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
73
+ ai_edge_torch/generative/examples/paligemma/verify_decoder2.py,sha256=tm-UfLr0YeBRVcQsWLBOMWI9JUzHmtPEbYK2vpITpqY,2534
74
+ ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=vNm-wTT8BD6zbX6GocfP1QrVoHl0zSvuVxoXN36eeiU,3540
73
75
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
74
76
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
75
77
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
@@ -140,19 +142,19 @@ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1e
140
142
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
141
143
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
142
144
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
143
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
145
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NctnggTSFh0XEQbTu55diZ35rFD2QIARO-8PzLktRWg,12165
144
146
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
145
147
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
146
148
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
147
149
  ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
148
150
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
149
151
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
150
- ai_edge_torch/generative/utilities/model_builder.py,sha256=q82-1E2zYlzpbFW6Vw-MWrJivRXHKpRh8jUxpR-w0sY,6349
152
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=S08WNqVKCmxd2QjtMlwETd7J97UnlME_bTKdz5LMkGU,6352
151
153
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
154
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
155
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
154
156
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
155
- ai_edge_torch/generative/utilities/verifier.py,sha256=ESSA8W1EYNsd4ntwmXbr-dn-BcIS27hf53XL5RTwjEU,11941
157
+ ai_edge_torch/generative/utilities/verifier.py,sha256=awO-sQrEpsFxIkZw72ysWZenYEmkLOLOuj62o2c7XeQ,11994
156
158
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
157
159
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
158
160
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -169,7 +171,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
169
171
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
170
172
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
171
173
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
172
- ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
174
+ ai_edge_torch/odml_torch/export.py,sha256=QzOPmcNPB7R-KhhPEP0oGVbDRgGPptIxRSoz3S8py9I,13405
173
175
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
174
176
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
175
177
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -178,16 +180,17 @@ ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_g
178
180
  ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu2YRyGlMZZqVPWUH6g,762
179
181
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
180
182
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
181
- ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
183
+ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
182
184
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
183
185
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
184
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
186
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GWFl7WWgExLXu6FEYxnig5_g6hd_Sfnl8690uFg2-CU,1013
185
187
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
186
188
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
187
189
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
188
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
190
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
189
191
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
190
192
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
193
+ ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
191
194
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
192
195
  ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
193
196
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
@@ -200,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/METADATA,sha256=PfyYhqbf7VEibw2TEDRb8tBOIPG9dfXhT9tNNou_iZg,1966
205
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20241224.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20241224.dist-info/METADATA,sha256=J8nAtbMNmSIcHuThVv0omkhpldggz91pIIYy-6ATJgM,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20241224.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20241224.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20241224.dist-info/RECORD,,