ai-edge-torch-nightly 0.3.0.dev20241114__py3-none-any.whl → 0.3.0.dev20241116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/paligemma/decoder.py +43 -5
- ai_edge_torch/generative/examples/paligemma/paligemma.py +135 -0
- ai_edge_torch/generative/examples/paligemma/verify.py +134 -0
- ai_edge_torch/generative/utilities/loader.py +4 -1
- ai_edge_torch/generative/utilities/model_builder.py +24 -9
- ai_edge_torch/generative/utilities/verifier.py +38 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/RECORD +12 -10
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""
|
17
17
|
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
20
21
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
21
23
|
|
22
24
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
23
25
|
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
|
@@ -35,6 +37,41 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
35
37
|
)
|
36
38
|
|
37
39
|
|
40
|
+
class Decoder(model_builder.DecoderOnlyModel):
|
41
|
+
"""A decoder of PaliGemma 3B model which is Gemma1.
|
42
|
+
|
43
|
+
Besides a tensor of text token IDs, forward() can also take a tensor of
|
44
|
+
embeddings which may include text or image or both.
|
45
|
+
"""
|
46
|
+
|
47
|
+
@torch.inference_mode
|
48
|
+
def forward(
|
49
|
+
self,
|
50
|
+
tokens: torch.Tensor,
|
51
|
+
input_pos: torch.Tensor,
|
52
|
+
kv_cache: kv_utils.KVCache,
|
53
|
+
input_embeds: torch.Tensor = None,
|
54
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
55
|
+
if input_embeds is None:
|
56
|
+
return super().forward(tokens, input_pos, kv_cache)
|
57
|
+
|
58
|
+
assert input_embeds is not None
|
59
|
+
|
60
|
+
repo_pos = input_pos + 1 # PaliGemma position is 1-based.
|
61
|
+
cos, sin = self.rope_cache
|
62
|
+
rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
|
63
|
+
|
64
|
+
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
65
|
+
# doesn't work here.
|
66
|
+
embeds_len = input_embeds.shape[1]
|
67
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
68
|
+
mask[:, embeds_len:] = float("-inf")
|
69
|
+
|
70
|
+
return self.forward_with_embeds(
|
71
|
+
input_embeds, rope, mask, input_pos, kv_cache
|
72
|
+
)
|
73
|
+
|
74
|
+
|
38
75
|
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
39
76
|
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
40
77
|
|
@@ -96,8 +133,9 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
96
133
|
def build_decoder(
|
97
134
|
checkpoint_path: str, **kwargs
|
98
135
|
) -> model_builder.DecoderOnlyModel:
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
)
|
136
|
+
decoder = Decoder(get_decoder_config(**kwargs))
|
137
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
138
|
+
# Loose the strictness because only decoder is being loaded.
|
139
|
+
loader.load(decoder, strict=False)
|
140
|
+
decoder.eval()
|
141
|
+
return decoder
|
@@ -0,0 +1,135 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a full-stack of PaliGemma model."""
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.paligemma import decoder
|
21
|
+
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
22
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
24
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
|
+
import torch
|
26
|
+
from torch import nn
|
27
|
+
|
28
|
+
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class PaliGemmaConfig:
|
33
|
+
"""PaliGemma model configurations."""
|
34
|
+
|
35
|
+
image_encoder_config: cfg.ModelConfig
|
36
|
+
decoder_config: cfg.ModelConfig
|
37
|
+
|
38
|
+
image_token_id: int
|
39
|
+
image_projection_use_bias: bool = False
|
40
|
+
|
41
|
+
|
42
|
+
class PaliGemma(nn.Module):
|
43
|
+
"""PaliGemma model from the Edge Generative API."""
|
44
|
+
|
45
|
+
def __init__(self, config: PaliGemmaConfig):
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
self.image_encoder = image_encoder.SiglipVisionEncoder(
|
49
|
+
config.image_encoder_config
|
50
|
+
)
|
51
|
+
self.image_projection = nn.Linear(
|
52
|
+
config.image_encoder_config.embedding_dim,
|
53
|
+
config.decoder_config.embedding_dim,
|
54
|
+
bias=config.image_projection_use_bias,
|
55
|
+
)
|
56
|
+
self.decoder = decoder.Decoder(config.decoder_config)
|
57
|
+
self.config = config
|
58
|
+
|
59
|
+
@torch.inference_mode
|
60
|
+
def forward(
|
61
|
+
self,
|
62
|
+
tokens: torch.Tensor,
|
63
|
+
input_pos: torch.Tensor,
|
64
|
+
kv_cache: kv_utils.KVCache,
|
65
|
+
pixel_values: torch.Tensor = None,
|
66
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
67
|
+
if pixel_values is None:
|
68
|
+
return self.decoder(tokens, input_pos, kv_cache)
|
69
|
+
|
70
|
+
input_embeds = self.decoder.tok_embedding(tokens)
|
71
|
+
|
72
|
+
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
73
|
+
image_embeds = self.image_projection(image_encoded)
|
74
|
+
if self.config.decoder_config.embedding_scale is not None:
|
75
|
+
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
76
|
+
|
77
|
+
# Merge image_embeds into text_embeds as PaliGemmaForConditionalGeneration.
|
78
|
+
image_mask = tokens == self.config.image_token_id
|
79
|
+
image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
|
80
|
+
input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
|
81
|
+
|
82
|
+
return self.decoder(
|
83
|
+
tokens=None,
|
84
|
+
input_pos=input_pos,
|
85
|
+
kv_cache=kv_cache,
|
86
|
+
input_embeds=input_embeds,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def get_model_config() -> PaliGemmaConfig:
|
91
|
+
"""Returns the model config for a PaliGemma 3B-224 model.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
The model config for a PaliGemma 3B model.
|
95
|
+
"""
|
96
|
+
return PaliGemmaConfig(
|
97
|
+
image_encoder_config=image_encoder.get_image_encoder_config(),
|
98
|
+
decoder_config=decoder.get_decoder_config(),
|
99
|
+
image_projection_use_bias=True,
|
100
|
+
image_token_id=257152,
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
def get_fake_image_encoder_config() -> PaliGemmaConfig:
|
105
|
+
return PaliGemmaConfig(
|
106
|
+
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
107
|
+
decoder_config=decoder.get_fake_decoder_config(),
|
108
|
+
image_projection_use_bias=True,
|
109
|
+
image_token_id=257152,
|
110
|
+
)
|
111
|
+
|
112
|
+
|
113
|
+
def build_model(checkpoint_path: str) -> PaliGemma:
|
114
|
+
config = get_model_config()
|
115
|
+
model = PaliGemma(config)
|
116
|
+
# Load the parameters of image encoder.
|
117
|
+
loader = loading_utils.ModelLoader(
|
118
|
+
checkpoint_path, image_encoder.TENSOR_NAMES
|
119
|
+
)
|
120
|
+
loader.load(model.image_encoder, strict=False)
|
121
|
+
# Load the parameters of decoder.
|
122
|
+
loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
|
123
|
+
loader.load(model.decoder, strict=False)
|
124
|
+
|
125
|
+
# Load the parameters of image projection.
|
126
|
+
loader = loading_utils.ModelLoader(checkpoint_path, None)
|
127
|
+
state = loader.get_state()
|
128
|
+
converted_state = dict()
|
129
|
+
converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
|
130
|
+
if config.image_projection_use_bias:
|
131
|
+
converted_state["bias"] = state.pop(f"{PROJECTION_TENSOR_NAME}.bias")
|
132
|
+
model.image_projection.load_state_dict(converted_state)
|
133
|
+
|
134
|
+
model.eval()
|
135
|
+
return model
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored PaliGemma 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.utilities import verifier
|
25
|
+
from PIL import Image
|
26
|
+
import requests
|
27
|
+
import torch
|
28
|
+
import transformers
|
29
|
+
|
30
|
+
_IMAGE_URL = flags.DEFINE_string(
|
31
|
+
"image_url",
|
32
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
33
|
+
"The image URI to encode.",
|
34
|
+
)
|
35
|
+
_PROMPTS = flags.DEFINE_string(
|
36
|
+
"prompts",
|
37
|
+
"Caption en",
|
38
|
+
"The input prompts to generate answers.",
|
39
|
+
)
|
40
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
41
|
+
"max_new_tokens",
|
42
|
+
30,
|
43
|
+
"The maximum size of the generated tokens.",
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
|
48
|
+
"""Reauthored PaliGemma model wrapper."""
|
49
|
+
|
50
|
+
def _init_kv_cache(self):
|
51
|
+
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
52
|
+
|
53
|
+
|
54
|
+
def main(_):
|
55
|
+
checkpoint = "google/paligemma-3b-mix-224"
|
56
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
57
|
+
original_model = (
|
58
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
59
|
+
)
|
60
|
+
|
61
|
+
# Locate the cached dir.
|
62
|
+
cached_config_file = transformers.utils.cached_file(
|
63
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
64
|
+
)
|
65
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
66
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
+
reauthored_model = paligemma.build_model(reauthored_checkpoint)
|
68
|
+
|
69
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
70
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
71
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
72
|
+
# sentencepiece model file properly.
|
73
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
74
|
+
|
75
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
76
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
77
|
+
inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
|
78
|
+
|
79
|
+
logging.info("Verifying the reauthored model with model.forward()...")
|
80
|
+
logging.info("Forwarding the original model...")
|
81
|
+
outputs_original = original_model.forward(
|
82
|
+
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"]
|
83
|
+
)
|
84
|
+
outputs_original = outputs_original.logits
|
85
|
+
logging.info("outputs_original: %s", outputs_original)
|
86
|
+
|
87
|
+
logging.info("Forwarding the reauthored model...")
|
88
|
+
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
|
89
|
+
outputs_reauthored = wrapped_reauthored_model.forward(
|
90
|
+
tokens=inputs["input_ids"],
|
91
|
+
pixel_values=inputs["pixel_values"],
|
92
|
+
)
|
93
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
94
|
+
|
95
|
+
try:
|
96
|
+
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
|
97
|
+
except AssertionError as e:
|
98
|
+
logging.error("*** FAILED *** verify with forward()")
|
99
|
+
raise e
|
100
|
+
else:
|
101
|
+
logging.info("*** PASSED *** verify with forward()")
|
102
|
+
|
103
|
+
logging.info("Verifying the reauthored model with model.generate()...")
|
104
|
+
logging.info("Generating answer with the original model...")
|
105
|
+
outputs_original = original_model.generate(
|
106
|
+
**inputs, max_new_tokens=_MAX_NEW_TOKENS.value, do_sample=False
|
107
|
+
)
|
108
|
+
response_original = processor.decode(
|
109
|
+
outputs_original[0], skip_special_tokens=True
|
110
|
+
)
|
111
|
+
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
112
|
+
|
113
|
+
logging.info("Generating answer with the reauthored model...")
|
114
|
+
outputs_reauthored = wrapped_reauthored_model.generate(
|
115
|
+
prompts=inputs["input_ids"],
|
116
|
+
pixel_values=inputs["pixel_values"],
|
117
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
118
|
+
)
|
119
|
+
response_reauthored = processor.decode(
|
120
|
+
outputs_reauthored[0], skip_special_tokens=True
|
121
|
+
)
|
122
|
+
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
123
|
+
|
124
|
+
try:
|
125
|
+
assert response_original == response_reauthored
|
126
|
+
except AssertionError as e:
|
127
|
+
logging.error("*** FAILED *** verify with generate()")
|
128
|
+
raise e
|
129
|
+
else:
|
130
|
+
logging.info("*** PASSED *** verify with generate()")
|
131
|
+
|
132
|
+
|
133
|
+
if __name__ == "__main__":
|
134
|
+
app.run(main)
|
@@ -131,6 +131,9 @@ class ModelLoader:
|
|
131
131
|
self._names = names
|
132
132
|
self._loader = self._get_loader()
|
133
133
|
|
134
|
+
def get_state(self) -> Dict[str, torch.Tensor]:
|
135
|
+
return self._loader(self._file_name)
|
136
|
+
|
134
137
|
def load(
|
135
138
|
self, model: torch.nn.Module, strict: bool = True
|
136
139
|
) -> Tuple[List[str], List[str]]:
|
@@ -150,7 +153,7 @@ class ModelLoader:
|
|
150
153
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
151
154
|
enabled.
|
152
155
|
"""
|
153
|
-
state = self.
|
156
|
+
state = self.get_state()
|
154
157
|
state = state["model_state_dict"] if "model_state_dict" in state else state
|
155
158
|
converted_state = dict()
|
156
159
|
if self._names.embedding is not None:
|
@@ -16,6 +16,7 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
+
from typing import Tuple
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
@@ -98,26 +99,40 @@ class DecoderOnlyModel(nn.Module):
|
|
98
99
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
99
100
|
f" {self.config.max_seq_len}"
|
100
101
|
)
|
101
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
102
|
-
"The number of transformer blocks and the number of KV cache entries"
|
103
|
-
" must be the same."
|
104
|
-
)
|
105
102
|
|
103
|
+
# token embeddings of shape (b, t, n_embd)
|
104
|
+
input_embeds = self.tok_embedding(tokens)
|
106
105
|
cos, sin = self.rope_cache
|
107
|
-
|
108
|
-
sin = sin.index_select(0, input_pos)
|
106
|
+
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
|
109
107
|
mask = self.mask_cache.index_select(2, input_pos)
|
110
108
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
111
109
|
|
112
|
-
|
113
|
-
|
110
|
+
return self.forward_with_embeds(
|
111
|
+
input_embeds, rope, mask, input_pos, kv_cache
|
112
|
+
)
|
113
|
+
|
114
|
+
def forward_with_embeds(
|
115
|
+
self,
|
116
|
+
input_embeds: torch.Tensor,
|
117
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
118
|
+
mask: torch.Tensor,
|
119
|
+
input_pos: torch.Tensor,
|
120
|
+
kv_cache: kv_utils.KVCache,
|
121
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
122
|
+
"""Forwards the model with input embeddings."""
|
123
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
124
|
+
"The number of transformer blocks and the number of KV cache entries"
|
125
|
+
" must be the same."
|
126
|
+
)
|
127
|
+
|
128
|
+
x = input_embeds
|
114
129
|
if self.config.embedding_scale is not None:
|
115
130
|
x = x * self.config.embedding_scale
|
116
131
|
|
117
132
|
updated_kv_entires = []
|
118
133
|
for i, block in enumerate(self.transformer_blocks):
|
119
134
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
120
|
-
x, kv_entry = block(x,
|
135
|
+
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
|
121
136
|
if kv_entry:
|
122
137
|
updated_kv_entires.append(kv_entry)
|
123
138
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
@@ -41,7 +41,9 @@ class ModelWrapper(torch.nn.Module):
|
|
41
41
|
super().__init__()
|
42
42
|
self.model = model
|
43
43
|
|
44
|
-
def forward(
|
44
|
+
def forward(
|
45
|
+
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
46
|
+
) -> torch.Tensor:
|
45
47
|
"""Gets output logits by forwarding the input tokens.
|
46
48
|
|
47
49
|
Args:
|
@@ -54,7 +56,10 @@ class ModelWrapper(torch.nn.Module):
|
|
54
56
|
raise NotImplementedError("forward() is not implemented.")
|
55
57
|
|
56
58
|
def generate(
|
57
|
-
self,
|
59
|
+
self,
|
60
|
+
prompts: torch.Tensor,
|
61
|
+
max_new_tokens: int,
|
62
|
+
pixel_values: torch.Tensor = None,
|
58
63
|
) -> torch.IntTensor:
|
59
64
|
"""Returns the response token IDs to the given prompts tensor.
|
60
65
|
|
@@ -83,35 +88,59 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
83
88
|
def _forward_with_kv_cache(
|
84
89
|
self,
|
85
90
|
tokens: torch.Tensor,
|
91
|
+
input_pos: torch.Tensor,
|
86
92
|
kv_cache: kv_utils.KVCache,
|
93
|
+
pixel_values: torch.Tensor,
|
87
94
|
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
88
95
|
"""Forwards the model and updates an external KV cache.
|
89
96
|
|
90
97
|
Args:
|
91
98
|
tokens (torch.Tensor): The input tokens to forward.
|
99
|
+
input_pos (torch.Tensor): The input positions to forward.
|
92
100
|
kv_cache (KVCache): The KV cache to forward.
|
101
|
+
pixel_values (torch.Tensor): The input pixel values to forward.
|
93
102
|
|
94
103
|
Returns:
|
95
104
|
The output logits and the updated KV cache.
|
96
105
|
"""
|
97
|
-
|
98
|
-
|
106
|
+
# Since the reauthored model doesn't include keyword arguments, pass
|
107
|
+
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
108
|
+
if pixel_values is None:
|
109
|
+
output = self.model.forward(tokens, input_pos, kv_cache)
|
110
|
+
else:
|
111
|
+
output = self.model.forward(
|
112
|
+
tokens, input_pos, kv_cache, pixel_values=pixel_values
|
113
|
+
)
|
99
114
|
return output["logits"], output["kv_cache"]
|
100
115
|
|
101
|
-
def forward(
|
102
|
-
|
116
|
+
def forward(
|
117
|
+
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
118
|
+
) -> torch.Tensor:
|
119
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
120
|
+
logits, _ = self._forward_with_kv_cache(
|
121
|
+
tokens, input_pos, self._init_kv_cache(), pixel_values
|
122
|
+
)
|
103
123
|
return logits
|
104
124
|
|
105
125
|
def generate(
|
106
|
-
self,
|
126
|
+
self,
|
127
|
+
prompts: torch.Tensor,
|
128
|
+
max_new_tokens: int,
|
129
|
+
pixel_values: torch.Tensor = None,
|
107
130
|
) -> torch.IntTensor:
|
108
131
|
input_ids = prompts[0].int().tolist()
|
132
|
+
tokens = torch.tensor([input_ids])
|
133
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
109
134
|
kv_cache = self._init_kv_cache()
|
110
135
|
for _ in range(max_new_tokens):
|
111
|
-
|
112
|
-
|
136
|
+
logits, kv_cache = self._forward_with_kv_cache(
|
137
|
+
tokens, input_pos, kv_cache, pixel_values
|
138
|
+
)
|
113
139
|
generated_token = logits[0][-1].argmax().item()
|
114
140
|
input_ids.append(generated_token)
|
141
|
+
tokens = torch.tensor([[generated_token]])
|
142
|
+
input_pos = torch.tensor([len(input_ids) - 1])
|
143
|
+
pixel_values = None # Pass only for the first time.
|
115
144
|
return torch.tensor([input_ids])
|
116
145
|
|
117
146
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241116
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=VA2R7z515pfD79tg2AjlwXASYb6LSz0-kch5NJzdj3k,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -61,8 +61,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
|
|
61
61
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
62
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
63
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
64
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
|
65
65
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
|
66
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=mbq9CBp2znXPIQdzIQTiQGRh4Ql3bn9kyX-k_LXKTms,4537
|
67
|
+
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
66
68
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
67
69
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
68
70
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -139,12 +141,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
139
141
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
140
142
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
141
143
|
ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
|
142
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
143
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
144
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
145
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
144
146
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
145
147
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
146
148
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
147
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
149
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=GLh7h8pcpSKtCKoPyxJhv3TmvENd2h6ek_cnbe2s3Ak,11418
|
148
150
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
149
151
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
150
152
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -191,8 +193,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
191
193
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
192
194
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
193
195
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
194
|
-
ai_edge_torch_nightly-0.3.0.
|
195
|
-
ai_edge_torch_nightly-0.3.0.
|
196
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
196
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/METADATA,sha256=OyMmJ6EACAhEKbHNgLaGAogbjR8DwCLHYfIDdKW7iMI,1897
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/RECORD,,
|
File without changes
|
File without changes
|