litert-torch-nightly 0.8.0.dev20260126__py3-none-any.whl → 0.9.0.dev20260127__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
litert_torch/cli.py CHANGED
@@ -16,8 +16,8 @@
16
16
 
17
17
  # This is experimental and subject to change.
18
18
 
19
- from litert_torch.generative.export_hf import export as hf_export_lib
20
19
  import fire
20
+ from litert_torch.generative.export_hf import export as hf_export_lib
21
21
 
22
22
 
23
23
  class CLI:
@@ -26,7 +26,7 @@ class CLI:
26
26
  self.hf_export = hf_export_lib.export
27
27
 
28
28
 
29
- def main(_):
29
+ def main():
30
30
  fire.Fire(CLI())
31
31
 
32
32
 
@@ -16,6 +16,8 @@
16
16
 
17
17
  import os
18
18
  import time
19
+
20
+ import huggingface_hub
19
21
  from litert_torch import fx_infra
20
22
  from litert_torch._convert import converter as converter_utils
21
23
  from litert_torch.generative.export_hf.core import attention as _
@@ -86,7 +88,7 @@ def load_model(
86
88
 
87
89
  config = transformers.AutoConfig.from_pretrained(
88
90
  model_path,
89
- torch_dtype=torch.float32,
91
+ dtype=torch.float32,
90
92
  trust_remote_code=trust_remote_code,
91
93
  )
92
94
  config._attn_implementation = 'lrt_transposed_attention' # pylint: disable=protected-access
@@ -111,7 +113,20 @@ def load_model(
111
113
 
112
114
  verify_model_compatibility(model, config, text_model_config)
113
115
 
116
+ # TODO(weiyiw): Refactor into a separate function.
114
117
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
118
+ if not hasattr(tokenizer, 'chat_template') or not tokenizer.chat_template:
119
+ try:
120
+ if utils.get_model_path_type(model_path) == 'repo_id':
121
+ template_file = huggingface_hub.hf_hub_download(
122
+ model_path, filename='chat_template.json'
123
+ )
124
+ else:
125
+ template_file = os.path.join(model_path, 'chat_template.json')
126
+ with open(template_file, 'rt') as f:
127
+ tokenizer.chat_template = f.read()
128
+ except Exception as e: # pylint: disable=broad-exception-caught
129
+ print(f'Failed to load chat template: {e}')
115
130
 
116
131
  return model, config, text_model_config, tokenizer
117
132
 
@@ -27,7 +27,7 @@ _PH = 'KIMAIRA'
27
27
  def parse_chat_template(tokenizer):
28
28
  """Parses chat template."""
29
29
  if tokenizer.chat_template is None:
30
- return (None, None), (None, None), (None, None)
30
+ return None
31
31
  try:
32
32
  messages = [
33
33
  {'role': 'system', 'content': _PH},
@@ -39,6 +39,10 @@ def parse_chat_template(tokenizer):
39
39
  add_generation_prompt=False,
40
40
  )
41
41
  sys_prompt_parts = sys_prompt.split(_PH)
42
+ no_sys_prompt = False
43
+ if len(sys_prompt_parts) == 1:
44
+ sys_prompt_parts = [sys_prompt_parts[0], '']
45
+ no_sys_prompt = True
42
46
  if len(sys_prompt_parts) != 2:
43
47
  raise ValueError(
44
48
  f'System prompt {_PH} not found in chat template: {sys_prompt}'
@@ -46,7 +50,10 @@ def parse_chat_template(tokenizer):
46
50
  if sys_prompt_parts[0].startswith(str(tokenizer.bos_token)):
47
51
  sys_prompt_parts[0] = sys_prompt_parts[0][len(tokenizer.bos_token) :]
48
52
 
49
- messages.append({'role': 'user', 'content': _PH})
53
+ if no_sys_prompt:
54
+ messages = [{'role': 'user', 'content': _PH}]
55
+ else:
56
+ messages.append({'role': 'user', 'content': _PH})
50
57
  user_prompt = tokenizer.apply_chat_template(
51
58
  messages,
52
59
  tokenize=False,
@@ -133,20 +140,21 @@ def build_llm_metadata(
133
140
  if gen_config.temperature:
134
141
  sampler_params.temperature = gen_config.temperature
135
142
 
136
- if isinstance(chat_templates, str):
137
- llm_metadata.jinja_prompt_template = chat_templates
138
- else:
139
- sys_prompt_parts, user_prompt_parts, model_prompt_parts = chat_templates
140
- pairs = []
141
- if sys_prompt_parts[0] is not None:
142
- pairs.append((sys_prompt_parts, llm_metadata.prompt_templates.system))
143
- if user_prompt_parts[0] is not None:
144
- pairs.append((user_prompt_parts, llm_metadata.prompt_templates.user))
145
- if model_prompt_parts[0] is not None:
146
- pairs.append((model_prompt_parts, llm_metadata.prompt_templates.model))
147
- for pts, fld in pairs:
148
- fld.prefix = pts[0]
149
- fld.suffix = pts[1]
143
+ if chat_templates is not None:
144
+ if isinstance(chat_templates, str):
145
+ llm_metadata.jinja_prompt_template = chat_templates
146
+ else:
147
+ sys_prompt_parts, user_prompt_parts, model_prompt_parts = chat_templates
148
+ pairs = []
149
+ if sys_prompt_parts[0] is not None:
150
+ pairs.append((sys_prompt_parts, llm_metadata.prompt_templates.system))
151
+ if user_prompt_parts[0] is not None:
152
+ pairs.append((user_prompt_parts, llm_metadata.prompt_templates.user))
153
+ if model_prompt_parts[0] is not None:
154
+ pairs.append((model_prompt_parts, llm_metadata.prompt_templates.model))
155
+ for pts, fld in pairs:
156
+ fld.prefix = pts[0]
157
+ fld.suffix = pts[1]
150
158
 
151
159
  llm_metadata.max_num_tokens = context_length
152
160
 
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
  """Utility functions."""
16
16
 
17
+ import os
18
+ import re
17
19
  import torch
18
20
 
19
21
 
@@ -89,3 +91,35 @@ def has_sliding_attention(model):
89
91
  return False
90
92
  layer_types = getattr(model.config, 'layer_types', None)
91
93
  return layer_types is not None and 'sliding_attention' in layer_types
94
+
95
+
96
+ def get_model_path_type(path_str: str) -> str:
97
+ """Determines if a string is a local path or a Hugging Face Repo ID.
98
+
99
+ Args:
100
+ path_str: The string to check.
101
+
102
+ Returns:
103
+ "local": If the path exists on disk.
104
+ "repo_id": If it looks like a Hub ID (e.g., 'meta-llama/Llama-2-7b').
105
+ "local_not_found": If it looks like a file path but doesn't exist.
106
+ "unknown": If it matches neither pattern clearly.
107
+ """
108
+ # 1. Absolute Truth: Does it exist on the disk?
109
+ if os.path.exists(path_str):
110
+ return 'local'
111
+
112
+ # 2. Heuristic: Does it have explicit path markers?
113
+ # Starts with "./", "/", "~", or contains Windows backslashes
114
+ if path_str.startswith(('.', '/', '~')) or '\\' in path_str:
115
+ return 'local_not_found'
116
+
117
+ # 3. Heuristic: Does it look like a Repo ID?
118
+ # Pattern: username/repo_name (e.g. "mistralai/Mistral-7B")
119
+ # or just repo_name for official models (e.g. "gpt2", "bert-base-uncased")
120
+ # Allowed chars: Alphanumeric, underscores, hyphens, periods.
121
+ repo_id_pattern = r'^(?:[\w\-\.]+\/)?[\w\-\.]+$'
122
+ if re.match(repo_id_pattern, path_str):
123
+ return 'repo_id'
124
+
125
+ return 'unknown'
litert_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of litert-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.8.0.dev20260126"
18
+ __version__ = "0.9.0.dev20260127"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: litert-torch-nightly
3
- Version: 0.8.0.dev20260126
3
+ Version: 0.9.0.dev20260127
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/litert-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -41,6 +41,7 @@ Requires-Dist: jax
41
41
  Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
42
42
  Requires-Dist: jaxtyping
43
43
  Requires-Dist: fire
44
+ Requires-Dist: sentencepiece
44
45
  Provides-Extra: torch-xla
45
46
  Requires-Dist: torch_xla>=2.4.0; extra == "torch-xla"
46
47
  Dynamic: classifier
@@ -1,9 +1,9 @@
1
1
  litert_torch/__init__.py,sha256=jgEpTtwnhxMhPGPTRvHJR7pzx6WU_mLbA-G1LjO7fnE,1279
2
2
  litert_torch/_config.py,sha256=zDnki83sBsQzDAea6bvzwccylWHnPUzbEyGGRh6B14w,2526
3
- litert_torch/cli.py,sha256=Svcs5U_HJIZYQUz5sf2Uu5__JTgxnU4-f4SBsZefgv0,985
3
+ litert_torch/cli.py,sha256=TiguLo2O3_wY8cCnKnbXtUvGH4lzyjeSgsOnHsKd9Gg,984
4
4
  litert_torch/conftest.py,sha256=gYmFrsR4c_fjIidbyrDnek26yS0crDP6-UoyMvy-WFg,757
5
5
  litert_torch/model.py,sha256=KXFTyyfPM6AnP0JoSwsTqQR3lUQbMkTGSr3dUsfQ5Jk,5635
6
- litert_torch/version.py,sha256=B2Yv1xWPIduypuwtL5Z-8PlcGXPeaOafIegxJq8Dagw,804
6
+ litert_torch/version.py,sha256=vbucmZdeLtxng93Sar6Ki8BLkqBQuisk59IB4CS7klU,804
7
7
  litert_torch/_convert/__init__.py,sha256=qdLdbj5NjhNG-QgY5O_8TzOr2XaDoWvmdY9JNPStQmw,670
8
8
  litert_torch/_convert/conversion.py,sha256=NuQEphyYp3W19IKvyTWo9pe7zt1-XmWM4zU9PDkUm54,6108
9
9
  litert_torch/_convert/conversion_utils.py,sha256=MWpB-3eN-rvQzTtXsPL30cDIK431SQuwvw3ia2K2ONM,2158
@@ -187,11 +187,11 @@ litert_torch/generative/export_hf/core/attention_test.py,sha256=RevOczfPncmbIBth
187
187
  litert_torch/generative/export_hf/core/cache.py,sha256=pfWh2SACdhNY2of2Z8KJC0wrSQ2jrkXgPHWe7PSEiuU,10263
188
188
  litert_torch/generative/export_hf/core/cache_base.py,sha256=FXMm9B8nDwC8uTyLmuBnYKLTnNtoeGN8gUnWwDCcH08,1714
189
189
  litert_torch/generative/export_hf/core/cache_test.py,sha256=y-v-oOGtRNPGWRfIfW3FcpDxvJbzrBU6Pb2o66FkUzU,6203
190
- litert_torch/generative/export_hf/core/export_lib.py,sha256=Gr3MqU1gBs4aSwpt3ag7dqwfCdlUXTXRmmeo3StF9mo,11622
190
+ litert_torch/generative/export_hf/core/export_lib.py,sha256=Nvg3QiYRZcMiQd7du7w5vohazjLlZJW6YFk_WAbIpAs,12249
191
191
  litert_torch/generative/export_hf/core/exportable_module.py,sha256=XEqsV9M34OP-_vsxH7bnmxSCD6erAPl0a9I9JQM7v6k,8305
192
- litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=pnm5GbHmwr5KLTfONUrsMdvhwRFC2q5OkRPiRsWW_Ls,7715
192
+ litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=f8Q2ifVyt65V-kRL0X9FRpQNKIer0R_Yx2lECZTMGPU,7965
193
193
  litert_torch/generative/export_hf/core/patches.py,sha256=i1fzs0anIFbBH-Q_PwCtp9VKXy64olJKwnGpnJUjkEo,1815
194
- litert_torch/generative/export_hf/core/utils.py,sha256=NxLMo4vgqG-8Hhr4ZsqDALtVV3n8rYbI1jiRaQfn-ho,2870
194
+ litert_torch/generative/export_hf/core/utils.py,sha256=5Wgs9aAOKd2i8wmQF_IierLUuFG23v1T6zZPr-azQ7A,4018
195
195
  litert_torch/generative/export_hf/core/external_emb/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
196
196
  litert_torch/generative/export_hf/core/external_emb/exportable_module.py,sha256=mWn75lLms3BAeCTEvbkGZ2n4fxtwsqGA8PP4S8-JBdY,3058
197
197
  litert_torch/generative/export_hf/core/external_rope/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
@@ -318,9 +318,9 @@ litert_torch/testing/__init__.py,sha256=AfYP1HwTYSQmupveonEHCDV5dEyshzUgbwUrCUhb
318
318
  litert_torch/testing/export.py,sha256=3dR6oxnrdtX0MfqAfMv233cf3sHA4e0F2TBQotoo8xc,3292
319
319
  litert_torch/testing/model_coverage/__init__.py,sha256=uPXeAhWiD1O0aMDLCX7FTOSNQiea8yOtoIYPCuHEAG4,763
320
320
  litert_torch/testing/model_coverage/model_coverage.py,sha256=EPCI7PbNPb7GV28lo3qQvFdzJwJ_ZDrbCGdpeiBZhVo,4715
321
- litert_torch_nightly-0.8.0.dev20260126.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
322
- litert_torch_nightly-0.8.0.dev20260126.dist-info/METADATA,sha256=IC63S5u9vXex_QMJhOklzWXMcPE-nr4JtC9S1V57NBU,2470
323
- litert_torch_nightly-0.8.0.dev20260126.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
324
- litert_torch_nightly-0.8.0.dev20260126.dist-info/entry_points.txt,sha256=roYAi9hp0uYrMudMR59hGNF2pz0TSAtqNl4vQLJzxnE,55
325
- litert_torch_nightly-0.8.0.dev20260126.dist-info/top_level.txt,sha256=mGrsl2SYcjQSLBJX4ZXrHnFqHZe6QLRR7uk0tLfzwfM,13
326
- litert_torch_nightly-0.8.0.dev20260126.dist-info/RECORD,,
321
+ litert_torch_nightly-0.9.0.dev20260127.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
322
+ litert_torch_nightly-0.9.0.dev20260127.dist-info/METADATA,sha256=bXq8cxXF9vyEVA9Kr6Q8ZKlz5x4v2L-aa1CfUBdsuAY,2499
323
+ litert_torch_nightly-0.9.0.dev20260127.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
324
+ litert_torch_nightly-0.9.0.dev20260127.dist-info/entry_points.txt,sha256=roYAi9hp0uYrMudMR59hGNF2pz0TSAtqNl4vQLJzxnE,55
325
+ litert_torch_nightly-0.9.0.dev20260127.dist-info/top_level.txt,sha256=mGrsl2SYcjQSLBJX4ZXrHnFqHZe6QLRR7uk0tLfzwfM,13
326
+ litert_torch_nightly-0.9.0.dev20260127.dist-info/RECORD,,