litert-torch-nightly 0.9.0.dev20260204__py3-none-any.whl → 0.9.0.dev20260206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -80,6 +80,7 @@ def load_model(
80
80
  model_path: str,
81
81
  trust_remote_code: bool = False,
82
82
  auto_model_override: str | None = None,
83
+ task: str = 'text_generation',
83
84
  ):
84
85
  """Loads model from checkpoint."""
85
86
 
@@ -90,7 +91,12 @@ def load_model(
90
91
  )
91
92
  config._attn_implementation = 'lrt_transposed_attention' # pylint: disable=protected-access
92
93
 
93
- auto_model_cls = transformers.AutoModelForCausalLM
94
+ if task == 'text_generation':
95
+ auto_model_cls = transformers.AutoModelForCausalLM
96
+ elif task == 'image_text_to_text':
97
+ auto_model_cls = transformers.AutoModelForImageTextToText
98
+ else:
99
+ raise ValueError(f'Unsupported task: {task}')
94
100
  if auto_model_override is not None:
95
101
  auto_model_cls = transformers.__dict__[auto_model_override]
96
102
 
@@ -101,14 +107,16 @@ def load_model(
101
107
  trust_remote_code=trust_remote_code,
102
108
  )
103
109
 
104
- model.generation_config.cache_implementation = 'static'
105
- model.generation_config.do_sample = False
110
+ if task == 'text_generation':
111
+ model.generation_config.cache_implementation = 'static'
112
+ model.generation_config.do_sample = False
106
113
 
107
114
  text_model_config = config
108
115
  if hasattr(config, 'text_config'):
109
116
  text_model_config = config.text_config
110
117
 
111
- verify_model_compatibility(model, config, text_model_config)
118
+ if task == 'text_generation':
119
+ verify_model_compatibility(model, config, text_model_config)
112
120
 
113
121
  # TODO(weiyiw): Refactor into a separate function.
114
122
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
@@ -326,7 +334,7 @@ def export_embedder_model(
326
334
  sample_kwargs=sample_inputs,
327
335
  )
328
336
  lrt_model = converter.convert(strict_export=False)
329
- model_path = os.path.join(work_dir, 'model.tflite')
337
+ model_path = os.path.join(work_dir, 'embedder.tflite')
330
338
  lrt_model.export(model_path)
331
339
  quantization_recipe_list = (
332
340
  quantization_recipe.split(',') if quantization_recipe else [None]
@@ -359,7 +367,10 @@ def export_auxiliary_model(
359
367
  sample_kwargs=sample_input,
360
368
  )
361
369
  # Attention Mask
362
- attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(model)
370
+ attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(
371
+ export_config.cache_length,
372
+ # TODO(weiyiw): Add sliding window sizes.
373
+ )
363
374
  sample_inputs = attention_mask_module.get_sample_inputs(
364
375
  text_model_config, export_config
365
376
  )
@@ -370,7 +381,7 @@ def export_auxiliary_model(
370
381
  sample_kwargs=sample_input,
371
382
  )
372
383
  # Cache Update
373
- cache_update_module = split_cache_module.CacheUpdate(model)
384
+ cache_update_module = split_cache_module.CacheUpdate()
374
385
  sample_inputs = cache_update_module.get_sample_inputs(
375
386
  text_model_config, export_config
376
387
  )
@@ -31,6 +31,7 @@ class ExportableModuleConfig:
31
31
 
32
32
  # Export configs
33
33
  externalize_embedder: bool = False
34
+ single_token_embedder: bool = False
34
35
  externalize_rope: bool = False
35
36
 
36
37
  split_cache: bool = False
@@ -94,3 +94,33 @@ class LiteRTExportableModuleForEmbedder(torch.nn.Module):
94
94
  token_ids = torch.maximum(token_ids, torch.tensor(0, dtype=torch.int32))
95
95
  output = self.model(token_ids)
96
96
  return {"embeddings": output}
97
+
98
+ @classmethod
99
+ def get_sample_inputs(
100
+ cls,
101
+ model_config,
102
+ export_config: base_exportable_module.ExportableModuleConfig,
103
+ ):
104
+ """Gets sample inputs."""
105
+ batch_size = export_config.batch_size
106
+ prefill_length = export_config.prefill_lengths[0]
107
+ prefill_length_dim = export_config.prefill_length_dim
108
+ del model_config # Unused.
109
+ tokens = {"token_ids": torch.ones((batch_size, 1), dtype=torch.int32)}
110
+ tokens_dynamic_shape = {"token_ids": {1: 1}} if prefill_length_dim else {}
111
+ if export_config.single_token_embedder:
112
+ return {"embedder": (tokens, tokens_dynamic_shape)}
113
+ else:
114
+ ret = {}
115
+ ret["decode_embedder"] = (tokens, tokens_dynamic_shape)
116
+
117
+ tokens = {
118
+ "token_ids": torch.ones(
119
+ (batch_size, prefill_length), dtype=torch.int32
120
+ )
121
+ }
122
+ tokens_dynamic_shape = (
123
+ {"token_ids": {1: prefill_length_dim}} if prefill_length_dim else {}
124
+ )
125
+ ret[f"prefill_embedder_{prefill_length}"] = (tokens, tokens_dynamic_shape)
126
+ return ret
@@ -118,7 +118,8 @@ def build_llm_metadata(
118
118
  if isinstance(gen_config.eos_token_id, int):
119
119
  stop_tokens.add(gen_config.eos_token_id)
120
120
  elif isinstance(gen_config.eos_token_id, list):
121
- stop_tokens.update(gen_config.eos_token_id)
121
+ for token_id in gen_config.eos_token_id:
122
+ stop_tokens.add(token_id)
122
123
  elif hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
123
124
  stop_tokens.add(tokenizer.eos_token)
124
125
  for stop_token in stop_tokens:
@@ -60,3 +60,32 @@ original_use_kernel_forward_from_hub = (
60
60
  transformers.integrations.use_kernel_forward_from_hub = (
61
61
  _use_kernel_forward_from_hub
62
62
  )
63
+
64
+
65
+ # TODO(weiyiw): Find a better way to patch Gemma3RMSNorm.
66
+ class Gemma3RMSNorm(torch.nn.Module):
67
+ """RMSNorm Layer."""
68
+
69
+ def __init__(self, dim: int, eps: float = 1e-6):
70
+ """RMSNorm Layer."""
71
+ super().__init__()
72
+ self.weight = torch.nn.Parameter(torch.ones(dim))
73
+ self.variance_epsilon = eps
74
+ self.hidden_size = dim
75
+
76
+ def forward(self, hidden_states):
77
+ return normalization.rms_norm_with_hlfb(
78
+ hidden_states,
79
+ self.weight + 1.0,
80
+ self.variance_epsilon,
81
+ torch.ones((self.hidden_size,), dtype=torch.float32),
82
+ )
83
+
84
+ def extra_repr(self):
85
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
86
+
87
+
88
+ from transformers.models.gemma3 import modeling_gemma3
89
+
90
+ original_gemma3_rms_norm = modeling_gemma3.Gemma3RMSNorm
91
+ modeling_gemma3.Gemma3RMSNorm = Gemma3RMSNorm
@@ -31,6 +31,7 @@ def export(
31
31
  quantization_recipe: str = 'dynamic_wi8_afp32',
32
32
  enable_dynamic_shape: bool = False,
33
33
  externalize_embedder: bool = False,
34
+ single_token_embedder: bool = False,
34
35
  key_ts_idx: int = 2,
35
36
  value_ts_idx: int = 3,
36
37
  split_cache: bool = False,
@@ -38,6 +39,7 @@ def export(
38
39
  # target_accelerator: str | None = None,
39
40
  trust_remote_code: bool = False,
40
41
  use_jinja_template: bool = False,
42
+ task: str = 'text_generation',
41
43
  ):
42
44
  """Exports HuggingFace Transformers model to tflite."""
43
45
  # TODO(weiyiw): Use tmp dir for work_dir.
@@ -47,6 +49,7 @@ def export(
47
49
  model,
48
50
  trust_remote_code=trust_remote_code,
49
51
  auto_model_override=auto_model_override,
52
+ task=task,
50
53
  )
51
54
  del config # Unused.
52
55
  if split_cache and not externalize_embedder:
@@ -62,6 +65,7 @@ def export(
62
65
  if enable_dynamic_shape
63
66
  else None,
64
67
  externalize_embedder=externalize_embedder,
68
+ single_token_embedder=single_token_embedder,
65
69
  k_ts_idx=key_ts_idx,
66
70
  v_ts_idx=value_ts_idx,
67
71
  split_cache=split_cache,
litert_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of litert-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.9.0.dev20260204"
18
+ __version__ = "0.9.0.dev20260206"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: litert-torch-nightly
3
- Version: 0.9.0.dev20260204
3
+ Version: 0.9.0.dev20260206
4
4
  Summary: Support PyTorch model conversion with LiteRT.
5
5
  Home-page: https://github.com/google-ai-edge/litert-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,LiteRT,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ litert_torch/_config.py,sha256=zDnki83sBsQzDAea6bvzwccylWHnPUzbEyGGRh6B14w,2526
3
3
  litert_torch/cli.py,sha256=TiAUbgbWm3ecTUJtJ1_hjKJuC1LrG-Qwnm8_zws-sVY,984
4
4
  litert_torch/conftest.py,sha256=gYmFrsR4c_fjIidbyrDnek26yS0crDP6-UoyMvy-WFg,757
5
5
  litert_torch/model.py,sha256=KXFTyyfPM6AnP0JoSwsTqQR3lUQbMkTGSr3dUsfQ5Jk,5635
6
- litert_torch/version.py,sha256=Cx4KT-tDB1CIsBgn-EX34IaDb7a_kd9beo1zv2eOhYQ,804
6
+ litert_torch/version.py,sha256=20bJvaJMGX0olC644fZPyW86ti_tJxn-WOPVm9S1tZ4,804
7
7
  litert_torch/_convert/__init__.py,sha256=qdLdbj5NjhNG-QgY5O_8TzOr2XaDoWvmdY9JNPStQmw,670
8
8
  litert_torch/_convert/conversion.py,sha256=NuQEphyYp3W19IKvyTWo9pe7zt1-XmWM4zU9PDkUm54,6108
9
9
  litert_torch/_convert/conversion_utils.py,sha256=MWpB-3eN-rvQzTtXsPL30cDIK431SQuwvw3ia2K2ONM,2158
@@ -179,7 +179,7 @@ litert_torch/generative/examples/tiny_llama/verify.py,sha256=6geA8OUOSj8_sTRyoo0
179
179
  litert_torch/generative/examples/tiny_llama/verify_util.py,sha256=FKMC6Olex6bJbB8HXvC1KwxPbKgRBfT1CjoWcmyaPD8,2989
180
180
  litert_torch/generative/export_hf/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
181
181
  litert_torch/generative/export_hf/__main__.py,sha256=8VuBDkZ2sL-q2XdQ45qwzeHQk39-MM_6TdkxOU_23xE,782
182
- litert_torch/generative/export_hf/export.py,sha256=HC_nwBg3WMGL_qMfOn7OB2SATKed6UQe2KqqE-6CHIA,3656
182
+ litert_torch/generative/export_hf/export.py,sha256=koqs0znGe9QXlEoRF7TuvDtrjbiXa79qgBhGK6MENwk,3800
183
183
  litert_torch/generative/export_hf/export_main.py,sha256=bQidNXz0MEP_gil86LSfnpCW0pUiqZq2-F9ZOrSb3Yk,1183
184
184
  litert_torch/generative/export_hf/core/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
185
185
  litert_torch/generative/export_hf/core/attention.py,sha256=bXuTHNeVtKwWf6YXgb5I2j08vvgb9M7r1RYGvdjl9QI,4798
@@ -187,14 +187,14 @@ litert_torch/generative/export_hf/core/attention_test.py,sha256=KBSyYjHoTKYi6Se6
187
187
  litert_torch/generative/export_hf/core/cache.py,sha256=UnuTBpJvplEyig1myrhA1d0QJ05pNJgWbm-GrsUu5Uk,11763
188
188
  litert_torch/generative/export_hf/core/cache_base.py,sha256=6s-6L6iSa-qn0PLdAAhpHdOU9qwqEE-JVdlIsYyCPt4,2180
189
189
  litert_torch/generative/export_hf/core/cache_test.py,sha256=y-v-oOGtRNPGWRfIfW3FcpDxvJbzrBU6Pb2o66FkUzU,6203
190
- litert_torch/generative/export_hf/core/export_lib.py,sha256=W1jG6L9oqu3hYnXaN0lQLtEqc5ZPkTyDVOzGOsLLkAU,14142
190
+ litert_torch/generative/export_hf/core/export_lib.py,sha256=qDOyLxCMeQtRbJQlLcI8Gq3PmL9yMJ2gUvPhoEcVty8,14516
191
191
  litert_torch/generative/export_hf/core/exportable_module.py,sha256=niCS0na0VvFLiwebnL4JeXZh2hT8FCQmp-vnyTBh7pA,8257
192
- litert_torch/generative/export_hf/core/exportable_module_config.py,sha256=cpqtagzOglvbr91NFC0K_QX-7mr5Q7gnhQ8Srqral9Y,1284
193
- litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=f8Q2ifVyt65V-kRL0X9FRpQNKIer0R_Yx2lECZTMGPU,7965
194
- litert_torch/generative/export_hf/core/patches.py,sha256=i1fzs0anIFbBH-Q_PwCtp9VKXy64olJKwnGpnJUjkEo,1815
192
+ litert_torch/generative/export_hf/core/exportable_module_config.py,sha256=oJOWBBKWYpLq5A5qXEAIZbLwvCpY22nstHx6L88CXqU,1322
193
+ litert_torch/generative/export_hf/core/litert_lm_builder.py,sha256=ai-5Njn8fGKco_5jiRnmACBIKu1EL2b5SY5ArmsmttM,7998
194
+ litert_torch/generative/export_hf/core/patches.py,sha256=h4TCTNPT0N9xcMFfJ54XnpCHt1iKwS8mU-GhAxdsUrc,2636
195
195
  litert_torch/generative/export_hf/core/utils.py,sha256=5Wgs9aAOKd2i8wmQF_IierLUuFG23v1T6zZPr-azQ7A,4018
196
196
  litert_torch/generative/export_hf/core/external_emb/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
197
- litert_torch/generative/export_hf/core/external_emb/exportable_module.py,sha256=mWn75lLms3BAeCTEvbkGZ2n4fxtwsqGA8PP4S8-JBdY,3058
197
+ litert_torch/generative/export_hf/core/external_emb/exportable_module.py,sha256=1ke2mugD--1bIqeJhAJ4Ly7o6NRW8RZL79UzCqRHLNY,4113
198
198
  litert_torch/generative/export_hf/core/external_rope/__init__.py,sha256=5xWIp2ziIwapcZcjSKfeaFgBnIooa8ckhTQ7mazZC3c,670
199
199
  litert_torch/generative/export_hf/core/external_rope/exportable_module.py,sha256=czTf835b9Nw4XcDo6cd9chsmBdbIsdqMtnEkwuwMgX0,2478
200
200
  litert_torch/generative/export_hf/core/external_rope/preprocess_model.py,sha256=NL3zROb7EZNAvZfutIhLk4KqXid_HklQMUjHZqZYOH4,1735
@@ -319,9 +319,9 @@ litert_torch/testing/__init__.py,sha256=AfYP1HwTYSQmupveonEHCDV5dEyshzUgbwUrCUhb
319
319
  litert_torch/testing/export.py,sha256=3dR6oxnrdtX0MfqAfMv233cf3sHA4e0F2TBQotoo8xc,3292
320
320
  litert_torch/testing/model_coverage/__init__.py,sha256=uPXeAhWiD1O0aMDLCX7FTOSNQiea8yOtoIYPCuHEAG4,763
321
321
  litert_torch/testing/model_coverage/model_coverage.py,sha256=EPCI7PbNPb7GV28lo3qQvFdzJwJ_ZDrbCGdpeiBZhVo,4715
322
- litert_torch_nightly-0.9.0.dev20260204.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
323
- litert_torch_nightly-0.9.0.dev20260204.dist-info/METADATA,sha256=-DAJh0KO6GPV9RjXiU3oOK4KKJTj1szkhHO-F6XI99o,2463
324
- litert_torch_nightly-0.9.0.dev20260204.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
325
- litert_torch_nightly-0.9.0.dev20260204.dist-info/entry_points.txt,sha256=roYAi9hp0uYrMudMR59hGNF2pz0TSAtqNl4vQLJzxnE,55
326
- litert_torch_nightly-0.9.0.dev20260204.dist-info/top_level.txt,sha256=mGrsl2SYcjQSLBJX4ZXrHnFqHZe6QLRR7uk0tLfzwfM,13
327
- litert_torch_nightly-0.9.0.dev20260204.dist-info/RECORD,,
322
+ litert_torch_nightly-0.9.0.dev20260206.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
323
+ litert_torch_nightly-0.9.0.dev20260206.dist-info/METADATA,sha256=mRNKi6UzLyaT9u1_6oqV95e9vs3f_77JAhDxkBtU3Ao,2463
324
+ litert_torch_nightly-0.9.0.dev20260206.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
325
+ litert_torch_nightly-0.9.0.dev20260206.dist-info/entry_points.txt,sha256=roYAi9hp0uYrMudMR59hGNF2pz0TSAtqNl4vQLJzxnE,55
326
+ litert_torch_nightly-0.9.0.dev20260206.dist-info/top_level.txt,sha256=mGrsl2SYcjQSLBJX4ZXrHnFqHZe6QLRR7uk0tLfzwfM,13
327
+ litert_torch_nightly-0.9.0.dev20260206.dist-info/RECORD,,