ai-edge-torch-nightly 0.2.0.dev20240603__py3-none-any.whl → 0.2.0.dev20240605__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -99,6 +99,36 @@ def _aten_hardswish(gm: GraphModule, node: Node):
99
99
  node.target = hardswish
100
100
 
101
101
 
102
+ @_register_composite_builder(torch.ops.aten.gelu.default)
103
+ def _aten_gelu(gm: GraphModule, node: Node):
104
+ op = node.target
105
+ args_mapper = TorchOpArgumentsMapper(op)
106
+
107
+ def gelu(*args, **kwargs):
108
+ nonlocal op, args_mapper
109
+
110
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
111
+
112
+ # TFLite supports exact and tanh approximate.
113
+ if full_kwargs["approximate"] != "none" and full_kwargs["approximate"] != "tanh":
114
+ return op(*args, **kwargs)
115
+
116
+ builder = StableHLOCompositeBuilder(
117
+ "aten.gelu.default",
118
+ attr=_tree_map_to_composite_attr_values(
119
+ {
120
+ "approximate": full_kwargs["approximate"],
121
+ }
122
+ ),
123
+ )
124
+ full_kwargs["self"] = builder.mark_inputs(full_kwargs["self"])
125
+ output = op(full_kwargs["self"])
126
+ output = builder.mark_outputs(output)
127
+ return output
128
+
129
+ node.target = gelu
130
+
131
+
102
132
  @_register_composite_builder(torch.ops.aten.avg_pool2d.default)
103
133
  def _aten_avg_pool2d(gm: GraphModule, node: Node):
104
134
  op = node.target
@@ -212,7 +212,6 @@ class TestConvert(unittest.TestCase):
212
212
  self.assertTrue(flags["key1"], "new_value1")
213
213
  self.assertTrue(flags["key2"]["subkey2"], "new_subvalue2")
214
214
 
215
- @unittest.skip("https://b.corp.google.com/issues/331463544")
216
215
  def test_convert_add_backdoor_flags(self):
217
216
  """Tests conversion of an add module setting a tflite converter flag."""
218
217
 
@@ -228,13 +227,13 @@ class TestConvert(unittest.TestCase):
228
227
  torch_module = Add().eval()
229
228
 
230
229
  with tempfile.TemporaryDirectory() as tmp_dir_path:
231
- mlir_dump_path = os.path.join(
230
+ ir_dump_path = os.path.join(
232
231
  tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
233
232
  )
234
233
  ai_edge_torch.convert(
235
- torch_module, args, _ai_edge_converter_flags={"mlir_dump_dir": mlir_dump_path}
234
+ torch_module, args, _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path}
236
235
  )
237
- self.assertTrue(os.path.isdir(mlir_dump_path))
236
+ self.assertTrue(os.path.isdir(ir_dump_path))
238
237
 
239
238
  def test_convert_model_with_dynamic_batch(self):
240
239
  """
@@ -169,6 +169,24 @@ class TestConvertComposites(unittest.TestCase):
169
169
  model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
170
170
  )
171
171
 
172
+ def test_convert_gelu(self):
173
+ """Tests conversion of a GELU module."""
174
+
175
+ args = (torch.randn((5, 10)),)
176
+ torch_module = torch.nn.GELU().eval()
177
+ edge_model = ai_edge_torch.convert(torch_module, args)
178
+
179
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
180
+
181
+ def test_convert_gelu_approximate(self):
182
+ """Tests conversion of an Approximate GELU module."""
183
+
184
+ args = (torch.randn((5, 10)),)
185
+ torch_module = torch.nn.GELU('tanh').eval()
186
+ edge_model = ai_edge_torch.convert(torch_module, args)
187
+
188
+ self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
189
+
172
190
 
173
191
  if __name__ == '__main__':
174
192
  unittest.main()
@@ -15,65 +15,99 @@
15
15
 
16
16
  import torch
17
17
  from torch import nn
18
- from torch._prims_common import mask_tensor
19
- from torch._prims_common.wrappers import out_wrapper
20
18
 
21
- from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
19
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
20
+ import ai_edge_torch.generative.layers.attention_utils as attention_utils
21
+ import ai_edge_torch.generative.layers.builder as builder
22
+ import ai_edge_torch.generative.layers.model_config as cfg
23
+ import ai_edge_torch.generative.utilities.loader as loading_utils
24
+
25
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ ff_up_proj="layers.{}.linear_1",
27
+ ff_down_proj="layers.{}.linear_2",
28
+ ff_gate_proj="layers.{}.linear_1",
29
+ attn_fused_qkv_proj="layers.{}.attention.in_proj",
30
+ attn_output_proj="layers.{}.attention.out_proj",
31
+ pre_attn_norm="layers.{}.layernorm_1",
32
+ pre_ff_norm="layers.{}.layernorm_2",
33
+ embedding="embedding.token_embedding",
34
+ embedding_position="embedding.position_value",
35
+ final_norm="layernorm",
36
+ lm_head=None,
37
+ )
22
38
 
23
39
 
24
- class CLIPEmbedding(nn.Module):
25
-
26
- def __init__(self, n_vocab: int, n_embd: int, n_token: int):
27
- super().__init__()
28
- self.token_embedding = nn.Embedding(n_vocab, n_embd)
29
- self.position_value = nn.Parameter(torch.zeros((n_token, n_embd)))
30
-
31
- def forward(self, tokens):
32
- x = self.token_embedding(tokens)
33
- x += self.position_value
34
- return x
35
-
36
-
37
- class CLIPLayer(nn.Module):
40
+ class CLIP(nn.Module):
41
+ """CLIP text encoder
42
+ For details, see https://arxiv.org/abs/2103.00020
43
+ """
38
44
 
39
- def __init__(self, n_head: int, n_embd: int):
45
+ def __init__(self, config: cfg.ModelConfig):
40
46
  super().__init__()
41
- self.layernorm_1 = nn.LayerNorm(n_embd)
42
- self.attention = SelfAttention(n_head, n_embd)
43
- self.layernorm_2 = nn.LayerNorm(n_embd)
44
- self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
45
- self.linear_2 = nn.Linear(4 * n_embd, n_embd)
46
-
47
- def forward(self, x):
48
- residue = x
49
- x = self.layernorm_1(x)
50
- x = self.attention(x, causal_mask=True)
51
- x += residue
47
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
48
+ self.tok_embedding_position = nn.Parameter(
49
+ torch.zeros((config.max_seq_len, config.embedding_dim))
50
+ )
52
51
 
53
- residue = x
54
- x = self.layernorm_2(x)
55
- x = self.linear_1(x)
56
- x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
57
- x = self.linear_2(x)
58
- x += residue
52
+ self.config = config
53
+ self.transformer_blocks = nn.ModuleList(
54
+ TransformerBlock(config) for _ in range(config.num_layers)
55
+ )
56
+ self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
59
57
 
60
- return x
61
-
62
-
63
- class CLIP(nn.Module):
64
-
65
- def __init__(self):
66
- super().__init__()
67
- self.embedding = CLIPEmbedding(49408, 768, 77)
68
- self.layers = nn.ModuleList([CLIPLayer(12, 768) for i in range(12)])
69
- self.layernorm = nn.LayerNorm(768)
58
+ self.mask_cache = attention_utils.build_causal_mask_cache(
59
+ size=config.max_seq_len, dtype=torch.float32
60
+ )
70
61
 
71
62
  @torch.inference_mode
72
63
  def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
73
64
  tokens = tokens.type(torch.long)
74
65
 
75
- state = self.embedding(tokens)
76
- for layer in self.layers:
77
- state = layer(state)
78
- output = self.layernorm(state)
66
+ state = self.tok_embedding(tokens) + self.tok_embedding_position
67
+ for layer in self.transformer_blocks:
68
+ state = layer(state, mask=self.mask_cache)
69
+ output = self.final_norm(state)
79
70
  return output
71
+
72
+
73
+ def get_model_config() -> cfg.ModelConfig:
74
+ max_seq_len = 77
75
+ vocab_size = 49408
76
+ num_layers = 12
77
+ num_heads = 12
78
+ num_query_groups = 12
79
+ embedding_dim = 768
80
+
81
+ attn_config = cfg.AttentionConfig(
82
+ num_heads=num_heads,
83
+ num_query_groups=num_query_groups,
84
+ rotary_percentage=0.0,
85
+ qkv_use_bias=True,
86
+ qkv_transpose_before_split=True,
87
+ output_proj_use_bias=True,
88
+ enable_kv_cache=False,
89
+ )
90
+
91
+ ff_config = cfg.FeedForwardConfig(
92
+ type=cfg.FeedForwardType.SEQUENTIAL,
93
+ activation=cfg.ActivationType.GELU_QUICK,
94
+ intermediate_size=embedding_dim * 4,
95
+ use_bias=True,
96
+ )
97
+
98
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
99
+
100
+ config = cfg.ModelConfig(
101
+ vocab_size=vocab_size,
102
+ num_layers=num_layers,
103
+ max_seq_len=max_seq_len,
104
+ embedding_dim=embedding_dim,
105
+ attn_config=attn_config,
106
+ ff_config=ff_config,
107
+ pre_attention_norm_config=norm_config,
108
+ pre_ff_norm_config=norm_config,
109
+ final_norm_config=norm_config,
110
+ enable_hlfb=True,
111
+ )
112
+
113
+ return config
@@ -19,11 +19,12 @@ from pathlib import Path
19
19
  import torch
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch.generative.examples.stable_diffusion.clip import CLIP
22
+ import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
23
23
  from ai_edge_torch.generative.examples.stable_diffusion.decoder import Decoder
24
24
  from ai_edge_torch.generative.examples.stable_diffusion.diffusion import Diffusion # NOQA
25
25
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
26
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
 
28
29
 
29
30
  @torch.inference_mode
@@ -36,8 +37,9 @@ def convert_stable_diffusion_to_tflite(
36
37
  image_width: int = 512,
37
38
  ):
38
39
 
39
- clip = CLIP()
40
- clip.load_state_dict(torch.load(clip_ckpt_path))
40
+ clip_model = clip.CLIP(clip.get_model_config())
41
+ loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
42
+ loader.load(clip_model, strict=False)
41
43
 
42
44
  encoder = Encoder()
43
45
  encoder.load_state_dict(torch.load(encoder_ckpt_path))
@@ -59,13 +61,13 @@ def convert_stable_diffusion_to_tflite(
59
61
  )
60
62
 
61
63
  input_latents = encoder(input_image, noise)
62
- context_cond = clip(prompt_tokens)
64
+ context_cond = clip_model(prompt_tokens)
63
65
  context_uncond = torch.zeros_like(context_cond)
64
66
  context = torch.cat([context_cond, context_uncond], axis=0)
65
67
  time_embedding = util.get_time_embedding(timestamp)
66
68
 
67
69
  # CLIP text encoder
68
- ai_edge_torch.signature('encode', clip, (prompt_tokens,)).convert().export(
70
+ ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
69
71
  '/tmp/stable_diffusion/clip.tflite'
70
72
  )
71
73
 
@@ -202,11 +202,6 @@ class UNet(nn.Module):
202
202
 
203
203
  x = self.bottleneck(x, context, time)
204
204
 
205
- # print('x shape:')
206
- # print(list(x.shape))
207
- # print('time shape:')
208
- # print(list(time.shape))
209
-
210
205
  for layers in self.decoders:
211
206
  x = torch.cat((x, skip_connections.pop()), dim=1)
212
207
  x = layers(x, context, time)
@@ -214,199 +209,6 @@ class UNet(nn.Module):
214
209
  return x
215
210
 
216
211
 
217
- # The encoder component.
218
- class UNetEncoder(nn.Module):
219
-
220
- def __init__(self):
221
- super().__init__()
222
- self.time_embedding = TimeEmbedding(320)
223
- self.encoders = nn.ModuleList(
224
- [
225
- SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
226
- SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
227
- SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
228
- SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
229
- SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
230
- SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
231
- SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
232
- SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
233
- SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)),
234
- SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
235
- SwitchSequential(ResidualBlock(1280, 1280)),
236
- SwitchSequential(ResidualBlock(1280, 1280)),
237
- ]
238
- )
239
-
240
- def forward(self, x, context, time):
241
- time_embedding = self.time_embedding(time)
242
- skip_connections = []
243
- for layers in self.encoders:
244
- x = layers(x, context, time_embedding)
245
- skip_connections.append(x)
246
-
247
- return x, skip_connections, time_embedding
248
-
249
-
250
- class UNetBottleNeck(nn.Module):
251
-
252
- def __init__(self):
253
- super().__init__()
254
- self.bottleneck = SwitchSequential(
255
- ResidualBlock(1280, 1280),
256
- AttentionBlock(8, 160),
257
- ResidualBlock(1280, 1280),
258
- )
259
-
260
- def forward(self, x, context, time):
261
- x = self.bottleneck(x, context, time)
262
- # print('shape')
263
- # print(list(x.shape))
264
- return x
265
-
266
-
267
- # Unet decoder.
268
- class UNetDecoder1(nn.Module):
269
-
270
- def __init__(self):
271
- super().__init__()
272
- self.decoders = nn.ModuleList(
273
- [
274
- SwitchSequential(ResidualBlock(2560, 1280)),
275
- SwitchSequential(ResidualBlock(2560, 1280)),
276
- SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
277
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
278
- ]
279
- )
280
-
281
- def forward(self, x, context, time, s9, s10, s11, s12):
282
- x = torch.cat((x, s12), dim=1)
283
- x = self.decoders[0](x, context, time)
284
- x = torch.cat((x, s11), dim=1)
285
- x = self.decoders[1](x, context, time)
286
- x = torch.cat((x, s10), dim=1)
287
- x = self.decoders[2](x, context, time)
288
- x = torch.cat((x, s9), dim=1)
289
- x = self.decoders[3](x, context, time)
290
-
291
- return x
292
-
293
-
294
- class UNetDecoder2(nn.Module):
295
-
296
- def __init__(self):
297
- super().__init__()
298
- self.decoders = nn.ModuleList(
299
- [
300
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
301
- SwitchSequential(
302
- ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
303
- ),
304
- SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
305
- SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
306
- ]
307
- )
308
-
309
- def forward(self, x, context, time, s5, s6, s7, s8):
310
- x = torch.cat((x, s8), dim=1)
311
- x = self.decoders[0](x, context, time)
312
- x = torch.cat((x, s7), dim=1)
313
- x = self.decoders[1](x, context, time)
314
- x = torch.cat((x, s6), dim=1)
315
- x = self.decoders[2](x, context, time)
316
- x = torch.cat((x, s5), dim=1)
317
- x = self.decoders[3](x, context, time)
318
- return x
319
-
320
-
321
- class UNetDecoder3(nn.Module):
322
-
323
- def __init__(self):
324
- super().__init__()
325
- self.decoders = nn.ModuleList(
326
- [
327
- SwitchSequential(
328
- ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
329
- ),
330
- SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
331
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
332
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
333
- ]
334
- )
335
- self.final = FinalLayer(320, 4)
336
-
337
- def forward(self, x, context, time, s1, s2, s3, s4):
338
- x = torch.cat((x, s4), dim=1)
339
- x = self.decoders[0](x, context, time)
340
- x = torch.cat((x, s3), dim=1)
341
- x = self.decoders[1](x, context, time)
342
- x = torch.cat((x, s2), dim=1)
343
- x = self.decoders[2](x, context, time)
344
- x = torch.cat((x, s1), dim=1)
345
- x = self.decoders[3](x, context, time)
346
-
347
- x = self.final(x)
348
- return x
349
-
350
-
351
- class UNetDecoder(nn.Module):
352
-
353
- def __init__(self):
354
- super().__init__()
355
- self.decoders = nn.ModuleList(
356
- [
357
- SwitchSequential(ResidualBlock(2560, 1280)),
358
- SwitchSequential(ResidualBlock(2560, 1280)),
359
- SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
360
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
361
- SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
362
- SwitchSequential(
363
- ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)
364
- ),
365
- SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
366
- SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
367
- SwitchSequential(
368
- ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)
369
- ),
370
- SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
371
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
372
- SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
373
- ]
374
- )
375
- self.final = FinalLayer(320, 4)
376
-
377
- def forward(
378
- self, x, context, time, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12
379
- ):
380
- x = torch.cat((x, s12), dim=1)
381
- x = self.decoders[0](x, context, time)
382
- x = torch.cat((x, s11), dim=1)
383
- x = self.decoders[1](x, context, time)
384
- x = torch.cat((x, s10), dim=1)
385
- x = self.decoders[2](x, context, time)
386
- x = torch.cat((x, s9), dim=1)
387
- x = self.decoders[3](x, context, time)
388
- x = torch.cat((x, s8), dim=1)
389
- x = self.decoders[4](x, context, time)
390
- x = torch.cat((x, s7), dim=1)
391
- x = self.decoders[5](x, context, time)
392
- x = torch.cat((x, s6), dim=1)
393
- x = self.decoders[6](x, context, time)
394
- x = torch.cat((x, s5), dim=1)
395
- x = self.decoders[7](x, context, time)
396
- x = torch.cat((x, s4), dim=1)
397
- x = self.decoders[0](x, context, time)
398
- x = torch.cat((x, s3), dim=1)
399
- x = self.decoders[1](x, context, time)
400
- x = torch.cat((x, s2), dim=1)
401
- x = self.decoders[2](x, context, time)
402
- x = torch.cat((x, s1), dim=1)
403
- x = self.decoders[3](x, context, time)
404
-
405
- x = self.final(x)
406
-
407
- return x
408
-
409
-
410
212
  class FinalLayer(nn.Module):
411
213
 
412
214
  def __init__(self, in_channels, out_channels):
@@ -432,68 +234,6 @@ class Diffusion(nn.Module):
432
234
  @torch.inference_mode
433
235
  def forward(self, latent, context, time):
434
236
  time = self.time_embedding(time)
435
- # print('time:')
436
- # print(list(time.shape))
437
237
  output = self.unet(latent, context, time)
438
238
  output = self.final(output)
439
239
  return output
440
-
441
-
442
- # Calling code as if Diffusion is splitted into two parts.
443
- class DiffusionSplitted(nn.Module):
444
-
445
- def __init__(self):
446
- super().__init__()
447
- self.unet_encoder = UNetEncoder()
448
- self.bottleneck = UNetBottleNeck()
449
- self.unet_decoder1 = UNetDecoder1()
450
- self.unet_decoder2 = UNetDecoder2()
451
- self.unet_decoder3 = UNetDecoder3()
452
-
453
- def get_skip_connections(self, latent, context, time):
454
- _, skip_connections, _ = self.unet_encoder(latent, context, time)
455
- return skip_connections
456
-
457
- def forward(self, latent, context, time):
458
- output, skip_connections, time = self.unet_encoder(latent, context, time)
459
- # print("output shape of unet encoder...")
460
- # print(list(output.shape))
461
- # print("output shape of time...")
462
- # print(list(time.shape))
463
- output = self.bottleneck(output, context, time)
464
- # print("output shape of bn")
465
- # print(list(output.shape))
466
- output = self.unet_decoder1(
467
- output,
468
- context,
469
- time,
470
- skip_connections[8],
471
- skip_connections[9],
472
- skip_connections[10],
473
- skip_connections[11],
474
- )
475
- # print("output shape of d1:")
476
- # print(list(output.shape))
477
-
478
- output = self.unet_decoder2(
479
- output,
480
- context,
481
- time,
482
- skip_connections[4],
483
- skip_connections[5],
484
- skip_connections[6],
485
- skip_connections[7],
486
- )
487
-
488
- # print("output shape of d2:")
489
- # print(list(output.shape))
490
- output = self.unet_decoder3(
491
- output,
492
- context,
493
- time,
494
- skip_connections[0],
495
- skip_connections[1],
496
- skip_connections[2],
497
- skip_connections[3],
498
- )
499
- return output
@@ -20,11 +20,11 @@ import torch
20
20
  from torch import nn
21
21
  import torch.nn.functional as F
22
22
 
23
- from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention # NOQA
24
- from ai_edge_torch.generative.layers.attention import scaled_dot_product_attention_with_hlfb # NOQA
25
23
  import ai_edge_torch.generative.layers.builder as builder
26
24
  from ai_edge_torch.generative.layers.kv_cache import KVCache
27
25
  import ai_edge_torch.generative.layers.model_config as cfg
26
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
27
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
28
28
 
29
29
 
30
30
  class EncoderDecoderBlock(nn.Module):
@@ -0,0 +1,161 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has basic transformer block (w/ externalized KV-Cache).
16
+
17
+ from typing import List, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch_xla
23
+
24
+ import ai_edge_torch
25
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
+ import ai_edge_torch.generative.layers.builder as builder
27
+ from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
+ import ai_edge_torch.generative.layers.model_config as cfg
29
+
30
+ RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
+
32
+
33
+ class ToyModelWithExternalKV(torch.nn.Module):
34
+
35
+ def __init__(self, config: cfg.ModelConfig) -> None:
36
+ super().__init__()
37
+ self.lm_head = nn.Linear(
38
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
39
+ )
40
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
41
+ self.transformer_blocks = nn.ModuleList(
42
+ TransformerBlock(config) for _ in range(config.num_layers)
43
+ )
44
+ self.final_norm = builder.build_norm(
45
+ config.embedding_dim,
46
+ config.final_norm_config,
47
+ )
48
+ self.rope_cache = attn_utils.build_rope_cache(
49
+ size=config.max_seq_len,
50
+ dim=int(config.attn_config.rotary_percentage * config.head_dim),
51
+ base=10_000,
52
+ condense_ratio=1,
53
+ dtype=torch.float32,
54
+ device=torch.device('cpu'),
55
+ )
56
+ self.mask_cache = attn_utils.build_causal_mask_cache(
57
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
58
+ )
59
+ self.config = config
60
+
61
+ def forward(
62
+ self,
63
+ idx: torch.Tensor,
64
+ input_pos: torch.Tensor,
65
+ k_caches: torch.Tensor,
66
+ v_caches: torch.Tensor,
67
+ ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
68
+ x = self.tok_embedding(idx)
69
+ cos, sin = self.rope_cache
70
+ cos = cos.index_select(0, input_pos)
71
+ sin = sin.index_select(0, input_pos)
72
+ mask = self.mask_cache.index_select(2, input_pos)
73
+ mask = mask[:, :, :, : self.config.max_seq_len]
74
+
75
+ for i, block in enumerate(self.transformer_blocks):
76
+ input_k, input_v = k_caches[i], v_caches[i]
77
+ x, (updated_k, updated_v) = block(
78
+ x, (cos, sin), mask, input_pos, (input_k, input_v)
79
+ )
80
+ k_caches[i], v_caches[i] = updated_k, updated_v
81
+
82
+ x = self.final_norm(x)
83
+ return self.lm_head(x), k_caches, v_caches
84
+
85
+
86
+ def _export_stablehlo_mlir(model, args):
87
+ ep = torch.export.export(model, args)
88
+ stablehlo_gm = torch_xla.stablehlo.exported_program_to_stablehlo(ep)
89
+ return stablehlo_gm.get_stablehlo_text()
90
+
91
+
92
+ def get_model_config() -> cfg.ModelConfig:
93
+ attn_config = cfg.AttentionConfig(
94
+ num_heads=32, num_query_groups=4, rotary_percentage=1.0
95
+ )
96
+ ff_config = cfg.FeedForwardConfig(
97
+ type=cfg.FeedForwardType.GATED,
98
+ activation=cfg.ActivationType.SILU,
99
+ intermediate_size=256,
100
+ )
101
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
102
+ config = cfg.ModelConfig(
103
+ vocab_size=150,
104
+ num_layers=2,
105
+ max_seq_len=100,
106
+ embedding_dim=128,
107
+ attn_config=attn_config,
108
+ ff_config=ff_config,
109
+ pre_attention_norm_config=norm_config,
110
+ pre_ff_norm_config=norm_config,
111
+ final_norm_config=norm_config,
112
+ enable_hlfb=True,
113
+ )
114
+ return config
115
+
116
+
117
+ def get_sample_prefill_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
118
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
119
+ input_pos = torch.arange(0, 100)
120
+ return idx, input_pos
121
+
122
+
123
+ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
124
+ idx = torch.tensor([[1]], dtype=torch.long)
125
+ input_pos = torch.tensor([10])
126
+ return idx, input_pos
127
+
128
+
129
+ def define_and_run() -> None:
130
+ dump_mlir = False
131
+
132
+ config = get_model_config()
133
+ model = ToyModelWithExternalKV(config)
134
+ print('running an inference')
135
+ k_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
136
+ v_caches = torch.zeros((2, 1, 100, 4, 4), dtype=torch.float32)
137
+
138
+ idx, input_pos = get_sample_prefill_inputs()
139
+ decode_idx, decode_input_pos = get_sample_decode_inputs()
140
+ print(model.forward(idx, input_pos, k_caches, v_caches))
141
+
142
+ if dump_mlir:
143
+ mlir_text = _export_stablehlo_mlir(model, (idx, input_pos, k_caches, v_caches))
144
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
145
+ f.write(mlir_text)
146
+
147
+ # Convert model to tflite with 2 signatures (prefill + decode).
148
+ # TODO(b/344014416): currently conversion will fail, because we generate int64 index
149
+ # in dynamic update slice op.
150
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
151
+ edge_model = (
152
+ ai_edge_torch.signature('prefill', model, (idx, input_pos, k_caches, v_caches))
153
+ .signature('decode', model, (decode_idx, decode_input_pos, k_caches, v_caches))
154
+ .convert()
155
+ )
156
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
157
+
158
+
159
+ if __name__ == '__main__':
160
+ with torch.inference_mode():
161
+ define_and_run()
@@ -14,7 +14,6 @@
14
14
  # ==============================================================================
15
15
  # Common building blocks for Attention layer.
16
16
 
17
- import math
18
17
  from typing import Optional, Tuple
19
18
 
20
19
  import torch
@@ -25,101 +24,8 @@ import ai_edge_torch.generative.layers.builder as builder
25
24
  from ai_edge_torch.generative.layers.kv_cache import KVCache
26
25
  import ai_edge_torch.generative.layers.model_config as cfg
27
26
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
28
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
29
-
30
-
31
- def scaled_dot_product_attention(
32
- q: torch.Tensor,
33
- k: torch.Tensor,
34
- v: torch.Tensor,
35
- head_size: int,
36
- mask: Optional[torch.Tensor] = None,
37
- scale: Optional[float] = None,
38
- ):
39
- """Scaled dot product attention.
40
-
41
- Args:
42
- q (torch.Tensor): Query tensor, with shape [B, T, N, H].
43
- k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
44
- v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
45
- head_size (int): head dimension.
46
- mask (torch.Tensor): the optional mask tensor.
47
-
48
- Returns:
49
- The output tensor of scaled_dot_product_attention.
50
- """
51
-
52
- if scale is None:
53
- scale = 1.0 / math.sqrt(head_size)
54
-
55
- q = q.transpose(1, 2)
56
- k = k.transpose(1, 2)
57
- v = v.transpose(1, 2)
58
- if q.size() != k.size():
59
- # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
60
- k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
61
- v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
62
- y = F.scaled_dot_product_attention(
63
- q,
64
- k,
65
- v,
66
- attn_mask=mask,
67
- dropout_p=0.0,
68
- is_causal=mask is None,
69
- scale=scale,
70
- )
71
- return y.transpose(1, 2)
72
-
73
-
74
- def scaled_dot_product_attention_with_hlfb(
75
- q: torch.Tensor,
76
- k: torch.Tensor,
77
- v: torch.Tensor,
78
- head_size: int,
79
- mask: Optional[torch.Tensor] = None,
80
- scale: Optional[float] = None,
81
- ):
82
- """Scaled dot product attention with high-level function boundary enabled.
83
-
84
- Args:
85
- q (torch.Tensor): Query tensor, with shape [B, T, N, H].
86
- k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
87
- v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
88
- head_size (int): head dimension.
89
- mask (torch.Tensor): the optional mask tensor.
90
-
91
- Returns:
92
- The output tensor of scaled_dot_product_attention.
93
- """
94
-
95
- if scale is None:
96
- scale = 1.0 / math.sqrt(head_size)
97
-
98
- builder = StableHLOCompositeBuilder(
99
- name="odml.scaled_dot_product_attention", attr={"scale": scale}
100
- )
101
- q, k, v, mask = builder.mark_inputs(q, k, v, mask)
102
-
103
- q = q.transpose(1, 2)
104
- k = k.transpose(1, 2)
105
- v = v.transpose(1, 2)
106
- if q.size() != k.size():
107
- # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
108
- k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
109
- v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
110
- y = F.scaled_dot_product_attention(
111
- q,
112
- k,
113
- v,
114
- attn_mask=mask,
115
- dropout_p=0.0,
116
- is_causal=mask is None,
117
- scale=scale,
118
- )
119
-
120
- result = y.transpose(1, 2)
121
- result = builder.mark_outputs(result)
122
- return result
27
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
28
+ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
123
29
 
124
30
 
125
31
  class TransformerBlock(nn.Module):
@@ -151,7 +57,7 @@ class TransformerBlock(nn.Module):
151
57
  def forward(
152
58
  self,
153
59
  x: torch.Tensor,
154
- rope: Tuple[torch.Tensor, torch.Tensor],
60
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
155
61
  mask: Optional[torch.Tensor] = None,
156
62
  input_pos: Optional[torch.Tensor] = None,
157
63
  ) -> torch.Tensor:
@@ -182,7 +88,6 @@ class TransformerBlock(nn.Module):
182
88
  return output
183
89
 
184
90
 
185
- # CausalSelfAttention which can support MHQ, MQA or GQA.
186
91
  class CausalSelfAttention(nn.Module):
187
92
 
188
93
  def __init__(
@@ -229,11 +134,12 @@ class CausalSelfAttention(nn.Module):
229
134
  def forward(
230
135
  self,
231
136
  x: torch.Tensor,
232
- rope: Tuple[torch.Tensor, torch.Tensor],
137
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
233
138
  mask: Optional[torch.Tensor] = None,
234
139
  input_pos: Optional[torch.Tensor] = None,
235
140
  ) -> torch.Tensor:
236
- """Forward function of the CausalSelfAttention layer.
141
+ """Forward function of the CausalSelfAttention layer, which can support
142
+ MQA, GQA and MHA.
237
143
 
238
144
  Args:
239
145
  x (torch.Tensor): the input tensor.
@@ -253,28 +159,35 @@ class CausalSelfAttention(nn.Module):
253
159
  # Assemble into a number of query groups to support MHA, MQA and GQA.
254
160
  q_per_kv = self.config.num_heads // self.config.num_query_groups
255
161
  total_qkv = q_per_kv + 2 # Each group has >=1 queries, 1 key, and 1 value.
256
- qkv = qkv.view(
257
- B, T, self.config.num_query_groups, total_qkv, self.head_dim
258
- ) # (B, T, num_query_groups, total_qkv, head_dim)
162
+ if self.config.qkv_transpose_before_split:
163
+ qkv = qkv.view(
164
+ B, T, total_qkv, self.config.num_query_groups, self.head_dim
165
+ ) # (B, T, total_qkv, num_query_groups, head_dim)
166
+ qkv_axis = -3
167
+ else:
168
+ qkv = qkv.view(
169
+ B, T, self.config.num_query_groups, total_qkv, self.head_dim
170
+ ) # (B, T, num_query_groups, total_qkv, head_dim)
171
+ qkv_axis = -2
259
172
 
260
173
  # Split batched computation into three.
261
- q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
262
-
174
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=qkv_axis)
263
175
  q = q.reshape(B, T, -1, self.head_dim)
264
176
  k = k.reshape(B, T, -1, self.head_dim)
265
177
  v = v.reshape(B, T, -1, self.head_dim)
266
178
 
267
179
  # Compute rotary positional embedding for query and key.
268
180
  n_elem = int(self.config.rotary_percentage * self.head_dim)
269
- cos, sin = rope
270
- q_roped = rotary_pos_emb.apply_rope(
271
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
272
- )
273
- k_roped = rotary_pos_emb.apply_rope(
274
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
275
- )
276
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
277
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
181
+ if n_elem > 0:
182
+ cos, sin = rope
183
+ q_roped = rotary_pos_emb.apply_rope(
184
+ q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
185
+ )
186
+ k_roped = rotary_pos_emb.apply_rope(
187
+ k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
188
+ )
189
+ q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
190
+ k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
278
191
 
279
192
  if self.kv_cache is not None:
280
193
  # TODO(haoliang): Handle when execeeding max sequence length.
@@ -97,6 +97,10 @@ def _get_activation(type_: cfg.ActivationType):
97
97
  return F.gelu
98
98
  elif type_ == cfg.ActivationType.GELU_TANH:
99
99
  return lambda x: F.gelu(x, approximate="tanh")
100
+ elif type_ == cfg.ActivationType.GELU_QUICK:
101
+ # GELU approximation that is fast but somewhat inaccurate.
102
+ # See: https://github.com/hendrycks/GELUs
103
+ return lambda x: x * F.sigmoid(1.702 * x)
100
104
  elif type_ == cfg.ActivationType.RELU:
101
105
  return F.relu
102
106
  else:
@@ -27,6 +27,7 @@ class ActivationType(enum.Enum):
27
27
  SILU = enum.auto()
28
28
  GELU = enum.auto()
29
29
  GELU_TANH = enum.auto()
30
+ GELU_QUICK = enum.auto()
30
31
  RELU = enum.auto()
31
32
 
32
33
 
@@ -46,7 +47,7 @@ class FeedForwardType(enum.Enum):
46
47
 
47
48
  # `output = linear(act(linear(x)))`.
48
49
  SEQUENTIAL = enum.auto()
49
- # `output = linear(act(linear(x)) * lienar(x))`.
50
+ # `output = linear_2(act(linear_1(x)) * lienar_3(x))`.
50
51
  GATED = enum.auto()
51
52
 
52
53
 
@@ -60,6 +61,9 @@ class AttentionConfig:
60
61
  num_query_groups: Optional[int]
61
62
  # Percentage of Rotary Positional Embedding added Q and K projections.
62
63
  rotary_percentage: Optional[float] = None
64
+ # Whether to transpose the query groups of qkv bundled tensor before
65
+ # splitting into separated tensors.
66
+ qkv_transpose_before_split: bool = False
63
67
  # Whether to use bias with Query, Key, and Value projection.
64
68
  qkv_use_bias: bool = False
65
69
  # Whether to use bias with attention output projection.
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Implements scaled dot product attention.
16
+
17
+ import math
18
+ from typing import Optional
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
24
+
25
+
26
+ def scaled_dot_product_attention(
27
+ q: torch.Tensor,
28
+ k: torch.Tensor,
29
+ v: torch.Tensor,
30
+ head_size: int,
31
+ mask: Optional[torch.Tensor] = None,
32
+ scale: Optional[float] = None,
33
+ ):
34
+ """Scaled dot product attention.
35
+
36
+ Args:
37
+ q (torch.Tensor): Query tensor, with shape [B, T, N, H].
38
+ k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
39
+ v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
40
+ head_size (int): head dimension.
41
+ mask (torch.Tensor): the optional mask tensor.
42
+
43
+ Returns:
44
+ The output tensor of scaled_dot_product_attention.
45
+ """
46
+
47
+ if scale is None:
48
+ scale = 1.0 / math.sqrt(head_size)
49
+
50
+ q = q.transpose(1, 2)
51
+ k = k.transpose(1, 2)
52
+ v = v.transpose(1, 2)
53
+ if q.size() != k.size():
54
+ # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
55
+ k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
56
+ v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
57
+ y = F.scaled_dot_product_attention(
58
+ q,
59
+ k,
60
+ v,
61
+ attn_mask=mask,
62
+ dropout_p=0.0,
63
+ is_causal=mask is None,
64
+ scale=scale,
65
+ )
66
+ return y.transpose(1, 2)
67
+
68
+
69
+ def scaled_dot_product_attention_with_hlfb(
70
+ q: torch.Tensor,
71
+ k: torch.Tensor,
72
+ v: torch.Tensor,
73
+ head_size: int,
74
+ mask: Optional[torch.Tensor] = None,
75
+ scale: Optional[float] = None,
76
+ ):
77
+ """Scaled dot product attention with high-level function boundary enabled.
78
+
79
+ Args:
80
+ q (torch.Tensor): Query tensor, with shape [B, T, N, H].
81
+ k (torch.Tensor): Key tensor, with shape [B, T, KV_LEN, H].
82
+ v (torch.Tensor): Value tensor, with shape [B, T, KV_LEN, H].
83
+ head_size (int): head dimension.
84
+ mask (torch.Tensor): the optional mask tensor.
85
+
86
+ Returns:
87
+ The output tensor of scaled_dot_product_attention.
88
+ """
89
+
90
+ if scale is None:
91
+ scale = 1.0 / math.sqrt(head_size)
92
+
93
+ builder = StableHLOCompositeBuilder(
94
+ name="odml.scaled_dot_product_attention", attr={"scale": scale}
95
+ )
96
+ q, k, v, mask = builder.mark_inputs(q, k, v, mask)
97
+
98
+ q = q.transpose(1, 2)
99
+ k = k.transpose(1, 2)
100
+ v = v.transpose(1, 2)
101
+ if q.size() != k.size():
102
+ # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
103
+ k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
104
+ v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
105
+ y = F.scaled_dot_product_attention(
106
+ q,
107
+ k,
108
+ v,
109
+ attn_mask=mask,
110
+ dropout_p=0.0,
111
+ is_causal=mask is None,
112
+ scale=scale,
113
+ )
114
+
115
+ result = y.transpose(1, 2)
116
+ result = builder.mark_outputs(result)
117
+ return result
@@ -69,10 +69,16 @@ def load_pytorch_statedict(full_path: str):
69
69
  Raises:
70
70
  ValueError: If no tensors are loaded from the provided directory or file.
71
71
  """
72
- pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
73
72
  files = []
74
- for file in glob.glob(pattern):
75
- files.append(file)
73
+ patterns = []
74
+ if os.path.isdir(full_path):
75
+ patterns.append(os.path.join(full_path, "*.bin"))
76
+ patterns.append(os.path.join(full_path, "*.pt"))
77
+ else:
78
+ patterns.append(full_path)
79
+ for pattern in patterns:
80
+ for file in glob.glob(pattern):
81
+ files.append(file)
76
82
 
77
83
  tensors = {}
78
84
  for file in files:
@@ -93,18 +99,20 @@ class ModelLoader:
93
99
 
94
100
  @dataclass
95
101
  class TensorNames:
96
- attn_query_proj: str
97
- attn_key_proj: str
98
- attn_value_proj: str
99
- attn_output_proj: str
100
-
101
- ff_up_proj: str
102
- ff_down_proj: str
102
+ attn_query_proj: str = None
103
+ attn_key_proj: str = None
104
+ attn_value_proj: str = None
105
+ attn_fused_qkv_proj: str = None
106
+ attn_output_proj: str = None
107
+
108
+ ff_up_proj: str = None
109
+ ff_down_proj: str = None
103
110
  ff_gate_proj: str = None
104
111
 
105
112
  pre_attn_norm: str = None
106
113
  pre_ff_norm: str = None
107
114
  embedding: str = None
115
+ embedding_position: str = None
108
116
  final_norm: str = None
109
117
  lm_head: str = None
110
118
 
@@ -129,6 +137,10 @@ class ModelLoader:
129
137
  strict (bool, optional): Whether the converted keys are strictly
130
138
  matched. Defaults to True.
131
139
 
140
+ Returns:
141
+ missing_keys (List[str]): a list of str containing the missing keys
142
+ unexpected_keys (List[str]): a list of str containing the unexpected keys
143
+
132
144
  Raises:
133
145
  ValueError: If conversion results in unmapped tensors and strict mode is
134
146
  enabled.
@@ -139,6 +151,10 @@ class ModelLoader:
139
151
  converted_state["tok_embedding.weight"] = state.pop(
140
152
  f"{self._names.embedding}.weight"
141
153
  )
154
+ if self._names.embedding_position is not None:
155
+ converted_state["tok_embedding_position"] = state.pop(
156
+ f"{self._names.embedding_position}"
157
+ )
142
158
  if self._names.lm_head is not None:
143
159
  converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
144
160
  if model.config.lm_head_use_bias:
@@ -158,7 +174,7 @@ class ModelLoader:
158
174
  raise ValueError(
159
175
  f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
160
176
  )
161
- model.load_state_dict(converted_state, strict=strict)
177
+ return model.load_state_dict(converted_state, strict=strict)
162
178
 
163
179
  def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
164
180
  """A best effort method for finding appropriate state loader.
@@ -172,13 +188,15 @@ class ModelLoader:
172
188
  if os.path.isdir(self._file_name):
173
189
  if glob.glob(os.path.join(self._file_name, "*.safetensors")):
174
190
  return load_safetensors
175
- if glob.glob(os.path.join(self._file_name, "*.bin")):
191
+ if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
192
+ os.path.join(self._file_name, "*.pt")
193
+ ):
176
194
  return load_pytorch_statedict
177
195
 
178
196
  if self._file_name.endswith(".safetensors"):
179
197
  return load_safetensors
180
198
 
181
- if self._file_name.endswith(".bin"):
199
+ if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
182
200
  return load_pytorch_statedict
183
201
 
184
202
  raise ValueError(f"File format not supported.")
@@ -225,22 +243,33 @@ class ModelLoader:
225
243
  converted_state: Dict[str, torch.Tensor],
226
244
  ):
227
245
  prefix = f"transformer_blocks.{idx}"
228
- q_name = self._names.attn_query_proj.format(idx)
229
- k_name = self._names.attn_key_proj.format(idx)
230
- v_name = self._names.attn_value_proj.format(idx)
231
- converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
232
- config,
233
- state.pop(f"{q_name}.weight"),
234
- state.pop(f"{k_name}.weight"),
235
- state.pop(f"{v_name}.weight"),
236
- )
237
- if config.attn_config.qkv_use_bias:
238
- converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
246
+ if self._names.attn_fused_qkv_proj:
247
+ fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
248
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
249
+ f"{fused_qkv_name}.weight"
250
+ )
251
+ else:
252
+ q_name = self._names.attn_query_proj.format(idx)
253
+ k_name = self._names.attn_key_proj.format(idx)
254
+ v_name = self._names.attn_value_proj.format(idx)
255
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
239
256
  config,
240
- state.pop(f"{q_name}.bias"),
241
- state.pop(f"{k_name}.bias"),
242
- state.pop(f"{v_name}.bias"),
257
+ state.pop(f"{q_name}.weight"),
258
+ state.pop(f"{k_name}.weight"),
259
+ state.pop(f"{v_name}.weight"),
243
260
  )
261
+ if config.attn_config.qkv_use_bias:
262
+ if self._names.attn_fused_qkv_proj:
263
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
264
+ f"{fused_qkv_name}.bias"
265
+ )
266
+ else:
267
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
268
+ config,
269
+ state.pop(f"{q_name}.bias"),
270
+ state.pop(f"{k_name}.bias"),
271
+ state.pop(f"{v_name}.bias"),
272
+ )
244
273
 
245
274
  o_name = self._names.attn_output_proj.format(idx)
246
275
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240603
3
+ Version: 0.2.0.dev20240605
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -6,7 +6,7 @@ ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBK
6
6
  ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
7
7
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=Ll2nNwufjcV5nSruQPXiloq7F1E7pWJ2T5clXmy1lk8,2825
8
8
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
9
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=wHVWNNMu5h_ya6GnnJn0cNif9xmdSqr8Vm-R7lllxZM,6213
9
+ ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=2yqUwJJ2R233_X9FNMOP9oYRTTzH34TR_BIUj-wfnKw,7080
10
10
  ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py,sha256=76XYoIlFDgrzp5QemoaEalPFcEbfszkEH_PLvO1ASCk,2607
11
11
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
12
12
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
@@ -21,8 +21,8 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
21
21
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
23
23
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
24
- ai_edge_torch/convert/test/test_convert.py,sha256=USduDO6PaO3nlA82jMihTct--mCU_ugILZDin00lcJ8,8092
25
- ai_edge_torch/convert/test/test_convert_composites.py,sha256=SrVn_cEMtQhYYCMOUKK0K7M57MQNQX-lOUwieln0HGA,6616
24
+ ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
+ ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
26
26
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
27
27
  ai_edge_torch/debug/__init__.py,sha256=TKvmnjVk3asvYcVh6C-LPr6srgAF_nppSAupWEXqwPY,707
28
28
  ai_edge_torch/debug/culprit.py,sha256=vklaxBUfINdo44OsH7csILK70N41gEThCGchGEfbTZw,12789
@@ -40,10 +40,10 @@ ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlY
40
40
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
41
41
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=KR1Ci4rlJeeGfsFRliCxUve9K7RTJLZfTRMgFtfQ4MU,2434
44
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=6REAYy1Bv-Iv5zcmA_m_W6fH6jt5a3IS6Vge18jS_Wo,3633
43
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=yUCJemEh4n8ez-yLgVU0HZAki-PZ9nY04DFjgpx9PUc,3698
44
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=MI73RjOeD4Kh7AL0j5_QXiZq-rl_qCdibSE6eCQCyeY,3804
45
45
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=AgVAdUbSkHXONVUjAyBQEXhIUUlinf9kNljcBpWnj3A,3276
46
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=nq94VpQ103eOimnmdyg7u3Xk1LH1IxGlmIbr2AttRIk,16224
46
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=TfbfsmuKoGsBENF9fYIAN_SMEQNhj-kjNdqQXFJGxpg,7784
47
47
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=L6hLaMQGb8-_BwSvTLIuDnZwfTqn0K4swBUjfPnYWZo,2341
48
48
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
49
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
@@ -56,22 +56,24 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
56
56
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
57
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
58
58
  ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
59
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=anR99IrzR21x6yswFHYG5QQtPDZ7rVicf6STfMp54fU,8998
59
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
60
60
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
61
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=EV07_MEG3fv9g0ZGu9gbBd5BjjrGkxCT1pv7dvhz4TI,3791
62
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=rzL5h7Z5DIEgfpc1pWgYHdKt2aR8ha_CUqTKQBSPBaU,5521
62
63
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=MUr6fSj2hBuYSlNbZtrBBpzqB_0WY-l_xYcd_TFFUjY,4831
63
64
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
65
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
65
66
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=hVGpuI8gpj4Rn9k4otsRE22MSLFHBDlUOgioY6Ru6VI,5629
66
67
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
- ai_edge_torch/generative/layers/attention.py,sha256=PxixRZb00v5BQkWbDwaJgke4Rd5LwzdWe0zH9SG4Tj0,9127
68
+ ai_edge_torch/generative/layers/attention.py,sha256=zNIBXxCOA5Mz_F_dfBbKpIovhtcB6q5a-i8oAxls1d0,7071
68
69
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
69
- ai_edge_torch/generative/layers/builder.py,sha256=ZSBVLv5EOtCkSW_Z8C2Hd7jN52nIAA2as1-qpmHGbCg,3201
70
+ ai_edge_torch/generative/layers/builder.py,sha256=WLTeDId9t3Xwt0h1zxzqoYyFvfrNzPKLskcl39q8Aqw,3403
70
71
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
71
72
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
72
- ai_edge_torch/generative/layers/model_config.py,sha256=KpJRIHV5BJH8QOa7h6LXLZyC7UDWgbCEsw0CvArz49Q,4064
73
+ ai_edge_torch/generative/layers/model_config.py,sha256=2zT9nyoyuuyk5ziiww0VSJ6_JO7pDf7uOYbO9O3OQc4,4249
73
74
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
74
75
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
75
77
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
78
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
77
79
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
@@ -84,7 +86,7 @@ ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-y
84
86
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
85
87
  ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
86
88
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
87
- ai_edge_torch/generative/utilities/loader.py,sha256=c-ZOIDBVnat_5l2W5sWU7HQm7CL-wducS8poSu5PlUg,10107
89
+ ai_edge_torch/generative/utilities/loader.py,sha256=r-_hSanSjLZ_YXFpZUb0Up94u5F8JHp70Vf2nlONPSg,11269
88
90
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
89
91
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
90
92
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
@@ -100,8 +102,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
100
102
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
101
103
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
102
104
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
103
- ai_edge_torch_nightly-0.2.0.dev20240603.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
104
- ai_edge_torch_nightly-0.2.0.dev20240603.dist-info/METADATA,sha256=lSrdb2AHtqNqaAx0xXSxvAU2VZBpRslbse2gWrwNEo0,1748
105
- ai_edge_torch_nightly-0.2.0.dev20240603.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
106
- ai_edge_torch_nightly-0.2.0.dev20240603.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
107
- ai_edge_torch_nightly-0.2.0.dev20240603.dist-info/RECORD,,
105
+ ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
106
+ ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/METADATA,sha256=GJzwmKkM4T0H-vTvMyoxiD80WfppEpE_sd2Ip4aSbgM,1748
107
+ ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
108
+ ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
109
+ ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/RECORD,,