ai-edge-torch-nightly 0.3.0.dev20241114__py3-none-any.whl → 0.3.0.dev20241116__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/paligemma/decoder.py +43 -5
- ai_edge_torch/generative/examples/paligemma/paligemma.py +135 -0
- ai_edge_torch/generative/examples/paligemma/verify.py +134 -0
- ai_edge_torch/generative/utilities/loader.py +4 -1
- ai_edge_torch/generative/utilities/model_builder.py +24 -9
- ai_edge_torch/generative/utilities/verifier.py +38 -9
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/RECORD +12 -10
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241114.dist-info → ai_edge_torch_nightly-0.3.0.dev20241116.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Example of building a decoder of PaliGemma 3B model which is Gemma1."""
|
17
17
|
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
18
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
19
20
|
from ai_edge_torch.generative.utilities import model_builder
|
20
21
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
21
23
|
|
22
24
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
23
25
|
ff_up_proj="language_model.model.layers.{}.mlp.up_proj",
|
@@ -35,6 +37,41 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
35
37
|
)
|
36
38
|
|
37
39
|
|
40
|
+
class Decoder(model_builder.DecoderOnlyModel):
|
41
|
+
"""A decoder of PaliGemma 3B model which is Gemma1.
|
42
|
+
|
43
|
+
Besides a tensor of text token IDs, forward() can also take a tensor of
|
44
|
+
embeddings which may include text or image or both.
|
45
|
+
"""
|
46
|
+
|
47
|
+
@torch.inference_mode
|
48
|
+
def forward(
|
49
|
+
self,
|
50
|
+
tokens: torch.Tensor,
|
51
|
+
input_pos: torch.Tensor,
|
52
|
+
kv_cache: kv_utils.KVCache,
|
53
|
+
input_embeds: torch.Tensor = None,
|
54
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
55
|
+
if input_embeds is None:
|
56
|
+
return super().forward(tokens, input_pos, kv_cache)
|
57
|
+
|
58
|
+
assert input_embeds is not None
|
59
|
+
|
60
|
+
repo_pos = input_pos + 1 # PaliGemma position is 1-based.
|
61
|
+
cos, sin = self.rope_cache
|
62
|
+
rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
|
63
|
+
|
64
|
+
# The first part of input_embeds are image embeddings. Diagonal causal mask
|
65
|
+
# doesn't work here.
|
66
|
+
embeds_len = input_embeds.shape[1]
|
67
|
+
mask = torch.zeros(embeds_len, self.config.kv_cache_max)
|
68
|
+
mask[:, embeds_len:] = float("-inf")
|
69
|
+
|
70
|
+
return self.forward_with_embeds(
|
71
|
+
input_embeds, rope, mask, input_pos, kv_cache
|
72
|
+
)
|
73
|
+
|
74
|
+
|
38
75
|
def get_decoder_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
39
76
|
"""Returns the model config for the decoder of a PaliGemma 3B model.
|
40
77
|
|
@@ -96,8 +133,9 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
96
133
|
def build_decoder(
|
97
134
|
checkpoint_path: str, **kwargs
|
98
135
|
) -> model_builder.DecoderOnlyModel:
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
)
|
136
|
+
decoder = Decoder(get_decoder_config(**kwargs))
|
137
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
138
|
+
# Loose the strictness because only decoder is being loaded.
|
139
|
+
loader.load(decoder, strict=False)
|
140
|
+
decoder.eval()
|
141
|
+
return decoder
|
@@ -0,0 +1,135 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a full-stack of PaliGemma model."""
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.examples.paligemma import decoder
|
21
|
+
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
22
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
24
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
|
+
import torch
|
26
|
+
from torch import nn
|
27
|
+
|
28
|
+
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class PaliGemmaConfig:
|
33
|
+
"""PaliGemma model configurations."""
|
34
|
+
|
35
|
+
image_encoder_config: cfg.ModelConfig
|
36
|
+
decoder_config: cfg.ModelConfig
|
37
|
+
|
38
|
+
image_token_id: int
|
39
|
+
image_projection_use_bias: bool = False
|
40
|
+
|
41
|
+
|
42
|
+
class PaliGemma(nn.Module):
|
43
|
+
"""PaliGemma model from the Edge Generative API."""
|
44
|
+
|
45
|
+
def __init__(self, config: PaliGemmaConfig):
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
self.image_encoder = image_encoder.SiglipVisionEncoder(
|
49
|
+
config.image_encoder_config
|
50
|
+
)
|
51
|
+
self.image_projection = nn.Linear(
|
52
|
+
config.image_encoder_config.embedding_dim,
|
53
|
+
config.decoder_config.embedding_dim,
|
54
|
+
bias=config.image_projection_use_bias,
|
55
|
+
)
|
56
|
+
self.decoder = decoder.Decoder(config.decoder_config)
|
57
|
+
self.config = config
|
58
|
+
|
59
|
+
@torch.inference_mode
|
60
|
+
def forward(
|
61
|
+
self,
|
62
|
+
tokens: torch.Tensor,
|
63
|
+
input_pos: torch.Tensor,
|
64
|
+
kv_cache: kv_utils.KVCache,
|
65
|
+
pixel_values: torch.Tensor = None,
|
66
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
67
|
+
if pixel_values is None:
|
68
|
+
return self.decoder(tokens, input_pos, kv_cache)
|
69
|
+
|
70
|
+
input_embeds = self.decoder.tok_embedding(tokens)
|
71
|
+
|
72
|
+
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
73
|
+
image_embeds = self.image_projection(image_encoded)
|
74
|
+
if self.config.decoder_config.embedding_scale is not None:
|
75
|
+
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
76
|
+
|
77
|
+
# Merge image_embeds into text_embeds as PaliGemmaForConditionalGeneration.
|
78
|
+
image_mask = tokens == self.config.image_token_id
|
79
|
+
image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
|
80
|
+
input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
|
81
|
+
|
82
|
+
return self.decoder(
|
83
|
+
tokens=None,
|
84
|
+
input_pos=input_pos,
|
85
|
+
kv_cache=kv_cache,
|
86
|
+
input_embeds=input_embeds,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def get_model_config() -> PaliGemmaConfig:
|
91
|
+
"""Returns the model config for a PaliGemma 3B-224 model.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
The model config for a PaliGemma 3B model.
|
95
|
+
"""
|
96
|
+
return PaliGemmaConfig(
|
97
|
+
image_encoder_config=image_encoder.get_image_encoder_config(),
|
98
|
+
decoder_config=decoder.get_decoder_config(),
|
99
|
+
image_projection_use_bias=True,
|
100
|
+
image_token_id=257152,
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
def get_fake_image_encoder_config() -> PaliGemmaConfig:
|
105
|
+
return PaliGemmaConfig(
|
106
|
+
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
107
|
+
decoder_config=decoder.get_fake_decoder_config(),
|
108
|
+
image_projection_use_bias=True,
|
109
|
+
image_token_id=257152,
|
110
|
+
)
|
111
|
+
|
112
|
+
|
113
|
+
def build_model(checkpoint_path: str) -> PaliGemma:
|
114
|
+
config = get_model_config()
|
115
|
+
model = PaliGemma(config)
|
116
|
+
# Load the parameters of image encoder.
|
117
|
+
loader = loading_utils.ModelLoader(
|
118
|
+
checkpoint_path, image_encoder.TENSOR_NAMES
|
119
|
+
)
|
120
|
+
loader.load(model.image_encoder, strict=False)
|
121
|
+
# Load the parameters of decoder.
|
122
|
+
loader = loading_utils.ModelLoader(checkpoint_path, decoder.TENSOR_NAMES)
|
123
|
+
loader.load(model.decoder, strict=False)
|
124
|
+
|
125
|
+
# Load the parameters of image projection.
|
126
|
+
loader = loading_utils.ModelLoader(checkpoint_path, None)
|
127
|
+
state = loader.get_state()
|
128
|
+
converted_state = dict()
|
129
|
+
converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
|
130
|
+
if config.image_projection_use_bias:
|
131
|
+
converted_state["bias"] = state.pop(f"{PROJECTION_TENSOR_NAME}.bias")
|
132
|
+
model.image_projection.load_state_dict(converted_state)
|
133
|
+
|
134
|
+
model.eval()
|
135
|
+
return model
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored PaliGemma 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache
|
24
|
+
from ai_edge_torch.generative.utilities import verifier
|
25
|
+
from PIL import Image
|
26
|
+
import requests
|
27
|
+
import torch
|
28
|
+
import transformers
|
29
|
+
|
30
|
+
_IMAGE_URL = flags.DEFINE_string(
|
31
|
+
"image_url",
|
32
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
33
|
+
"The image URI to encode.",
|
34
|
+
)
|
35
|
+
_PROMPTS = flags.DEFINE_string(
|
36
|
+
"prompts",
|
37
|
+
"Caption en",
|
38
|
+
"The input prompts to generate answers.",
|
39
|
+
)
|
40
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
41
|
+
"max_new_tokens",
|
42
|
+
30,
|
43
|
+
"The maximum size of the generated tokens.",
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class ReauthoredPaliGemmaWrapper(verifier.ReauthoredModelWrapper):
|
48
|
+
"""Reauthored PaliGemma model wrapper."""
|
49
|
+
|
50
|
+
def _init_kv_cache(self):
|
51
|
+
return kv_cache.KVCache.from_model_config(self.model.config.decoder_config)
|
52
|
+
|
53
|
+
|
54
|
+
def main(_):
|
55
|
+
checkpoint = "google/paligemma-3b-mix-224"
|
56
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
57
|
+
original_model = (
|
58
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
59
|
+
)
|
60
|
+
|
61
|
+
# Locate the cached dir.
|
62
|
+
cached_config_file = transformers.utils.cached_file(
|
63
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
64
|
+
)
|
65
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
66
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
67
|
+
reauthored_model = paligemma.build_model(reauthored_checkpoint)
|
68
|
+
|
69
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
70
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
71
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
72
|
+
# sentencepiece model file properly.
|
73
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
74
|
+
|
75
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
76
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
77
|
+
inputs = processor(text=_PROMPTS.value, images=image, return_tensors="pt")
|
78
|
+
|
79
|
+
logging.info("Verifying the reauthored model with model.forward()...")
|
80
|
+
logging.info("Forwarding the original model...")
|
81
|
+
outputs_original = original_model.forward(
|
82
|
+
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"]
|
83
|
+
)
|
84
|
+
outputs_original = outputs_original.logits
|
85
|
+
logging.info("outputs_original: %s", outputs_original)
|
86
|
+
|
87
|
+
logging.info("Forwarding the reauthored model...")
|
88
|
+
wrapped_reauthored_model = ReauthoredPaliGemmaWrapper(reauthored_model)
|
89
|
+
outputs_reauthored = wrapped_reauthored_model.forward(
|
90
|
+
tokens=inputs["input_ids"],
|
91
|
+
pixel_values=inputs["pixel_values"],
|
92
|
+
)
|
93
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
94
|
+
|
95
|
+
try:
|
96
|
+
assert torch.allclose(outputs_original, outputs_reauthored, atol=1e-03)
|
97
|
+
except AssertionError as e:
|
98
|
+
logging.error("*** FAILED *** verify with forward()")
|
99
|
+
raise e
|
100
|
+
else:
|
101
|
+
logging.info("*** PASSED *** verify with forward()")
|
102
|
+
|
103
|
+
logging.info("Verifying the reauthored model with model.generate()...")
|
104
|
+
logging.info("Generating answer with the original model...")
|
105
|
+
outputs_original = original_model.generate(
|
106
|
+
**inputs, max_new_tokens=_MAX_NEW_TOKENS.value, do_sample=False
|
107
|
+
)
|
108
|
+
response_original = processor.decode(
|
109
|
+
outputs_original[0], skip_special_tokens=True
|
110
|
+
)
|
111
|
+
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
112
|
+
|
113
|
+
logging.info("Generating answer with the reauthored model...")
|
114
|
+
outputs_reauthored = wrapped_reauthored_model.generate(
|
115
|
+
prompts=inputs["input_ids"],
|
116
|
+
pixel_values=inputs["pixel_values"],
|
117
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
118
|
+
)
|
119
|
+
response_reauthored = processor.decode(
|
120
|
+
outputs_reauthored[0], skip_special_tokens=True
|
121
|
+
)
|
122
|
+
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
123
|
+
|
124
|
+
try:
|
125
|
+
assert response_original == response_reauthored
|
126
|
+
except AssertionError as e:
|
127
|
+
logging.error("*** FAILED *** verify with generate()")
|
128
|
+
raise e
|
129
|
+
else:
|
130
|
+
logging.info("*** PASSED *** verify with generate()")
|
131
|
+
|
132
|
+
|
133
|
+
if __name__ == "__main__":
|
134
|
+
app.run(main)
|
@@ -131,6 +131,9 @@ class ModelLoader:
|
|
131
131
|
self._names = names
|
132
132
|
self._loader = self._get_loader()
|
133
133
|
|
134
|
+
def get_state(self) -> Dict[str, torch.Tensor]:
|
135
|
+
return self._loader(self._file_name)
|
136
|
+
|
134
137
|
def load(
|
135
138
|
self, model: torch.nn.Module, strict: bool = True
|
136
139
|
) -> Tuple[List[str], List[str]]:
|
@@ -150,7 +153,7 @@ class ModelLoader:
|
|
150
153
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
151
154
|
enabled.
|
152
155
|
"""
|
153
|
-
state = self.
|
156
|
+
state = self.get_state()
|
154
157
|
state = state["model_state_dict"] if "model_state_dict" in state else state
|
155
158
|
converted_state = dict()
|
156
159
|
if self._names.embedding is not None:
|
@@ -16,6 +16,7 @@
|
|
16
16
|
"""Utilities to be used for re-authoring transformer models."""
|
17
17
|
|
18
18
|
import copy
|
19
|
+
from typing import Tuple
|
19
20
|
|
20
21
|
from ai_edge_torch.generative.layers import attention
|
21
22
|
from ai_edge_torch.generative.layers import builder
|
@@ -98,26 +99,40 @@ class DecoderOnlyModel(nn.Module):
|
|
98
99
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
99
100
|
f" {self.config.max_seq_len}"
|
100
101
|
)
|
101
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
102
|
-
"The number of transformer blocks and the number of KV cache entries"
|
103
|
-
" must be the same."
|
104
|
-
)
|
105
102
|
|
103
|
+
# token embeddings of shape (b, t, n_embd)
|
104
|
+
input_embeds = self.tok_embedding(tokens)
|
106
105
|
cos, sin = self.rope_cache
|
107
|
-
|
108
|
-
sin = sin.index_select(0, input_pos)
|
106
|
+
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
|
109
107
|
mask = self.mask_cache.index_select(2, input_pos)
|
110
108
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
111
109
|
|
112
|
-
|
113
|
-
|
110
|
+
return self.forward_with_embeds(
|
111
|
+
input_embeds, rope, mask, input_pos, kv_cache
|
112
|
+
)
|
113
|
+
|
114
|
+
def forward_with_embeds(
|
115
|
+
self,
|
116
|
+
input_embeds: torch.Tensor,
|
117
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
118
|
+
mask: torch.Tensor,
|
119
|
+
input_pos: torch.Tensor,
|
120
|
+
kv_cache: kv_utils.KVCache,
|
121
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
122
|
+
"""Forwards the model with input embeddings."""
|
123
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
124
|
+
"The number of transformer blocks and the number of KV cache entries"
|
125
|
+
" must be the same."
|
126
|
+
)
|
127
|
+
|
128
|
+
x = input_embeds
|
114
129
|
if self.config.embedding_scale is not None:
|
115
130
|
x = x * self.config.embedding_scale
|
116
131
|
|
117
132
|
updated_kv_entires = []
|
118
133
|
for i, block in enumerate(self.transformer_blocks):
|
119
134
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
120
|
-
x, kv_entry = block(x,
|
135
|
+
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
|
121
136
|
if kv_entry:
|
122
137
|
updated_kv_entires.append(kv_entry)
|
123
138
|
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
@@ -41,7 +41,9 @@ class ModelWrapper(torch.nn.Module):
|
|
41
41
|
super().__init__()
|
42
42
|
self.model = model
|
43
43
|
|
44
|
-
def forward(
|
44
|
+
def forward(
|
45
|
+
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
46
|
+
) -> torch.Tensor:
|
45
47
|
"""Gets output logits by forwarding the input tokens.
|
46
48
|
|
47
49
|
Args:
|
@@ -54,7 +56,10 @@ class ModelWrapper(torch.nn.Module):
|
|
54
56
|
raise NotImplementedError("forward() is not implemented.")
|
55
57
|
|
56
58
|
def generate(
|
57
|
-
self,
|
59
|
+
self,
|
60
|
+
prompts: torch.Tensor,
|
61
|
+
max_new_tokens: int,
|
62
|
+
pixel_values: torch.Tensor = None,
|
58
63
|
) -> torch.IntTensor:
|
59
64
|
"""Returns the response token IDs to the given prompts tensor.
|
60
65
|
|
@@ -83,35 +88,59 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
83
88
|
def _forward_with_kv_cache(
|
84
89
|
self,
|
85
90
|
tokens: torch.Tensor,
|
91
|
+
input_pos: torch.Tensor,
|
86
92
|
kv_cache: kv_utils.KVCache,
|
93
|
+
pixel_values: torch.Tensor,
|
87
94
|
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
88
95
|
"""Forwards the model and updates an external KV cache.
|
89
96
|
|
90
97
|
Args:
|
91
98
|
tokens (torch.Tensor): The input tokens to forward.
|
99
|
+
input_pos (torch.Tensor): The input positions to forward.
|
92
100
|
kv_cache (KVCache): The KV cache to forward.
|
101
|
+
pixel_values (torch.Tensor): The input pixel values to forward.
|
93
102
|
|
94
103
|
Returns:
|
95
104
|
The output logits and the updated KV cache.
|
96
105
|
"""
|
97
|
-
|
98
|
-
|
106
|
+
# Since the reauthored model doesn't include keyword arguments, pass
|
107
|
+
# pixel_values only when it is not None. Otherwise, it may raise an error.
|
108
|
+
if pixel_values is None:
|
109
|
+
output = self.model.forward(tokens, input_pos, kv_cache)
|
110
|
+
else:
|
111
|
+
output = self.model.forward(
|
112
|
+
tokens, input_pos, kv_cache, pixel_values=pixel_values
|
113
|
+
)
|
99
114
|
return output["logits"], output["kv_cache"]
|
100
115
|
|
101
|
-
def forward(
|
102
|
-
|
116
|
+
def forward(
|
117
|
+
self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
|
118
|
+
) -> torch.Tensor:
|
119
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
120
|
+
logits, _ = self._forward_with_kv_cache(
|
121
|
+
tokens, input_pos, self._init_kv_cache(), pixel_values
|
122
|
+
)
|
103
123
|
return logits
|
104
124
|
|
105
125
|
def generate(
|
106
|
-
self,
|
126
|
+
self,
|
127
|
+
prompts: torch.Tensor,
|
128
|
+
max_new_tokens: int,
|
129
|
+
pixel_values: torch.Tensor = None,
|
107
130
|
) -> torch.IntTensor:
|
108
131
|
input_ids = prompts[0].int().tolist()
|
132
|
+
tokens = torch.tensor([input_ids])
|
133
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
109
134
|
kv_cache = self._init_kv_cache()
|
110
135
|
for _ in range(max_new_tokens):
|
111
|
-
|
112
|
-
|
136
|
+
logits, kv_cache = self._forward_with_kv_cache(
|
137
|
+
tokens, input_pos, kv_cache, pixel_values
|
138
|
+
)
|
113
139
|
generated_token = logits[0][-1].argmax().item()
|
114
140
|
input_ids.append(generated_token)
|
141
|
+
tokens = torch.tensor([[generated_token]])
|
142
|
+
input_pos = torch.tensor([len(input_ids) - 1])
|
143
|
+
pixel_values = None # Pass only for the first time.
|
115
144
|
return torch.tensor([input_ids])
|
116
145
|
|
117
146
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241116
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=VA2R7z515pfD79tg2AjlwXASYb6LSz0-kch5NJzdj3k,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -61,8 +61,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
|
|
61
61
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
62
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
63
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
|
-
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=
|
64
|
+
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
|
65
65
|
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
|
66
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=mbq9CBp2znXPIQdzIQTiQGRh4Ql3bn9kyX-k_LXKTms,4537
|
67
|
+
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
66
68
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
67
69
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
68
70
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -139,12 +141,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
139
141
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
140
142
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
141
143
|
ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
|
142
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
143
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
144
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
145
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
144
146
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
145
147
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
146
148
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
147
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
149
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=GLh7h8pcpSKtCKoPyxJhv3TmvENd2h6ek_cnbe2s3Ak,11418
|
148
150
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
149
151
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
150
152
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -191,8 +193,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
191
193
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
192
194
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
193
195
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
194
|
-
ai_edge_torch_nightly-0.3.0.
|
195
|
-
ai_edge_torch_nightly-0.3.0.
|
196
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
196
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/METADATA,sha256=OyMmJ6EACAhEKbHNgLaGAogbjR8DwCLHYfIDdKW7iMI,1897
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241116.dist-info/RECORD,,
|
File without changes
|
File without changes
|