ai-edge-torch-nightly 0.3.0.dev20250105__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. ai_edge_torch/_config.py +26 -9
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  9. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  10. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  12. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  16. ai_edge_torch/generative/layers/attention.py +70 -12
  17. ai_edge_torch/generative/layers/lora.py +557 -0
  18. ai_edge_torch/generative/layers/normalization.py +2 -50
  19. ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
  20. ai_edge_torch/generative/test/test_lora.py +147 -0
  21. ai_edge_torch/generative/utilities/converter.py +100 -47
  22. ai_edge_torch/generative/utilities/model_builder.py +21 -16
  23. ai_edge_torch/generative/utilities/verifier.py +4 -4
  24. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  25. ai_edge_torch/odml_torch/export.py +6 -2
  26. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +32 -30
  30. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20250105.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,557 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """LoRA weights for generative models.
17
+
18
+ The current implementation support attention only lora. Additionally, we expect
19
+ lora weights for all projections within the attention module (i.e., Q, K, V, O).
20
+ """
21
+
22
+ import dataclasses
23
+ from typing import Any, Callable, List, Optional, Tuple
24
+
25
+ from ai_edge_torch.generative.layers import model_config
26
+ import flatbuffers
27
+ import numpy as np
28
+ import safetensors
29
+ import torch
30
+ import torch.utils._pytree as pytree
31
+
32
+ from tensorflow.lite.python import schema_py_generated as schema_fb # pylint: disable=g-direct-tensorflow-import
33
+
34
+ _TFLITE_SCHEMA_VERSION = 3
35
+ _TFLITE_FILE_IDENTIFIER = b"TFL3"
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class LoRAWeight:
40
+ """LoRA weight per projection. The weights are pre-transposed."""
41
+
42
+ a_prime: torch.Tensor
43
+ b_prime: torch.Tensor
44
+
45
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
46
+ if not isinstance(other, LoRAWeight):
47
+ return False
48
+ if self.a_prime.shape != other.a_prime.shape:
49
+ return False
50
+ if self.b_prime.shape != other.b_prime.shape:
51
+ return False
52
+ return torch.allclose(
53
+ self.a_prime, other.a_prime, rtol=rtol, atol=atol
54
+ ) and torch.allclose(self.b_prime, other.b_prime, rtol=rtol, atol=atol)
55
+
56
+
57
+ @dataclasses.dataclass
58
+ class AttentionLoRA:
59
+ """LoRA weights for attention module."""
60
+
61
+ query: LoRAWeight
62
+ key: LoRAWeight
63
+ value: LoRAWeight
64
+ output: LoRAWeight
65
+
66
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
67
+ if not isinstance(other, AttentionLoRA):
68
+ return False
69
+ return (
70
+ self.query.__eq__(other.query, rtol=rtol, atol=atol)
71
+ and self.key.__eq__(other.key, rtol=rtol, atol=atol)
72
+ and self.value.__eq__(other.value, rtol=rtol, atol=atol)
73
+ and self.output.__eq__(other.output, rtol=rtol, atol=atol)
74
+ )
75
+
76
+
77
+ @dataclasses.dataclass
78
+ class LoRAEntry:
79
+ """LoRA weights for a single layer."""
80
+
81
+ attention: AttentionLoRA
82
+
83
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
84
+ if not isinstance(other, LoRAEntry):
85
+ return False
86
+ return self.attention.__eq__(other.attention, rtol=rtol, atol=atol)
87
+
88
+
89
+ @dataclasses.dataclass
90
+ class LoRATensorNames:
91
+ """Tensor names for LoRA weights."""
92
+
93
+ attn_query_w_a: str
94
+ attn_query_w_b: str
95
+
96
+ attn_key_w_a: str
97
+ attn_key_w_b: str
98
+
99
+ attn_value_w_a: str
100
+ attn_value_w_b: str
101
+
102
+ attn_output_w_a: str
103
+ attn_output_w_b: str
104
+
105
+
106
+ @dataclasses.dataclass
107
+ class LoRA:
108
+ """LoRA weights for all modules."""
109
+
110
+ adapters: Tuple[LoRAEntry, ...]
111
+
112
+ def __eq__(self, other: Any, rtol: float = 1e-5, atol: float = 1e-8) -> bool:
113
+ if not isinstance(other, LoRA):
114
+ return False
115
+ if len(self.adapters) != len(other.adapters):
116
+ return False
117
+ return all(
118
+ adapter.__eq__(other_adapter, rtol=rtol, atol=atol)
119
+ for adapter, other_adapter in zip(self.adapters, other.adapters)
120
+ )
121
+
122
+ def get_rank(self) -> int:
123
+ """Returns the rank of the LoRA weights."""
124
+ return self.adapters[0].attention.query.a_prime.shape[1]
125
+
126
+ @classmethod
127
+ def from_safetensors(
128
+ cls,
129
+ path: str,
130
+ scale: float,
131
+ config: model_config.ModelConfig,
132
+ lora_tensor_names: LoRATensorNames,
133
+ dtype: torch.dtype = torch.float32,
134
+ ) -> "LoRA":
135
+ """Creates LoRA weights from a Hugging Face model.
136
+
137
+ Args:
138
+ path: Path to the model.
139
+ scale: Scale factor for the LoRA weights (applied only to one of the
140
+ projections). The scaling factor depnds on the training configuration.
141
+ The common values are either `lora_alpha / rank` or `lora_alpha /
142
+ sqrt(rank)`.
143
+ config: Model configuration.
144
+ lora_tensor_names: Tensor names for the LoRA weights.
145
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
146
+
147
+ Returns:
148
+ LoRA weights for all modules.
149
+ """
150
+ with safetensors.safe_open(path, framework="pt", device="cpu") as f:
151
+ adapters = []
152
+ for i in range(config.num_layers):
153
+ attention_lora = AttentionLoRA(
154
+ query=LoRAWeight(
155
+ a_prime=f.get_tensor(lora_tensor_names.attn_query_w_a.format(i))
156
+ .to(dtype)
157
+ .T
158
+ * scale,
159
+ b_prime=f.get_tensor(lora_tensor_names.attn_query_w_b.format(i))
160
+ .to(dtype)
161
+ .T,
162
+ ),
163
+ key=LoRAWeight(
164
+ a_prime=f.get_tensor(lora_tensor_names.attn_key_w_a.format(i))
165
+ .to(dtype)
166
+ .T
167
+ * scale,
168
+ b_prime=f.get_tensor(lora_tensor_names.attn_key_w_b.format(i))
169
+ .to(dtype)
170
+ .T,
171
+ ),
172
+ value=LoRAWeight(
173
+ a_prime=f.get_tensor(lora_tensor_names.attn_value_w_a.format(i))
174
+ .to(dtype)
175
+ .T
176
+ * scale,
177
+ b_prime=f.get_tensor(lora_tensor_names.attn_value_w_b.format(i))
178
+ .to(dtype)
179
+ .T,
180
+ ),
181
+ output=LoRAWeight(
182
+ a_prime=f.get_tensor(
183
+ lora_tensor_names.attn_output_w_a.format(i)
184
+ )
185
+ .to(dtype)
186
+ .T
187
+ * scale,
188
+ b_prime=f.get_tensor(
189
+ lora_tensor_names.attn_output_w_b.format(i)
190
+ )
191
+ .to(dtype)
192
+ .T,
193
+ ),
194
+ )
195
+ adapters.append(LoRAEntry(attention=attention_lora))
196
+ return cls(adapters=adapters)
197
+
198
+ @classmethod
199
+ def from_flatbuffers(
200
+ cls,
201
+ flatbuffer_model: bytearray,
202
+ dtype: torch.dtype = torch.float32,
203
+ ) -> "LoRA":
204
+ """Creates LoRA weights from FlatBuffers.
205
+
206
+ Args:
207
+ flatbuffer_model: FlatBuffers model.
208
+ dtype: Data type of the LoRA weights.
209
+
210
+ Returns:
211
+ LoRA weights for all modules.
212
+ """
213
+ model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
214
+ model = schema_fb.ModelT.InitFromObj(model)
215
+
216
+ flat_names = []
217
+ tensors = []
218
+ for tensor in model.subgraphs[0].tensors:
219
+ name = tensor.name.decode("utf-8")
220
+ assert name.startswith("lora_")
221
+ flat_names.append(name.split("lora_")[-1])
222
+ buffer_bytes = model.buffers[tensor.buffer].data.data.tobytes()
223
+ arr = np.frombuffer(buffer_bytes, dtype=np.float32).reshape(tensor.shape)
224
+ torch_tensor = torch.from_numpy(arr).to(dtype)
225
+ tensors.append(torch_tensor)
226
+
227
+ return _unflatten_lora(tensors, (flat_names, []))
228
+
229
+ @classmethod
230
+ def zeros(
231
+ cls,
232
+ rank: int,
233
+ config: model_config.ModelConfig,
234
+ dtype: torch.dtype = torch.float32,
235
+ ) -> "LoRA":
236
+ """Creates LoRA weights with zeros.
237
+
238
+ Args:
239
+ rank: Rank of the LoRA weights.
240
+ config: Model configuration.
241
+ dtype: Data type of the LoRA weights. Currently only float32 is supported.
242
+
243
+ Returns:
244
+ LoRA weights with zeros.
245
+ """
246
+ return cls._from_tensor_generator(
247
+ tensor_generator=lambda shape, dtype: torch.zeros(shape, dtype=dtype),
248
+ rank=rank,
249
+ config=config,
250
+ dtype=dtype,
251
+ )
252
+
253
+ @classmethod
254
+ def random(
255
+ cls,
256
+ rank: int,
257
+ config: model_config.ModelConfig,
258
+ dtype: torch.dtype = torch.float32,
259
+ ) -> "LoRA":
260
+ """Creates LoRA weights with random values.
261
+
262
+ Args:
263
+ rank: Rank of the LoRA weights.
264
+ config: Model configuration.
265
+ dtype: Data type of the LoRA weights.
266
+
267
+ Returns:
268
+ LoRA weights with random values.
269
+ """
270
+ return cls._from_tensor_generator(
271
+ tensor_generator=lambda shape, dtype: torch.randint(
272
+ low=0, high=128, size=shape, dtype=dtype
273
+ ),
274
+ rank=rank,
275
+ config=config,
276
+ dtype=dtype,
277
+ )
278
+
279
+ @classmethod
280
+ def _from_tensor_generator(
281
+ cls,
282
+ tensor_generator: Callable[[Tuple[int, ...], torch.dtype], torch.Tensor],
283
+ rank: int,
284
+ config: model_config.ModelConfig,
285
+ dtype: torch.dtype = torch.float32,
286
+ ) -> "LoRA":
287
+ """Creates LoRA weights from a tensor generator."""
288
+ adapters = []
289
+
290
+ for i in range(config.num_layers):
291
+ block_config = config.block_config(i)
292
+ q_per_kv = (
293
+ block_config.attn_config.num_heads
294
+ // block_config.attn_config.num_query_groups
295
+ )
296
+ q_out_dim = q_per_kv * block_config.attn_config.head_dim
297
+ k_out_dim = v_out_dim = block_config.attn_config.head_dim
298
+ attention_lora = AttentionLoRA(
299
+ query=LoRAWeight(
300
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
301
+ b_prime=tensor_generator((rank, q_out_dim), dtype),
302
+ ),
303
+ key=LoRAWeight(
304
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
305
+ b_prime=tensor_generator((rank, k_out_dim), dtype),
306
+ ),
307
+ value=LoRAWeight(
308
+ a_prime=tensor_generator((config.embedding_dim, rank), dtype),
309
+ b_prime=tensor_generator((rank, v_out_dim), dtype),
310
+ ),
311
+ output=LoRAWeight(
312
+ a_prime=tensor_generator(
313
+ (
314
+ block_config.attn_config.num_heads
315
+ * block_config.attn_config.head_dim,
316
+ rank,
317
+ ),
318
+ dtype,
319
+ ),
320
+ b_prime=tensor_generator((rank, config.embedding_dim), dtype),
321
+ ),
322
+ )
323
+ adapters.append(LoRAEntry(attention=attention_lora))
324
+ return cls(adapters=adapters)
325
+
326
+ def to_tflite(self) -> bytearray:
327
+ """Converts LoRA to FlatBuffers."""
328
+ return _lora_to_flatbuffers(self)
329
+
330
+
331
+ def apply_lora(
332
+ x: torch.Tensor,
333
+ lora_weight: LoRAWeight,
334
+ shape: Optional[Tuple[int, ...]] = None,
335
+ ) -> torch.Tensor:
336
+ """Applies LoRA weights to a tensor.
337
+
338
+ Args:
339
+ x: Input tensor.
340
+ lora_weight: LoRA weight.
341
+ shape: Output shape. If None, the output shape is the same as the input
342
+ shape.
343
+
344
+ Returns:
345
+ Output tensor.
346
+ """
347
+ output = torch.matmul(
348
+ torch.matmul(x, lora_weight.a_prime), lora_weight.b_prime
349
+ )
350
+ if shape is not None:
351
+ output = output.reshape(shape)
352
+ return output
353
+
354
+
355
+ def _flatten_attention_lora(
356
+ lora: AttentionLoRA, block_index: int
357
+ ) -> Tuple[List[torch.Tensor], List[str]]:
358
+ """Flattens LoRA weights for attention module."""
359
+ flattened = []
360
+ flat_names = []
361
+ flattened.append(lora.query.a_prime)
362
+ flat_names.append(f"atten_q_a_prime_weight_{block_index}")
363
+ flattened.append(lora.query.b_prime)
364
+ flat_names.append(f"atten_q_b_prime_weight_{block_index}")
365
+ flattened.append(lora.key.a_prime)
366
+ flat_names.append(f"atten_k_a_prime_weight_{block_index}")
367
+ flattened.append(lora.key.b_prime)
368
+ flat_names.append(f"atten_k_b_prime_weight_{block_index}")
369
+ flattened.append(lora.value.a_prime)
370
+ flat_names.append(f"atten_v_a_prime_weight_{block_index}")
371
+ flattened.append(lora.value.b_prime)
372
+ flat_names.append(f"atten_v_b_prime_weight_{block_index}")
373
+ flattened.append(lora.output.a_prime)
374
+ flat_names.append(f"atten_o_a_prime_weight_{block_index}")
375
+ flattened.append(lora.output.b_prime)
376
+ flat_names.append(f"atten_o_b_prime_weight_{block_index}")
377
+ return flattened, flat_names
378
+
379
+
380
+ def _flatten_lora(lora: LoRA) -> Tuple[List[torch.Tensor], List[Any]]:
381
+ """Flattens LoRA weights."""
382
+ flattened = []
383
+ flat_names = []
384
+ none_names = []
385
+ for i, entry in enumerate(lora.adapters):
386
+ attn_flattened, attn_flat_names = _flatten_attention_lora(
387
+ lora=entry.attention, block_index=i
388
+ )
389
+ flattened.extend(attn_flattened)
390
+ flat_names.extend(attn_flat_names)
391
+ return flattened, [flat_names, none_names]
392
+
393
+
394
+ def _flatten_lora_with_keys(lora: LoRA) -> Tuple[List[Any], List[Any]]:
395
+ """Flattens LoRA weights with keys."""
396
+ flattened, (flat_names, _) = _flatten_lora(lora)
397
+ return [
398
+ (pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
399
+ ], flat_names
400
+
401
+
402
+ def _unflatten_lora(
403
+ values: List[torch.Tensor], context: Tuple[List[str], List[Any]]
404
+ ) -> LoRA:
405
+ """Unflattens LoRA object."""
406
+ flat_names, _ = context
407
+ names_weights = list(zip(flat_names, values))
408
+ adapters = {}
409
+ while names_weights:
410
+ name, weight = names_weights.pop(0)
411
+ block_idx = int(name.split("_")[-1])
412
+ if block_idx not in adapters:
413
+ adapters[block_idx] = LoRAEntry(
414
+ attention=AttentionLoRA(
415
+ query=LoRAWeight(
416
+ a_prime=None,
417
+ b_prime=None,
418
+ ),
419
+ key=LoRAWeight(
420
+ a_prime=None,
421
+ b_prime=None,
422
+ ),
423
+ value=LoRAWeight(
424
+ a_prime=None,
425
+ b_prime=None,
426
+ ),
427
+ output=LoRAWeight(
428
+ a_prime=None,
429
+ b_prime=None,
430
+ ),
431
+ )
432
+ )
433
+
434
+ if name.startswith("atten_"):
435
+ if "q_a_prime" in name:
436
+ adapters[block_idx].attention.query.a_prime = weight
437
+ elif "q_b_prime" in name:
438
+ adapters[block_idx].attention.query.b_prime = weight
439
+ elif "k_a_prime" in name:
440
+ adapters[block_idx].attention.key.a_prime = weight
441
+ elif "k_b_prime" in name:
442
+ adapters[block_idx].attention.key.b_prime = weight
443
+ elif "v_a_prime" in name:
444
+ adapters[block_idx].attention.value.a_prime = weight
445
+ elif "v_b_prime" in name:
446
+ adapters[block_idx].attention.value.b_prime = weight
447
+ elif "o_a_prime" in name:
448
+ adapters[block_idx].attention.output.a_prime = weight
449
+ elif "o_b_prime" in name:
450
+ adapters[block_idx].attention.output.b_prime = weight
451
+ else:
452
+ raise ValueError(f"Unsupported name: {name}")
453
+ else:
454
+ raise ValueError(f"Unsupported name: {name}")
455
+
456
+ return LoRA(adapters=tuple(adapters[key] for key in sorted(adapters)))
457
+
458
+
459
+ pytree.register_pytree_node(
460
+ LoRA,
461
+ _flatten_lora,
462
+ _unflatten_lora,
463
+ flatten_with_keys_fn=_flatten_lora_with_keys,
464
+ serialized_type_name="",
465
+ )
466
+
467
+
468
+ def _add_buffer(builder: flatbuffers.Builder, data: np.ndarray | None) -> int:
469
+ """Adds a buffer to the FlatBuffers."""
470
+ if data is not None:
471
+ assert data.dtype == np.float32
472
+ schema_fb.BufferStartDataVector(builder, data.size * data.itemsize)
473
+ for value in reversed(data.flatten().tolist()):
474
+ builder.PrependFloat32(value)
475
+ data_offset = builder.EndVector()
476
+ else:
477
+ schema_fb.BufferStartDataVector(builder, 0)
478
+ data_offset = builder.EndVector()
479
+
480
+ schema_fb.BufferStart(builder)
481
+ schema_fb.BufferAddData(builder, data_offset)
482
+ buffer_offset = schema_fb.BufferEnd(builder)
483
+ return buffer_offset
484
+
485
+
486
+ def _add_tensor(
487
+ builder: flatbuffers.Builder,
488
+ name: str,
489
+ shape: Tuple[int, ...],
490
+ buffer_idx: int,
491
+ ) -> int:
492
+ """Adds a tensor to the FlatBuffers."""
493
+ name_offset = builder.CreateString(name)
494
+ schema_fb.TensorStartShapeVector(builder, len(shape))
495
+ for dim in reversed(shape):
496
+ builder.PrependInt32(dim)
497
+ shape_offset = builder.EndVector()
498
+ schema_fb.TensorStart(builder)
499
+ schema_fb.TensorAddName(builder, name_offset)
500
+ schema_fb.TensorAddShape(builder, shape_offset)
501
+ schema_fb.TensorAddType(builder, schema_fb.TensorType.FLOAT32)
502
+ schema_fb.TensorAddBuffer(builder, buffer_idx)
503
+ tensor_offset = schema_fb.TensorEnd(builder)
504
+ return tensor_offset
505
+
506
+
507
+ def _lora_to_flatbuffers(lora: LoRA) -> bytearray:
508
+ """Converts LoRA to FlatBuffers."""
509
+ tensors, (names, _) = _flatten_lora(lora)
510
+ # Need to manually add the "lora_" prefix to the names here. The export will
511
+ # add the prefix automatically.
512
+ names = [f"lora_{name}" for name in names]
513
+ builder = flatbuffers.Builder(4096)
514
+
515
+ # Convention to add an empty buffer in the beginning.
516
+ buffer_offsets = [_add_buffer(builder, None)]
517
+ for tensor in tensors:
518
+ buffer_offsets.append(
519
+ _add_buffer(builder, tensor.detach().type(torch.float32).numpy())
520
+ )
521
+
522
+ schema_fb.ModelStartBuffersVector(builder, len(buffer_offsets))
523
+ for buffer_offset in reversed(buffer_offsets):
524
+ builder.PrependUOffsetTRelative(buffer_offset)
525
+ buffers_offset = builder.EndVector()
526
+
527
+ tensor_offsets = []
528
+ for i, (name, tensor) in enumerate(zip(names, tensors)):
529
+ # Note that the zeroth buffer is empty and reserved for the convention.
530
+ tensor_offsets.append(_add_tensor(builder, name, tensor.shape, i + 1))
531
+
532
+ schema_fb.SubGraphStartTensorsVector(builder, len(tensor_offsets))
533
+ for tensor_offset in reversed(tensor_offsets):
534
+ builder.PrependUOffsetTRelative(tensor_offset)
535
+ tensors_offset = builder.EndVector()
536
+
537
+ string_offset = builder.CreateString("lora_params")
538
+ schema_fb.SubGraphStart(builder)
539
+ schema_fb.SubGraphAddName(builder, string_offset)
540
+ schema_fb.SubGraphAddTensors(builder, tensors_offset)
541
+ subgraph_offset = schema_fb.SubGraphEnd(builder)
542
+
543
+ schema_fb.ModelStartSubgraphsVector(builder, 1)
544
+ builder.PrependUOffsetTRelative(subgraph_offset)
545
+ subgraphs_offset = builder.EndVector()
546
+
547
+ string_offset = builder.CreateString("lora_params")
548
+ schema_fb.ModelStart(builder)
549
+ schema_fb.ModelAddVersion(builder, _TFLITE_SCHEMA_VERSION)
550
+ schema_fb.ModelAddDescription(builder, string_offset)
551
+ schema_fb.ModelAddBuffers(builder, buffers_offset)
552
+ schema_fb.ModelAddSubgraphs(builder, subgraphs_offset)
553
+ model_offset = schema_fb.ModelEnd(builder)
554
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
555
+ flatbuffer_model = builder.Output()
556
+
557
+ return flatbuffer_model
@@ -80,6 +80,7 @@ class RMSNorm(torch.nn.Module):
80
80
  output = self._norm(x.float()).type_as(x)
81
81
  return output * w
82
82
 
83
+
83
84
  class GroupNorm(torch.nn.Module):
84
85
 
85
86
  def __init__(
@@ -115,16 +116,7 @@ class GroupNorm(torch.nn.Module):
115
116
  Returns:
116
117
  torch.Tensor: output tensor after applying GroupNorm.
117
118
  """
118
- if self.enable_hlfb:
119
- return group_norm_with_hlfb(
120
- x,
121
- self.weight,
122
- self.bias,
123
- self.group_num,
124
- self.eps,
125
- )
126
- else:
127
- return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
119
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
128
120
 
129
121
 
130
122
  class LayerNorm(torch.nn.Module):
@@ -169,46 +161,6 @@ class LayerNorm(torch.nn.Module):
169
161
  )
170
162
 
171
163
 
172
- def group_norm_with_hlfb(
173
- x: torch.Tensor,
174
- w: torch.Tensor,
175
- b: torch.Tensor,
176
- num_groups: int,
177
- eps: float,
178
- ):
179
- """Group Normalization with high-level function boundary enabled.
180
-
181
- Args:
182
- x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
183
- w (torch.Tensor): The weight tensor for the normalization.
184
- b (torch.Tensor): The bias tensor for the normalization.
185
- num_groups (int): Number of groups to separate the channels into.
186
- eps (float): A small float value to ensure numerical stability.
187
-
188
- Returns:
189
- The output tensor of Group Normalization.
190
- """
191
- x = torch.permute(x, (0, 2, 3, 1))
192
-
193
- builder = StableHLOCompositeBuilder(
194
- name="odml.group_norm",
195
- attr={
196
- "num_groups": num_groups,
197
- "epsilon": eps,
198
- "reduction_axes": [3],
199
- "channel_axis": 3,
200
- },
201
- )
202
- x, w, b = builder.mark_inputs(x, w, b)
203
- x = torch.permute(x, (0, 3, 1, 2))
204
- y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
205
- y = torch.permute(y, (0, 2, 3, 1))
206
- y = builder.mark_outputs(y)
207
-
208
- y = torch.permute(y, (0, 3, 1, 2))
209
- return y
210
-
211
-
212
164
  def rms_norm_with_hlfb(
213
165
  x: torch.Tensor,
214
166
  w: torch.Tensor,