ai-edge-torch-nightly 0.3.0.dev20241115__py3-none-any.whl → 0.3.0.dev20241116__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,9 +15,11 @@
15
15
 
16
16
  """Example of building a decoder of PaliGemma 3B model which is Gemma1."""
17
17
 
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
18
19
  import ai_edge_torch.generative.layers.model_config as cfg
19
20
  from ai_edge_torch.generative.utilities import model_builder
20
21
  import ai_edge_torch.generative.utilities.loader as loading_utils
22
+ import torch
21
23
 
22
24
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
25
  ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
@@ -35,6 +37,41 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
37
  )
36
38
 
37
39
 
40
+ class Decoder(model_builder.DecoderOnlyModel):
41
+ """A decoder of PaliGemma 3B model which is Gemma1.
42
+
43
+ Besides a tensor of text token IDs, forward() can also take a tensor of
44
+ embeddings which may include text or image or both.
45
+ """
46
+
47
+ @torch.inference_mode
48
+ def forward(
49
+ self,
50
+ tokens: torch.Tensor,
51
+ input_pos: torch.Tensor,
52
+ kv_cache: kv_utils.KVCache,
53
+ input_embeds: torch.Tensor = None,
54
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
55
+ if input_embeds is None:
56
+ return super().forward(tokens, input_pos, kv_cache)
57
+
58
+ assert input_embeds is not None
59
+
60
+ repo_pos = input_pos + 1 # PaliGemma position is 1-based.
61
+ cos, sin = self.rope_cache
62
+ rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
63
+
64
+ # The first part of input_embeds are image embeddings. Diagonal causal mask
65
+ # doesn't work here.
66
+ embeds_len = input_embeds.shape[1]
67
+ mask = torch.zeros(embeds_len, self.config.kv_cache_max)
68
+ mask[:, embeds_len:] = float("-inf")
69
+
70
+ return self.forward_with_embeds(
71
+ input_embeds, rope, mask, input_pos, kv_cache
72
+ )
73
+
74
+
38
75
  def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
39
76
  """Returns the model config for the decoder of a PaliGemma 3B model.
40
77
 
@@ -96,8 +133,9 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
96
133
  def build_decoder(
97
134
  checkpoint_path: str, **kwargs
98
135
  ) -> model_builder.DecoderOnlyModel:
99
- return model_builder.build_decoder_only_model(
100
- checkpoint_path=checkpoint_path,
101
- config=get_decoder_config(**kwargs),
102
- tensor_names=TENSOR_NAMES,
103
- )
136
+ decoder = Decoder(get_decoder_config(**kwargs))
137
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
138
+ # Loose the strictness because only decoder is being loaded.
139
+ loader.load(decoder, strict=False)
140
+ decoder.eval()
141
+ return decoder
@@ -0,0 +1,135 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a full-stack of PaliGemma model."""
17
+
18
+ from dataclasses import dataclass
19
+
20
+ from ai_edge_torch.generative.examples.paligemma import decoder
21
+ from ai_edge_torch.generative.examples.paligemma import image_encoder
22
+ import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
+ import ai_edge_torch.generative.layers.model_config as cfg
24
+ import ai_edge_torch.generative.utilities.loader as loading_utils
25
+ import torch
26
+ from torch import nn
27
+
28
+ PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
29
+
30
+
31
+ @dataclass
32
+ class PaliGemmaConfig:
33
+ """PaliGemma model configurations."""
34
+
35
+ image_encoder_config: cfg.ModelConfig
36
+ decoder_config: cfg.ModelConfig
37
+
38
+ image_token_id: int
39
+ image_projection_use_bias: bool = False
40
+
41
+
42
+ class PaliGemma(nn.Module):
43
+ """PaliGemma model from the Edge Generative API."""
44
+
45
+ def __init__(self, config: PaliGemmaConfig):
46
+ super().__init__()
47
+
48
+ self.image_encoder = image_encoder.SiglipVisionEncoder(
49
+ config.image_encoder_config
50
+ )
51
+ self.image_projection = nn.Linear(
52
+ config.image_encoder_config.embedding_dim,
53
+ config.decoder_config.embedding_dim,
54
+ bias=config.image_projection_use_bias,
55
+ )
56
+ self.decoder = decoder.Decoder(config.decoder_config)
57
+ self.config = config
58
+
59
+ @torch.inference_mode
60
+ def forward(
61
+ self,
62
+ tokens: torch.Tensor,
63
+ input_pos: torch.Tensor,
64
+ kv_cache: kv_utils.KVCache,
65
+ pixel_values: torch.Tensor = None,
66
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
67
+ if pixel_values is None:
68
+ return self.decoder(tokens, input_pos, kv_cache)
69
+
70
+ input_embeds = self.decoder.tok_embedding(tokens)
71
+
72
+ image_encoded = self.image_encoder(pixel_values=pixel_values)
73
+ image_embeds = self.image_projection(image_encoded)
74
+ if self.config.decoder_config.embedding_scale is not None:
75
+ image_embeds = image_embeds / self.config.decoder_config.embedding_scale
76
+
77
+ # Merge image_embeds into text_embeds as PaliGemmaForConditionalGeneration.
78
+ image_mask = tokens == self.config.image_token_id
79
+ image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
80
+ input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
81
+
82
+ return self.decoder(
83
+ tokens=None,
84
+ input_pos=input_pos,
85
+ kv_cache=kv_cache,
86
+ input_embeds=input_embeds,
87
+ )
88
+
89
+
90
+ def get_model_config() -> PaliGemmaConfig:
91
+ """Returns the model config for a PaliGemma 3B-224 model.
92
+
93
+ Returns:
94
+ The model config for a PaliGemma 3B model.
95
+ """
96
+ return PaliGemmaConfig(
97
+ image_encoder_config=image_encoder.get_image_encoder_config(),
98
+ decoder_config=decoder.get_decoder_config(),
99
+ image_projection_use_bias=True,
100
+ image_token_id=257152,
101
+ )
102
+
103
+
104
+ def get_fake_image_encoder_config() -> PaliGemmaConfig:
105
+ return PaliGemmaConfig(
106
+ image_encoder_config=image_encoder.get_fake_image_encoder_config(),
107
+ decoder_config=decoder.get_fake_decoder_config(),
108
+ image_projection_use_bias=True,
109
+ image_token_id=257152,
110
+ )
111
+
112
+
113
+ def build_model(checkpoint_path: str) -> PaliGemma:
114
+ config = get_model_config()
115
+ model = PaliGemma(config)
116
+ # Load the parameters of image encoder.
117
+ loader = loading_utils.ModelLoader(
118
+ checkpoint_path, image_encoder.TENSOR_NAMES
119
+ )
120
+ loader.load(model.image_encoder, strict=False)
121
+ # Load the parameters of decoder.
122
+ loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
123
+ loader.load(model.decoder, strict=False)
124
+
125
+ # Load the parameters of image projection.
126
+ loader = loading_utils.ModelLoader(checkpoint_path, None)
127
+ state = loader.get_state()
128
+ converted_state = dict()
129
+ converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
130
+ if config.image_projection_use_bias:
131
+ converted_state["bias"] = state.pop(f"{PROJECTION_TENSOR_NAME}.bias")
132
+ model.image_projection.load_state_dict(converted_state)
133
+
134
+ model.eval()
135
+ return model
@@ -0,0 +1,134 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored PaliGemma 3B model."""
17
+
18
+ import logging
19
+ import pathlib
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.paligemma import paligemma
23
+ from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.utilities import verifier
25
+ from PIL import Image
26
+ import requests
27
+ import torch
28
+ import transformers
29
+
30
+ _IMAGE_URL = flags.DEFINE_string(
31
+ "image_url",
32
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
33
+ "The image URI to encode.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_string(
36
+ "prompts",
37
+ "Caption en",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+
47
+ class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
48
+ """Reauthored PaliGemma model wrapper."""
49
+
50
+ def _init_kv_cache(self):
51
+ return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
52
+
53
+
54
+ def main(_):
55
+ checkpoint = "google/paligemma-3b-mix-224"
56
+ logging.info("Loading the original model from: %s", checkpoint)
57
+ original_model = (
58
+ transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
59
+ )
60
+
61
+ # Locate the cached dir.
62
+ cached_config_file = transformers.utils.cached_file(
63
+ checkpoint, transformers.utils.CONFIG_NAME
64
+ )
65
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
66
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
67
+ reauthored_model = paligemma.build_model(reauthored_checkpoint)
68
+
69
+ logging.info("Loading the processor from: %s", checkpoint)
70
+ # It works only when GemmaTokenizerFast is available. In some environments,
71
+ # use_fast=False doeesn't work either if the tokenizer cannot load the
72
+ # sentencepiece model file properly.
73
+ processor = transformers.AutoProcessor.from_pretrained(checkpoint)
74
+
75
+ logging.info("Loading the image from: %s", _IMAGE_URL.value)
76
+ image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
77
+ inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
78
+
79
+ logging.info("Verifying the reauthored model with model.forward()...")
80
+ logging.info("Forwarding the original model...")
81
+ outputs_original = original_model.forward(
82
+ input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"]
83
+ )
84
+ outputs_original = outputs_original.logits
85
+ logging.info("outputs_original: %s", outputs_original)
86
+
87
+ logging.info("Forwarding the reauthored model...")
88
+ wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
89
+ outputs_reauthored = wrapped_reauthored_model.forward(
90
+ tokens=inputs["input_ids"],
91
+ pixel_values=inputs["pixel_values"],
92
+ )
93
+ logging.info("outputs_reauthored: %s", outputs_reauthored)
94
+
95
+ try:
96
+ assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
97
+ except AssertionError as e:
98
+ logging.error("*** FAILED *** verify with forward()")
99
+ raise e
100
+ else:
101
+ logging.info("*** PASSED *** verify with forward()")
102
+
103
+ logging.info("Verifying the reauthored model with model.generate()...")
104
+ logging.info("Generating answer with the original model...")
105
+ outputs_original = original_model.generate(
106
+ **inputs, max_new_tokens=_MAX_NEW_TOKENS.value, do_sample=False
107
+ )
108
+ response_original = processor.decode(
109
+ outputs_original[0], skip_special_tokens=True
110
+ )
111
+ logging.info("outputs_from_original_model: [[%s]]", response_original)
112
+
113
+ logging.info("Generating answer with the reauthored model...")
114
+ outputs_reauthored = wrapped_reauthored_model.generate(
115
+ prompts=inputs["input_ids"],
116
+ pixel_values=inputs["pixel_values"],
117
+ max_new_tokens=_MAX_NEW_TOKENS.value,
118
+ )
119
+ response_reauthored = processor.decode(
120
+ outputs_reauthored[0], skip_special_tokens=True
121
+ )
122
+ logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
123
+
124
+ try:
125
+ assert response_original == response_reauthored
126
+ except AssertionError as e:
127
+ logging.error("*** FAILED *** verify with generate()")
128
+ raise e
129
+ else:
130
+ logging.info("*** PASSED *** verify with generate()")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ app.run(main)
@@ -131,6 +131,9 @@ class ModelLoader:
131
131
  self._names = names
132
132
  self._loader = self._get_loader()
133
133
 
134
+ def get_state(self) -> Dict[str, torch.Tensor]:
135
+ return self._loader(self._file_name)
136
+
134
137
  def load(
135
138
  self, model: torch.nn.Module, strict: bool = True
136
139
  ) -> Tuple[List[str], List[str]]:
@@ -150,7 +153,7 @@ class ModelLoader:
150
153
  ValueError: If conversion results in unmapped tensors and strict mode is
151
154
  enabled.
152
155
  """
153
- state = self._loader(self._file_name)
156
+ state = self.get_state()
154
157
  state = state["model_state_dict"] if "model_state_dict" in state else state
155
158
  converted_state = dict()
156
159
  if self._names.embedding is not None:
@@ -16,6 +16,7 @@
16
16
  """Utilities to be used for re-authoring transformer models."""
17
17
 
18
18
  import copy
19
+ from typing import Tuple
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
@@ -98,26 +99,40 @@ class DecoderOnlyModel(nn.Module):
98
99
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
99
100
  f" {self.config.max_seq_len}"
100
101
  )
101
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
102
- "The number of transformer blocks and the number of KV cache entries"
103
- " must be the same."
104
- )
105
102
 
103
+ # token embeddings of shape (b, t, n_embd)
104
+ input_embeds = self.tok_embedding(tokens)
106
105
  cos, sin = self.rope_cache
107
- cos = cos.index_select(0, input_pos)
108
- sin = sin.index_select(0, input_pos)
106
+ rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
109
107
  mask = self.mask_cache.index_select(2, input_pos)
110
108
  mask = mask[:, :, :, : self.config.kv_cache_max]
111
109
 
112
- # token embeddings of shape (b, t, n_embd)
113
- x = self.tok_embedding(tokens)
110
+ return self.forward_with_embeds(
111
+ input_embeds, rope, mask, input_pos, kv_cache
112
+ )
113
+
114
+ def forward_with_embeds(
115
+ self,
116
+ input_embeds: torch.Tensor,
117
+ rope: Tuple[torch.Tensor, torch.Tensor],
118
+ mask: torch.Tensor,
119
+ input_pos: torch.Tensor,
120
+ kv_cache: kv_utils.KVCache,
121
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
122
+ """Forwards the model with input embeddings."""
123
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
124
+ "The number of transformer blocks and the number of KV cache entries"
125
+ " must be the same."
126
+ )
127
+
128
+ x = input_embeds
114
129
  if self.config.embedding_scale is not None:
115
130
  x = x * self.config.embedding_scale
116
131
 
117
132
  updated_kv_entires = []
118
133
  for i, block in enumerate(self.transformer_blocks):
119
134
  kv_entry = kv_cache.caches[i] if kv_cache else None
120
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
135
+ x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
121
136
  if kv_entry:
122
137
  updated_kv_entires.append(kv_entry)
123
138
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
@@ -41,7 +41,9 @@ class ModelWrapper(torch.nn.Module):
41
41
  super().__init__()
42
42
  self.model = model
43
43
 
44
- def forward(self, tokens: torch.Tensor) -> torch.Tensor:
44
+ def forward(
45
+ self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
46
+ ) -> torch.Tensor:
45
47
  """Gets output logits by forwarding the input tokens.
46
48
 
47
49
  Args:
@@ -54,7 +56,10 @@ class ModelWrapper(torch.nn.Module):
54
56
  raise NotImplementedError("forward() is not implemented.")
55
57
 
56
58
  def generate(
57
- self, prompts: torch.Tensor, max_new_tokens: int
59
+ self,
60
+ prompts: torch.Tensor,
61
+ max_new_tokens: int,
62
+ pixel_values: torch.Tensor = None,
58
63
  ) -> torch.IntTensor:
59
64
  """Returns the response token IDs to the given prompts tensor.
60
65
 
@@ -83,35 +88,59 @@ class ReauthoredModelWrapper(ModelWrapper):
83
88
  def _forward_with_kv_cache(
84
89
  self,
85
90
  tokens: torch.Tensor,
91
+ input_pos: torch.Tensor,
86
92
  kv_cache: kv_utils.KVCache,
93
+ pixel_values: torch.Tensor,
87
94
  ) -> tuple[torch.Tensor, kv_utils.KVCache]:
88
95
  """Forwards the model and updates an external KV cache.
89
96
 
90
97
  Args:
91
98
  tokens (torch.Tensor): The input tokens to forward.
99
+ input_pos (torch.Tensor): The input positions to forward.
92
100
  kv_cache (KVCache): The KV cache to forward.
101
+ pixel_values (torch.Tensor): The input pixel values to forward.
93
102
 
94
103
  Returns:
95
104
  The output logits and the updated KV cache.
96
105
  """
97
- input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
98
- output = self.model.forward(tokens, input_pos, kv_cache)
106
+ # Since the reauthored model doesn't include keyword arguments, pass
107
+ # pixel_values only when it is not None. Otherwise, it may raise an error.
108
+ if pixel_values is None:
109
+ output = self.model.forward(tokens, input_pos, kv_cache)
110
+ else:
111
+ output = self.model.forward(
112
+ tokens, input_pos, kv_cache, pixel_values=pixel_values
113
+ )
99
114
  return output["logits"], output["kv_cache"]
100
115
 
101
- def forward(self, tokens: torch.Tensor) -> torch.Tensor:
102
- logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
116
+ def forward(
117
+ self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
118
+ ) -> torch.Tensor:
119
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
120
+ logits, _ = self._forward_with_kv_cache(
121
+ tokens, input_pos, self._init_kv_cache(), pixel_values
122
+ )
103
123
  return logits
104
124
 
105
125
  def generate(
106
- self, prompts: torch.Tensor, max_new_tokens: int
126
+ self,
127
+ prompts: torch.Tensor,
128
+ max_new_tokens: int,
129
+ pixel_values: torch.Tensor = None,
107
130
  ) -> torch.IntTensor:
108
131
  input_ids = prompts[0].int().tolist()
132
+ tokens = torch.tensor([input_ids])
133
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
109
134
  kv_cache = self._init_kv_cache()
110
135
  for _ in range(max_new_tokens):
111
- tokens = torch.tensor([input_ids])
112
- logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
136
+ logits, kv_cache = self._forward_with_kv_cache(
137
+ tokens, input_pos, kv_cache, pixel_values
138
+ )
113
139
  generated_token = logits[0][-1].argmax().item()
114
140
  input_ids.append(generated_token)
141
+ tokens = torch.tensor([[generated_token]])
142
+ input_pos = torch.tensor([len(input_ids) - 1])
143
+ pixel_values = None # Pass only for the first time.
115
144
  return torch.tensor([input_ids])
116
145
 
117
146
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241115"
16
+ __version__ = "0.3.0.dev20241116"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241115
3
+ Version: 0.3.0.dev20241116
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=pp4KVtq0a8ju4UB5nOeiv7QDkmgpHmz5XUokSR86qfI,706
6
+ ai_edge_torch/version.py,sha256=VA2R7z515pfD79tg2AjlwXASYb6LSz0-kch5NJzdj3k,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -61,8 +61,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
61
61
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
62
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
63
63
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
64
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
65
65
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
66
+ ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=mbq9CBp2znXPIQdzIQTiQGRh4Ql3bn9kyX-k_LXKTms,4537
67
+ ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
66
68
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
67
69
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
68
70
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -139,12 +141,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
139
141
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
140
142
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
141
143
  ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
142
- ai_edge_torch/generative/utilities/loader.py,sha256=k5fjCokNomte4ymy9IJrEWAuCSMhsPCJfmv1y5s0ZEc,13452
143
- ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
144
+ ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
145
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
144
146
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
145
147
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
146
148
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
147
- ai_edge_torch/generative/utilities/verifier.py,sha256=h5hGyIpYGyPZwvelbzpdkjy99Kpd4JkvhqWtQN9cm-M,10413
149
+ ai_edge_torch/generative/utilities/verifier.py,sha256=GLh7h8pcpSKtCKoPyxJhv3TmvENd2h6ek_cnbe2s3Ak,11418
148
150
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
149
151
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
150
152
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -191,8 +193,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
191
193
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
192
194
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
193
195
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
194
- ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
195
- ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/METADATA,sha256=epuuYZFnqVvLzIS0X27XMCFQpnc-dO8JJQ8DXVNv5IE,1897
196
- ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
197
- ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
198
- ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/RECORD,,
196
+ ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
197
+ ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/METADATA,sha256=OyMmJ6EACAhEKbHNgLaGAogbjR8DwCLHYfIDdKW7iMI,1897
198
+ ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
199
+ ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
200
+ ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/RECORD,,