litert-torch-nightly 0.8.0.dev20260126__py3-none-any.whl → 0.9.0.dev20260128__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- litert_torch/cli.py +2 -2
- litert_torch/generative/export_hf/core/export_lib.py +16 -1
- litert_torch/generative/export_hf/core/litert_lm_builder.py +24 -16
- litert_torch/generative/export_hf/core/utils.py +34 -0
- litert_torch/version.py +1 -1
- {litert_torch_nightly-0.8.0.dev20260126.dist-info → litert_torch_nightly-0.9.0.dev20260128.dist-info}/METADATA +5 -4
- {litert_torch_nightly-0.8.0.dev20260126.dist-info → litert_torch_nightly-0.9.0.dev20260128.dist-info}/RECORD +11 -11
- {litert_torch_nightly-0.8.0.dev20260126.dist-info → litert_torch_nightly-0.9.0.dev20260128.dist-info}/WHEEL +0 -0
- {litert_torch_nightly-0.8.0.dev20260126.dist-info → litert_torch_nightly-0.9.0.dev20260128.dist-info}/entry_points.txt +0 -0
- {litert_torch_nightly-0.8.0.dev20260126.dist-info → litert_torch_nightly-0.9.0.dev20260128.dist-info}/licenses/LICENSE +0 -0
- {litert_torch_nightly-0.8.0.dev20260126.dist-info → litert_torch_nightly-0.9.0.dev20260128.dist-info}/top_level.txt +0 -0
litert_torch/cli.py
CHANGED
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
# This is experimental and subject to change.
|
|
18
18
|
|
|
19
|
-
from litert_torch.generative.export_hf import export as hf_export_lib
|
|
20
19
|
import fire
|
|
20
|
+
from litert_torch.generative.export_hf import export as hf_export_lib
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class CLI:
|
|
@@ -26,7 +26,7 @@ class CLI:
|
|
|
26
26
|
self.hf_export = hf_export_lib.export
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
def main(
|
|
29
|
+
def main():
|
|
30
30
|
fire.Fire(CLI())
|
|
31
31
|
|
|
32
32
|
|
|
@@ -16,6 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
|
+
|
|
20
|
+
import huggingface_hub
|
|
19
21
|
from litert_torch import fx_infra
|
|
20
22
|
from litert_torch._convert import converter as converter_utils
|
|
21
23
|
from litert_torch.generative.export_hf.core import attention as _
|
|
@@ -86,7 +88,7 @@ def load_model(
|
|
|
86
88
|
|
|
87
89
|
config = transformers.AutoConfig.from_pretrained(
|
|
88
90
|
model_path,
|
|
89
|
-
|
|
91
|
+
dtype=torch.float32,
|
|
90
92
|
trust_remote_code=trust_remote_code,
|
|
91
93
|
)
|
|
92
94
|
config._attn_implementation = 'lrt_transposed_attention' # pylint: disable=protected-access
|
|
@@ -111,7 +113,20 @@ def load_model(
|
|
|
111
113
|
|
|
112
114
|
verify_model_compatibility(model, config, text_model_config)
|
|
113
115
|
|
|
116
|
+
# TODO(weiyiw): Refactor into a separate function.
|
|
114
117
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
|
118
|
+
if not hasattr(tokenizer, 'chat_template') or not tokenizer.chat_template:
|
|
119
|
+
try:
|
|
120
|
+
if utils.get_model_path_type(model_path) == 'repo_id':
|
|
121
|
+
template_file = huggingface_hub.hf_hub_download(
|
|
122
|
+
model_path, filename='chat_template.json'
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
template_file = os.path.join(model_path, 'chat_template.json')
|
|
126
|
+
with open(template_file, 'rt') as f:
|
|
127
|
+
tokenizer.chat_template = f.read()
|
|
128
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
129
|
+
print(f'Failed to load chat template: {e}')
|
|
115
130
|
|
|
116
131
|
return model, config, text_model_config, tokenizer
|
|
117
132
|
|
|
@@ -27,7 +27,7 @@ _PH = 'KIMAIRA'
|
|
|
27
27
|
def parse_chat_template(tokenizer):
|
|
28
28
|
"""Parses chat template."""
|
|
29
29
|
if tokenizer.chat_template is None:
|
|
30
|
-
return
|
|
30
|
+
return None
|
|
31
31
|
try:
|
|
32
32
|
messages = [
|
|
33
33
|
{'role': 'system', 'content': _PH},
|
|
@@ -39,6 +39,10 @@ def parse_chat_template(tokenizer):
|
|
|
39
39
|
add_generation_prompt=False,
|
|
40
40
|
)
|
|
41
41
|
sys_prompt_parts = sys_prompt.split(_PH)
|
|
42
|
+
no_sys_prompt = False
|
|
43
|
+
if len(sys_prompt_parts) == 1:
|
|
44
|
+
sys_prompt_parts = [sys_prompt_parts[0], '']
|
|
45
|
+
no_sys_prompt = True
|
|
42
46
|
if len(sys_prompt_parts) != 2:
|
|
43
47
|
raise ValueError(
|
|
44
48
|
f'System prompt {_PH} not found in chat template: {sys_prompt}'
|
|
@@ -46,7 +50,10 @@ def parse_chat_template(tokenizer):
|
|
|
46
50
|
if sys_prompt_parts[0].startswith(str(tokenizer.bos_token)):
|
|
47
51
|
sys_prompt_parts[0] = sys_prompt_parts[0][len(tokenizer.bos_token) :]
|
|
48
52
|
|
|
49
|
-
|
|
53
|
+
if no_sys_prompt:
|
|
54
|
+
messages = [{'role': 'user', 'content': _PH}]
|
|
55
|
+
else:
|
|
56
|
+
messages.append({'role': 'user', 'content': _PH})
|
|
50
57
|
user_prompt = tokenizer.apply_chat_template(
|
|
51
58
|
messages,
|
|
52
59
|
tokenize=False,
|
|
@@ -133,20 +140,21 @@ def build_llm_metadata(
|
|
|
133
140
|
if gen_config.temperature:
|
|
134
141
|
sampler_params.temperature = gen_config.temperature
|
|
135
142
|
|
|
136
|
-
if
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
fld
|
|
149
|
-
|
|
143
|
+
if chat_templates is not None:
|
|
144
|
+
if isinstance(chat_templates, str):
|
|
145
|
+
llm_metadata.jinja_prompt_template = chat_templates
|
|
146
|
+
else:
|
|
147
|
+
sys_prompt_parts, user_prompt_parts, model_prompt_parts = chat_templates
|
|
148
|
+
pairs = []
|
|
149
|
+
if sys_prompt_parts[0] is not None:
|
|
150
|
+
pairs.append((sys_prompt_parts, llm_metadata.prompt_templates.system))
|
|
151
|
+
if user_prompt_parts[0] is not None:
|
|
152
|
+
pairs.append((user_prompt_parts, llm_metadata.prompt_templates.user))
|
|
153
|
+
if model_prompt_parts[0] is not None:
|
|
154
|
+
pairs.append((model_prompt_parts, llm_metadata.prompt_templates.model))
|
|
155
|
+
for pts, fld in pairs:
|
|
156
|
+
fld.prefix = pts[0]
|
|
157
|
+
fld.suffix = pts[1]
|
|
150
158
|
|
|
151
159
|
llm_metadata.max_num_tokens = context_length
|
|
152
160
|
|
|
@@ -14,6 +14,8 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Utility functions."""
|
|
16
16
|
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
17
19
|
import torch
|
|
18
20
|
|
|
19
21
|
|
|
@@ -89,3 +91,35 @@ def has_sliding_attention(model):
|
|
|
89
91
|
return False
|
|
90
92
|
layer_types = getattr(model.config, 'layer_types', None)
|
|
91
93
|
return layer_types is not None and 'sliding_attention' in layer_types
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_model_path_type(path_str: str) -> str:
|
|
97
|
+
"""Determines if a string is a local path or a Hugging Face Repo ID.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
path_str: The string to check.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
"local": If the path exists on disk.
|
|
104
|
+
"repo_id": If it looks like a Hub ID (e.g., 'meta-llama/Llama-2-7b').
|
|
105
|
+
"local_not_found": If it looks like a file path but doesn't exist.
|
|
106
|
+
"unknown": If it matches neither pattern clearly.
|
|
107
|
+
"""
|
|
108
|
+
# 1. Absolute Truth: Does it exist on the disk?
|
|
109
|
+
if os.path.exists(path_str):
|
|
110
|
+
return 'local'
|
|
111
|
+
|
|
112
|
+
# 2. Heuristic: Does it have explicit path markers?
|
|
113
|
+
# Starts with "./", "/", "~", or contains Windows backslashes
|
|
114
|
+
if path_str.startswith(('.', '/', '~')) or '\\' in path_str:
|
|
115
|
+
return 'local_not_found'
|
|
116
|
+
|
|
117
|
+
# 3. Heuristic: Does it look like a Repo ID?
|
|
118
|
+
# Pattern: username/repo_name (e.g. "mistralai/Mistral-7B")
|
|
119
|
+
# or just repo_name for official models (e.g. "gpt2", "bert-base-uncased")
|
|
120
|
+
# Allowed chars: Alphanumeric, underscores, hyphens, periods.
|
|
121
|
+
repo_id_pattern = r'^(?:[\w\-\.]+\/)?[\w\-\.]+$'
|
|
122
|
+
if re.match(repo_id_pattern, path_str):
|
|
123
|
+
return 'repo_id'
|
|
124
|
+
|
|
125
|
+
return 'unknown'
|
litert_torch/version.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: litert-torch-nightly
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary:
|
|
3
|
+
Version: 0.9.0.dev20260128
|
|
4
|
+
Summary: Support PyTorch model conversion with LiteRT.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/litert-torch
|
|
6
|
-
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
6
|
+
Keywords: On-Device ML,AI,Google,TFLite,LiteRT,PyTorch,LLMs,GenAI
|
|
7
7
|
Classifier: Development Status :: 4 - Beta
|
|
8
8
|
Classifier: Intended Audience :: Developers
|
|
9
9
|
Classifier: Intended Audience :: Education
|
|
@@ -41,6 +41,7 @@ Requires-Dist: jax
|
|
|
41
41
|
Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
|
|
42
42
|
Requires-Dist: jaxtyping
|
|
43
43
|
Requires-Dist: fire
|
|
44
|
+
Requires-Dist: sentencepiece
|
|
44
45
|
Provides-Extra: torch-xla
|
|
45
46
|
Requires-Dist: torch_xla>=2.4.0; extra == "torch-xla"
|
|
46
47
|
Dynamic: classifier
|
|
@@ -55,7 +56,7 @@ Dynamic: requires-python
|
|
|
55
56
|
Dynamic: summary
|
|
56
57
|
|
|
57
58
|
Library that supports converting PyTorch models into a .tflite format, which can
|
|
58
|
-
then be run with
|
|
59
|
+
then be run with LiteRT. This enables applications for
|
|
59
60
|
Android, iOS and IOT that can run models completely on-device.
|
|
60
61
|
|
|
61
62
|
[Install steps](https://github.com/google-ai-edge/litert-torch#installation)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
litert_torch/__init__.py,sha256=jgEpTtwnhxMhPGPTRvHJR7pzx6WU_mLbA-G1LjO7fnE,1279
|
|
2
2
|
litert_torch/_config.py,sha256=zDnki83sBsQzDAea6bvzwccylWHnPUzbEyGGRh6B14w,2526
|
|
3
|
-
litert_torch/cli.py,sha256=
|
|
3
|
+
litert_torch/cli.py,sha256=TiguLo2O3_wY8cCnKnbXtUvGH4lzyjeSgsOnHsKd9Gg,984
|
|
4
4
|
litert_torch/conftest.py,sha256=gYmFrsR4c_fjIidbyrDnek26yS0crDP6-UoyMvy-WFg,757
|
|
5
5
|
litert_torch/model.py,sha256=KXFTyyfPM6AnP0JoSwsTqQR3lUQbMkTGSr3dUsfQ5Jk,5635
|
|
6
|
-
litert_torch/version.py,sha256=
|
|
6
|
+
litert_torch/version.py,sha256=mHpimDjWXUBj0_q1ks2U5uwQ-hARKooCxFvQ3Hp4RPU,804
|
|
7
7
|
litert_torch/_convert/__init__.py,sha256=qdLdbj5NjhNG-QgY5O_8TzOr2XaDoWvmdY9JNPStQmw,670
|
|
8
8
|
litert_torch/_convert/conversion.py,sha256=NuQEphyYp3W19IKvyTWo9pe7zt1-XmWM4zU9PDkUm54,6108
|
|
9
9
|
litert_torch/_convert/conversion_utils.py,sha256=MWpB-3eN-rvQzTtXsPL30cDIK431SQuwvw3ia2K2ONM,2158
|
|
@@ -187,11 +187,11 @@ litert_torch/generative/export_hf/core/attention_test.py,sha256=RevOczfPncmbIBth
|
|
|
187
187
|
litert_torch/generative/export_hf/core/cache.py,sha256=pfWh2SACdhNY2of2Z8KJC0wrSQ2jrkXgPHWe7PSEiuU,10263
|
|
188
188
|
litert_torch/generative/export_hf/core/cache_base.py,sha256=FXMm9B8nDwC8uTyLmuBnYKLTnNtoeGN8gUnWwDCcH08,1714
|
|
189
189
|
litert_torch/generative/export_hf/core/cache_test.py,sha256=y-v-oOGtRNPGWRfIfW3FcpDxvJbzrBU6Pb2o66FkUzU,6203
|
|
190
|
-
litert_torch/generative/export_hf/core/export_lib.py,sha256=
|
|
190
|
+
litert_torch/generative/export_hf/core/export_lib.py,sha256=Nvg3QiYRZcMiQd7du7w5vohazjLlZJW6YFk_WAbIpAs,12249
|
|
191
191
|
litert_torch/generative/export_hf/core/exportable_module.py,sha256=XEqsV9M34OP-_vsxH7bnmxSCD6erAPl0a9I9JQM7v6k,8305
|
|
192
|
-
litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=
|
|
192
|
+
litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=f8Q2ifVyt65V-kRL0X9FRpQNKIer0R_Yx2lECZTMGPU,7965
|
|
193
193
|
litert_torch/generative/export_hf/core/patches.py,sha256=i1fzs0anIFbBH-Q_PwCtp9VKXy64olJKwnGpnJUjkEo,1815
|
|
194
|
-
litert_torch/generative/export_hf/core/utils.py,sha256=
|
|
194
|
+
litert_torch/generative/export_hf/core/utils.py,sha256=5Wgs9aAOKd2i8wmQF_IierLUuFG23v1T6zZPr-azQ7A,4018
|
|
195
195
|
litert_torch/generative/export_hf/core/external_emb/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
|
|
196
196
|
litert_torch/generative/export_hf/core/external_emb/exportable_module.py,sha256=mWn75lLms3BAeCTEvbkGZ2n4fxtwsqGA8PP4S8-JBdY,3058
|
|
197
197
|
litert_torch/generative/export_hf/core/external_rope/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
|
|
@@ -318,9 +318,9 @@ litert_torch/testing/__init__.py,sha256=AfYP1HwTYSQmupveonEHCDV5dEyshzUgbwUrCUhb
|
|
|
318
318
|
litert_torch/testing/export.py,sha256=3dR6oxnrdtX0MfqAfMv233cf3sHA4e0F2TBQotoo8xc,3292
|
|
319
319
|
litert_torch/testing/model_coverage/__init__.py,sha256=uPXeAhWiD1O0aMDLCX7FTOSNQiea8yOtoIYPCuHEAG4,763
|
|
320
320
|
litert_torch/testing/model_coverage/model_coverage.py,sha256=EPCI7PbNPb7GV28lo3qQvFdzJwJ_ZDrbCGdpeiBZhVo,4715
|
|
321
|
-
litert_torch_nightly-0.
|
|
322
|
-
litert_torch_nightly-0.
|
|
323
|
-
litert_torch_nightly-0.
|
|
324
|
-
litert_torch_nightly-0.
|
|
325
|
-
litert_torch_nightly-0.
|
|
326
|
-
litert_torch_nightly-0.
|
|
321
|
+
litert_torch_nightly-0.9.0.dev20260128.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
322
|
+
litert_torch_nightly-0.9.0.dev20260128.dist-info/METADATA,sha256=F-LavniKQ7N46UMAKuq-94y75WvGAlcbfwEcjrKVUTQ,2463
|
|
323
|
+
litert_torch_nightly-0.9.0.dev20260128.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
324
|
+
litert_torch_nightly-0.9.0.dev20260128.dist-info/entry_points.txt,sha256=roYAi9hp0uYrMudMR59hGNF2pz0TSAtqNl4vQLJzxnE,55
|
|
325
|
+
litert_torch_nightly-0.9.0.dev20260128.dist-info/top_level.txt,sha256=mGrsl2SYcjQSLBJX4ZXrHnFqHZe6QLRR7uk0tLfzwfM,13
|
|
326
|
+
litert_torch_nightly-0.9.0.dev20260128.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|