litert-torch-nightly 0.9.0.dev20260204__py3-none-any.whl → 0.9.0.dev20260206__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- litert_torch/generative/export_hf/core/export_lib.py +18 -7
- litert_torch/generative/export_hf/core/exportable_module_config.py +1 -0
- litert_torch/generative/export_hf/core/external_emb/exportable_module.py +30 -0
- litert_torch/generative/export_hf/core/litert_lm_builder.py +2 -1
- litert_torch/generative/export_hf/core/patches.py +29 -0
- litert_torch/generative/export_hf/export.py +4 -0
- litert_torch/version.py +1 -1
- {litert_torch_nightly-0.9.0.dev20260204.dist-info → litert_torch_nightly-0.9.0.dev20260206.dist-info}/METADATA +1 -1
- {litert_torch_nightly-0.9.0.dev20260204.dist-info → litert_torch_nightly-0.9.0.dev20260206.dist-info}/RECORD +13 -13
- {litert_torch_nightly-0.9.0.dev20260204.dist-info → litert_torch_nightly-0.9.0.dev20260206.dist-info}/WHEEL +0 -0
- {litert_torch_nightly-0.9.0.dev20260204.dist-info → litert_torch_nightly-0.9.0.dev20260206.dist-info}/entry_points.txt +0 -0
- {litert_torch_nightly-0.9.0.dev20260204.dist-info → litert_torch_nightly-0.9.0.dev20260206.dist-info}/licenses/LICENSE +0 -0
- {litert_torch_nightly-0.9.0.dev20260204.dist-info → litert_torch_nightly-0.9.0.dev20260206.dist-info}/top_level.txt +0 -0
|
@@ -80,6 +80,7 @@ def load_model(
|
|
|
80
80
|
model_path: str,
|
|
81
81
|
trust_remote_code: bool = False,
|
|
82
82
|
auto_model_override: str | None = None,
|
|
83
|
+
task: str = 'text_generation',
|
|
83
84
|
):
|
|
84
85
|
"""Loads model from checkpoint."""
|
|
85
86
|
|
|
@@ -90,7 +91,12 @@ def load_model(
|
|
|
90
91
|
)
|
|
91
92
|
config._attn_implementation = 'lrt_transposed_attention' # pylint: disable=protected-access
|
|
92
93
|
|
|
93
|
-
|
|
94
|
+
if task == 'text_generation':
|
|
95
|
+
auto_model_cls = transformers.AutoModelForCausalLM
|
|
96
|
+
elif task == 'image_text_to_text':
|
|
97
|
+
auto_model_cls = transformers.AutoModelForImageTextToText
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError(f'Unsupported task: {task}')
|
|
94
100
|
if auto_model_override is not None:
|
|
95
101
|
auto_model_cls = transformers.__dict__[auto_model_override]
|
|
96
102
|
|
|
@@ -101,14 +107,16 @@ def load_model(
|
|
|
101
107
|
trust_remote_code=trust_remote_code,
|
|
102
108
|
)
|
|
103
109
|
|
|
104
|
-
|
|
105
|
-
|
|
110
|
+
if task == 'text_generation':
|
|
111
|
+
model.generation_config.cache_implementation = 'static'
|
|
112
|
+
model.generation_config.do_sample = False
|
|
106
113
|
|
|
107
114
|
text_model_config = config
|
|
108
115
|
if hasattr(config, 'text_config'):
|
|
109
116
|
text_model_config = config.text_config
|
|
110
117
|
|
|
111
|
-
|
|
118
|
+
if task == 'text_generation':
|
|
119
|
+
verify_model_compatibility(model, config, text_model_config)
|
|
112
120
|
|
|
113
121
|
# TODO(weiyiw): Refactor into a separate function.
|
|
114
122
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
|
@@ -326,7 +334,7 @@ def export_embedder_model(
|
|
|
326
334
|
sample_kwargs=sample_inputs,
|
|
327
335
|
)
|
|
328
336
|
lrt_model = converter.convert(strict_export=False)
|
|
329
|
-
model_path = os.path.join(work_dir, '
|
|
337
|
+
model_path = os.path.join(work_dir, 'embedder.tflite')
|
|
330
338
|
lrt_model.export(model_path)
|
|
331
339
|
quantization_recipe_list = (
|
|
332
340
|
quantization_recipe.split(',') if quantization_recipe else [None]
|
|
@@ -359,7 +367,10 @@ def export_auxiliary_model(
|
|
|
359
367
|
sample_kwargs=sample_input,
|
|
360
368
|
)
|
|
361
369
|
# Attention Mask
|
|
362
|
-
attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(
|
|
370
|
+
attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(
|
|
371
|
+
export_config.cache_length,
|
|
372
|
+
# TODO(weiyiw): Add sliding window sizes.
|
|
373
|
+
)
|
|
363
374
|
sample_inputs = attention_mask_module.get_sample_inputs(
|
|
364
375
|
text_model_config, export_config
|
|
365
376
|
)
|
|
@@ -370,7 +381,7 @@ def export_auxiliary_model(
|
|
|
370
381
|
sample_kwargs=sample_input,
|
|
371
382
|
)
|
|
372
383
|
# Cache Update
|
|
373
|
-
cache_update_module = split_cache_module.CacheUpdate(
|
|
384
|
+
cache_update_module = split_cache_module.CacheUpdate()
|
|
374
385
|
sample_inputs = cache_update_module.get_sample_inputs(
|
|
375
386
|
text_model_config, export_config
|
|
376
387
|
)
|
|
@@ -94,3 +94,33 @@ class LiteRTExportableModuleForEmbedder(torch.nn.Module):
|
|
|
94
94
|
token_ids = torch.maximum(token_ids, torch.tensor(0, dtype=torch.int32))
|
|
95
95
|
output = self.model(token_ids)
|
|
96
96
|
return {"embeddings": output}
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def get_sample_inputs(
|
|
100
|
+
cls,
|
|
101
|
+
model_config,
|
|
102
|
+
export_config: base_exportable_module.ExportableModuleConfig,
|
|
103
|
+
):
|
|
104
|
+
"""Gets sample inputs."""
|
|
105
|
+
batch_size = export_config.batch_size
|
|
106
|
+
prefill_length = export_config.prefill_lengths[0]
|
|
107
|
+
prefill_length_dim = export_config.prefill_length_dim
|
|
108
|
+
del model_config # Unused.
|
|
109
|
+
tokens = {"token_ids": torch.ones((batch_size, 1), dtype=torch.int32)}
|
|
110
|
+
tokens_dynamic_shape = {"token_ids": {1: 1}} if prefill_length_dim else {}
|
|
111
|
+
if export_config.single_token_embedder:
|
|
112
|
+
return {"embedder": (tokens, tokens_dynamic_shape)}
|
|
113
|
+
else:
|
|
114
|
+
ret = {}
|
|
115
|
+
ret["decode_embedder"] = (tokens, tokens_dynamic_shape)
|
|
116
|
+
|
|
117
|
+
tokens = {
|
|
118
|
+
"token_ids": torch.ones(
|
|
119
|
+
(batch_size, prefill_length), dtype=torch.int32
|
|
120
|
+
)
|
|
121
|
+
}
|
|
122
|
+
tokens_dynamic_shape = (
|
|
123
|
+
{"token_ids": {1: prefill_length_dim}} if prefill_length_dim else {}
|
|
124
|
+
)
|
|
125
|
+
ret[f"prefill_embedder_{prefill_length}"] = (tokens, tokens_dynamic_shape)
|
|
126
|
+
return ret
|
|
@@ -118,7 +118,8 @@ def build_llm_metadata(
|
|
|
118
118
|
if isinstance(gen_config.eos_token_id, int):
|
|
119
119
|
stop_tokens.add(gen_config.eos_token_id)
|
|
120
120
|
elif isinstance(gen_config.eos_token_id, list):
|
|
121
|
-
|
|
121
|
+
for token_id in gen_config.eos_token_id:
|
|
122
|
+
stop_tokens.add(token_id)
|
|
122
123
|
elif hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
|
|
123
124
|
stop_tokens.add(tokenizer.eos_token)
|
|
124
125
|
for stop_token in stop_tokens:
|
|
@@ -60,3 +60,32 @@ original_use_kernel_forward_from_hub = (
|
|
|
60
60
|
transformers.integrations.use_kernel_forward_from_hub = (
|
|
61
61
|
_use_kernel_forward_from_hub
|
|
62
62
|
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# TODO(weiyiw): Find a better way to patch Gemma3RMSNorm.
|
|
66
|
+
class Gemma3RMSNorm(torch.nn.Module):
|
|
67
|
+
"""RMSNorm Layer."""
|
|
68
|
+
|
|
69
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
70
|
+
"""RMSNorm Layer."""
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.weight = torch.nn.Parameter(torch.ones(dim))
|
|
73
|
+
self.variance_epsilon = eps
|
|
74
|
+
self.hidden_size = dim
|
|
75
|
+
|
|
76
|
+
def forward(self, hidden_states):
|
|
77
|
+
return normalization.rms_norm_with_hlfb(
|
|
78
|
+
hidden_states,
|
|
79
|
+
self.weight + 1.0,
|
|
80
|
+
self.variance_epsilon,
|
|
81
|
+
torch.ones((self.hidden_size,), dtype=torch.float32),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def extra_repr(self):
|
|
85
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
from transformers.models.gemma3 import modeling_gemma3
|
|
89
|
+
|
|
90
|
+
original_gemma3_rms_norm = modeling_gemma3.Gemma3RMSNorm
|
|
91
|
+
modeling_gemma3.Gemma3RMSNorm = Gemma3RMSNorm
|
|
@@ -31,6 +31,7 @@ def export(
|
|
|
31
31
|
quantization_recipe: str = 'dynamic_wi8_afp32',
|
|
32
32
|
enable_dynamic_shape: bool = False,
|
|
33
33
|
externalize_embedder: bool = False,
|
|
34
|
+
single_token_embedder: bool = False,
|
|
34
35
|
key_ts_idx: int = 2,
|
|
35
36
|
value_ts_idx: int = 3,
|
|
36
37
|
split_cache: bool = False,
|
|
@@ -38,6 +39,7 @@ def export(
|
|
|
38
39
|
# target_accelerator: str | None = None,
|
|
39
40
|
trust_remote_code: bool = False,
|
|
40
41
|
use_jinja_template: bool = False,
|
|
42
|
+
task: str = 'text_generation',
|
|
41
43
|
):
|
|
42
44
|
"""Exports HuggingFace Transformers model to tflite."""
|
|
43
45
|
# TODO(weiyiw): Use tmp dir for work_dir.
|
|
@@ -47,6 +49,7 @@ def export(
|
|
|
47
49
|
model,
|
|
48
50
|
trust_remote_code=trust_remote_code,
|
|
49
51
|
auto_model_override=auto_model_override,
|
|
52
|
+
task=task,
|
|
50
53
|
)
|
|
51
54
|
del config # Unused.
|
|
52
55
|
if split_cache and not externalize_embedder:
|
|
@@ -62,6 +65,7 @@ def export(
|
|
|
62
65
|
if enable_dynamic_shape
|
|
63
66
|
else None,
|
|
64
67
|
externalize_embedder=externalize_embedder,
|
|
68
|
+
single_token_embedder=single_token_embedder,
|
|
65
69
|
k_ts_idx=key_ts_idx,
|
|
66
70
|
v_ts_idx=value_ts_idx,
|
|
67
71
|
split_cache=split_cache,
|
litert_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: litert-torch-nightly
|
|
3
|
-
Version: 0.9.0.
|
|
3
|
+
Version: 0.9.0.dev20260206
|
|
4
4
|
Summary: Support PyTorch model conversion with LiteRT.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/litert-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,LiteRT,PyTorch,LLMs,GenAI
|
|
@@ -3,7 +3,7 @@ litert_torch/_config.py,sha256=zDnki83sBsQzDAea6bvzwccylWHnPUzbEyGGRh6B14w,2526
|
|
|
3
3
|
litert_torch/cli.py,sha256=TiAUbgbWm3ecTUJtJ1_hjKJuC1LrG-Qwnm8_zws-sVY,984
|
|
4
4
|
litert_torch/conftest.py,sha256=gYmFrsR4c_fjIidbyrDnek26yS0crDP6-UoyMvy-WFg,757
|
|
5
5
|
litert_torch/model.py,sha256=KXFTyyfPM6AnP0JoSwsTqQR3lUQbMkTGSr3dUsfQ5Jk,5635
|
|
6
|
-
litert_torch/version.py,sha256=
|
|
6
|
+
litert_torch/version.py,sha256=20bJvaJMGX0olC644fZPyW86ti_tJxn-WOPVm9S1tZ4,804
|
|
7
7
|
litert_torch/_convert/__init__.py,sha256=qdLdbj5NjhNG-QgY5O_8TzOr2XaDoWvmdY9JNPStQmw,670
|
|
8
8
|
litert_torch/_convert/conversion.py,sha256=NuQEphyYp3W19IKvyTWo9pe7zt1-XmWM4zU9PDkUm54,6108
|
|
9
9
|
litert_torch/_convert/conversion_utils.py,sha256=MWpB-3eN-rvQzTtXsPL30cDIK431SQuwvw3ia2K2ONM,2158
|
|
@@ -179,7 +179,7 @@ litert_torch/generative/examples/tiny_llama/verify.py,sha256=6geA8OUOSj8_sTRyoo0
|
|
|
179
179
|
litert_torch/generative/examples/tiny_llama/verify_util.py,sha256=FKMC6Olex6bJbB8HXvC1KwxPbKgRBfT1CjoWcmyaPD8,2989
|
|
180
180
|
litert_torch/generative/export_hf/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
|
|
181
181
|
litert_torch/generative/export_hf/__main__.py,sha256=8VuBDkZ2sL-q2XdQ45qwzeHQk39-MM_6TdkxOU_23xE,782
|
|
182
|
-
litert_torch/generative/export_hf/export.py,sha256=
|
|
182
|
+
litert_torch/generative/export_hf/export.py,sha256=koqs0znGe9QXlEoRF7TuvDtrjbiXa79qgBhGK6MENwk,3800
|
|
183
183
|
litert_torch/generative/export_hf/export_main.py,sha256=bQidNXz0MEP_gil86LSfnpCW0pUiqZq2-F9ZOrSb3Yk,1183
|
|
184
184
|
litert_torch/generative/export_hf/core/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
|
|
185
185
|
litert_torch/generative/export_hf/core/attention.py,sha256=bXuTHNeVtKwWf6YXgb5I2j08vvgb9M7r1RYGvdjl9QI,4798
|
|
@@ -187,14 +187,14 @@ litert_torch/generative/export_hf/core/attention_test.py,sha256=KBSyYjHoTKYi6Se6
|
|
|
187
187
|
litert_torch/generative/export_hf/core/cache.py,sha256=UnuTBpJvplEyig1myrhA1d0QJ05pNJgWbm-GrsUu5Uk,11763
|
|
188
188
|
litert_torch/generative/export_hf/core/cache_base.py,sha256=6s-6L6iSa-qn0PLdAAhpHdOU9qwqEE-JVdlIsYyCPt4,2180
|
|
189
189
|
litert_torch/generative/export_hf/core/cache_test.py,sha256=y-v-oOGtRNPGWRfIfW3FcpDxvJbzrBU6Pb2o66FkUzU,6203
|
|
190
|
-
litert_torch/generative/export_hf/core/export_lib.py,sha256=
|
|
190
|
+
litert_torch/generative/export_hf/core/export_lib.py,sha256=qDOyLxCMeQtRbJQlLcI8Gq3PmL9yMJ2gUvPhoEcVty8,14516
|
|
191
191
|
litert_torch/generative/export_hf/core/exportable_module.py,sha256=niCS0na0VvFLiwebnL4JeXZh2hT8FCQmp-vnyTBh7pA,8257
|
|
192
|
-
litert_torch/generative/export_hf/core/exportable_module_config.py,sha256=
|
|
193
|
-
litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=
|
|
194
|
-
litert_torch/generative/export_hf/core/patches.py,sha256=
|
|
192
|
+
litert_torch/generative/export_hf/core/exportable_module_config.py,sha256=oJOWBBKWYpLq5A5qXEAIZbLwvCpY22nstHx6L88CXqU,1322
|
|
193
|
+
litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=ai-5Njn8fGKco_5jiRnmACBIKu1EL2b5SY5ArmsmttM,7998
|
|
194
|
+
litert_torch/generative/export_hf/core/patches.py,sha256=h4TCTNPT0N9xcMFfJ54XnpCHt1iKwS8mU-GhAxdsUrc,2636
|
|
195
195
|
litert_torch/generative/export_hf/core/utils.py,sha256=5Wgs9aAOKd2i8wmQF_IierLUuFG23v1T6zZPr-azQ7A,4018
|
|
196
196
|
litert_torch/generative/export_hf/core/external_emb/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
|
|
197
|
-
litert_torch/generative/export_hf/core/external_emb/exportable_module.py,sha256=
|
|
197
|
+
litert_torch/generative/export_hf/core/external_emb/exportable_module.py,sha256=1ke2mugD--1bIqeJhAJ4Ly7o6NRW8RZL79UzCqRHLNY,4113
|
|
198
198
|
litert_torch/generative/export_hf/core/external_rope/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
|
|
199
199
|
litert_torch/generative/export_hf/core/external_rope/exportable_module.py,sha256=czTf835b9Nw4XcDo6cd9chsmBdbIsdqMtnEkwuwMgX0,2478
|
|
200
200
|
litert_torch/generative/export_hf/core/external_rope/preprocess_model.py,sha256=NL3zROb7EZNAvZfutIhLk4KqXid_HklQMUjHZqZYOH4,1735
|
|
@@ -319,9 +319,9 @@ litert_torch/testing/__init__.py,sha256=AfYP1HwTYSQmupveonEHCDV5dEyshzUgbwUrCUhb
|
|
|
319
319
|
litert_torch/testing/export.py,sha256=3dR6oxnrdtX0MfqAfMv233cf3sHA4e0F2TBQotoo8xc,3292
|
|
320
320
|
litert_torch/testing/model_coverage/__init__.py,sha256=uPXeAhWiD1O0aMDLCX7FTOSNQiea8yOtoIYPCuHEAG4,763
|
|
321
321
|
litert_torch/testing/model_coverage/model_coverage.py,sha256=EPCI7PbNPb7GV28lo3qQvFdzJwJ_ZDrbCGdpeiBZhVo,4715
|
|
322
|
-
litert_torch_nightly-0.9.0.
|
|
323
|
-
litert_torch_nightly-0.9.0.
|
|
324
|
-
litert_torch_nightly-0.9.0.
|
|
325
|
-
litert_torch_nightly-0.9.0.
|
|
326
|
-
litert_torch_nightly-0.9.0.
|
|
327
|
-
litert_torch_nightly-0.9.0.
|
|
322
|
+
litert_torch_nightly-0.9.0.dev20260206.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
323
|
+
litert_torch_nightly-0.9.0.dev20260206.dist-info/METADATA,sha256=mRNKi6UzLyaT9u1_6oqV95e9vs3f_77JAhDxkBtU3Ao,2463
|
|
324
|
+
litert_torch_nightly-0.9.0.dev20260206.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
325
|
+
litert_torch_nightly-0.9.0.dev20260206.dist-info/entry_points.txt,sha256=roYAi9hp0uYrMudMR59hGNF2pz0TSAtqNl4vQLJzxnE,55
|
|
326
|
+
litert_torch_nightly-0.9.0.dev20260206.dist-info/top_level.txt,sha256=mGrsl2SYcjQSLBJX4ZXrHnFqHZe6QLRR7uk0tLfzwfM,13
|
|
327
|
+
litert_torch_nightly-0.9.0.dev20260206.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|