ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  8. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  9. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  10. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  11. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  12. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  16. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  17. ai_edge_torch/generative/layers/attention.py +77 -73
  18. ai_edge_torch/generative/layers/builder.py +5 -3
  19. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  20. ai_edge_torch/generative/layers/model_config.py +38 -19
  21. ai_edge_torch/generative/layers/normalization.py +158 -0
  22. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  23. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  24. ai_edge_torch/generative/test/test_loader.py +1 -1
  25. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  26. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  27. ai_edge_torch/generative/test/utils.py +54 -0
  28. ai_edge_torch/generative/utilities/loader.py +15 -15
  29. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  31. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  32. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  33. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  34. ai_edge_torch/version.py +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
  37. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  38. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  40. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  41. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  42. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  43. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  44. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  45. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  46. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  47. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
23
21
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
22
+ from ai_edge_torch.generative.layers import kv_cache
23
+ from ai_edge_torch.generative.test import utils as test_utils
25
24
  import numpy as np
26
25
  import torch
27
26
 
@@ -49,22 +48,32 @@ class TestModelConversion(googletest.TestCase):
49
48
  )
50
49
  def test_toy_model_with_kv_cache(self):
51
50
  config = toy_model_with_kv_cache.get_model_config()
52
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
53
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
51
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
52
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
54
53
  [10], dtype=torch.int64
55
54
  )
56
-
57
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
55
+ kv = kv_cache.KVCache.from_model_config(config)
56
+
57
+ edge_model = ai_edge_torch.convert(
58
+ pytorch_model,
59
+ sample_kwargs={
60
+ "tokens": tokens,
61
+ "input_pos": input_pos,
62
+ "kv_cache": kv,
63
+ },
64
+ )
58
65
  edge_model.set_interpreter_builder(
59
66
  self._interpreter_builder(edge_model.tflite_model())
60
67
  )
61
68
 
62
69
  self.assertTrue(
63
- model_coverage.compare_tflite_torch(
70
+ test_utils.compare_tflite_torch(
64
71
  edge_model,
65
72
  pytorch_model,
66
- (idx, input_pos),
67
- num_valid_inputs=1,
73
+ tokens,
74
+ input_pos,
75
+ kv,
76
+ signature_name="serving_default",
68
77
  atol=1e-5,
69
78
  rtol=1e-5,
70
79
  )
@@ -77,22 +86,32 @@ class TestModelConversion(googletest.TestCase):
77
86
  def test_toy_model_with_kv_cache_with_hlfb(self):
78
87
  config = toy_model_with_kv_cache.get_model_config()
79
88
  config.enable_hlfb = True
80
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
81
- idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
89
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
90
+ tokens, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
82
91
  [10], dtype=torch.int64
83
92
  )
84
-
85
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
93
+ kv = kv_cache.KVCache.from_model_config(config)
94
+
95
+ edge_model = ai_edge_torch.convert(
96
+ pytorch_model,
97
+ sample_kwargs={
98
+ "tokens": tokens,
99
+ "input_pos": input_pos,
100
+ "kv_cache": kv,
101
+ },
102
+ )
86
103
  edge_model.set_interpreter_builder(
87
104
  self._interpreter_builder(edge_model.tflite_model())
88
105
  )
89
106
 
90
107
  self.assertTrue(
91
- model_coverage.compare_tflite_torch(
108
+ test_utils.compare_tflite_torch(
92
109
  edge_model,
93
110
  pytorch_model,
94
- (idx, input_pos),
95
- num_valid_inputs=1,
111
+ tokens,
112
+ input_pos,
113
+ kv,
114
+ signature_name="serving_default",
96
115
  atol=1e-5,
97
116
  rtol=1e-5,
98
117
  )
@@ -104,7 +123,7 @@ class TestModelConversion(googletest.TestCase):
104
123
  )
105
124
  def test_tiny_llama_multisig(self):
106
125
  config = tiny_llama.get_fake_model_config()
107
- pytorch_model = tiny_llama.TinyLLamma(config).eval()
126
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
108
127
 
109
128
  # prefill
110
129
  seq_len = 10
@@ -117,37 +136,56 @@ class TestModelConversion(googletest.TestCase):
117
136
  decode_token = torch.tensor([[1]], dtype=torch.long)
118
137
  decode_input_pos = torch.tensor([5], dtype=torch.int64)
119
138
 
139
+ kv = kv_cache.KVCache.from_model_config(config)
140
+
120
141
  edge_model = (
121
142
  ai_edge_torch.signature(
122
- "prefill", pytorch_model, (prefill_tokens, prefill_input_pos)
143
+ "prefill",
144
+ pytorch_model,
145
+ sample_kwargs={
146
+ "tokens": prefill_tokens,
147
+ "input_pos": prefill_input_pos,
148
+ "kv_cache": kv,
149
+ },
150
+ )
151
+ .signature(
152
+ "decode",
153
+ pytorch_model,
154
+ sample_kwargs={
155
+ "tokens": decode_token,
156
+ "input_pos": decode_input_pos,
157
+ "kv_cache": kv,
158
+ },
123
159
  )
124
- .signature("decode", pytorch_model, (decode_token, decode_input_pos))
125
160
  .convert()
126
161
  )
127
162
  edge_model.set_interpreter_builder(
128
163
  self._interpreter_builder(edge_model.tflite_model())
129
164
  )
130
165
 
131
- copied_model = copy.deepcopy(pytorch_model)
132
- copied_edge = copy.deepcopy(edge_model)
133
-
134
166
  self.assertTrue(
135
- model_coverage.compare_tflite_torch(
167
+ test_utils.compare_tflite_torch(
136
168
  edge_model,
137
169
  pytorch_model,
138
- (prefill_tokens, prefill_input_pos),
170
+ prefill_tokens,
171
+ prefill_input_pos,
172
+ kv,
139
173
  signature_name="prefill",
140
- num_valid_inputs=1,
174
+ atol=1e-5,
175
+ rtol=1e-5,
141
176
  )
142
177
  )
143
178
 
144
179
  self.assertTrue(
145
- model_coverage.compare_tflite_torch(
146
- copied_edge,
147
- copied_model,
148
- (decode_token, decode_input_pos),
180
+ test_utils.compare_tflite_torch(
181
+ edge_model,
182
+ pytorch_model,
183
+ decode_token,
184
+ decode_input_pos,
185
+ kv,
149
186
  signature_name="decode",
150
- num_valid_inputs=1,
187
+ atol=1e-5,
188
+ rtol=1e-5,
151
189
  )
152
190
  )
153
191
 
@@ -12,16 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Testing model conversion for a few gen-ai models.
16
- import copy
15
+
16
+ """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
- from ai_edge_torch.generative.examples.phi2 import phi2
22
- from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
- from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
- from ai_edge_torch.testing import model_coverage
20
+ from ai_edge_torch.generative.examples.gemma import gemma
21
+ from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.phi import phi2
23
+ from ai_edge_torch.generative.layers import kv_cache
24
+ from ai_edge_torch.generative.test import utils as test_utils
25
25
  import numpy as np
26
26
  import torch
27
27
 
@@ -55,18 +55,28 @@ class TestModelConversion(googletest.TestCase):
55
55
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
56
56
  tokens[0, :4] = idx
57
57
  input_pos = torch.arange(0, 10)
58
-
59
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
58
+ kv = kv_cache.KVCache.from_model_config(config)
59
+
60
+ edge_model = ai_edge_torch.convert(
61
+ model,
62
+ sample_kwargs={
63
+ "tokens": tokens,
64
+ "input_pos": input_pos,
65
+ "kv_cache": kv,
66
+ },
67
+ )
60
68
  edge_model.set_interpreter_builder(
61
69
  self._interpreter_builder(edge_model.tflite_model())
62
70
  )
63
71
 
64
72
  self.assertTrue(
65
- model_coverage.compare_tflite_torch(
73
+ test_utils.compare_tflite_torch(
66
74
  edge_model,
67
75
  model,
68
- (tokens, input_pos),
69
- num_valid_inputs=1,
76
+ tokens,
77
+ input_pos,
78
+ kv,
79
+ signature_name="serving_default",
70
80
  atol=1e-2,
71
81
  rtol=1e-5,
72
82
  )
@@ -85,23 +95,31 @@ class TestModelConversion(googletest.TestCase):
85
95
  prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
96
  prefill_tokens[0, :4] = idx
87
97
  prefill_input_pos = torch.arange(0, 10)
98
+ kv = kv_cache.KVCache.from_model_config(config)
88
99
 
89
100
  edge_model = ai_edge_torch.signature(
90
- "prefill", model, (prefill_tokens, prefill_input_pos)
101
+ "prefill",
102
+ model,
103
+ sample_kwargs={
104
+ "tokens": prefill_tokens,
105
+ "input_pos": prefill_input_pos,
106
+ "kv_cache": kv,
107
+ },
91
108
  ).convert()
92
109
  edge_model.set_interpreter_builder(
93
110
  self._interpreter_builder(edge_model.tflite_model())
94
111
  )
95
112
 
96
113
  self.assertTrue(
97
- model_coverage.compare_tflite_torch(
114
+ test_utils.compare_tflite_torch(
98
115
  edge_model,
99
116
  model,
100
- (prefill_tokens, prefill_input_pos),
117
+ prefill_tokens,
118
+ prefill_input_pos,
119
+ kv,
101
120
  signature_name="prefill",
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
121
+ atol=1e-1,
122
+ rtol=1e-3,
105
123
  )
106
124
  )
107
125
 
@@ -117,18 +135,28 @@ class TestModelConversion(googletest.TestCase):
117
135
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
118
136
  tokens[0, :4] = idx
119
137
  input_pos = torch.arange(0, 10)
120
-
121
- edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
138
+ kv = kv_cache.KVCache.from_model_config(config)
139
+
140
+ edge_model = ai_edge_torch.convert(
141
+ pytorch_model,
142
+ sample_kwargs={
143
+ "tokens": tokens,
144
+ "input_pos": input_pos,
145
+ "kv_cache": kv,
146
+ },
147
+ )
122
148
  edge_model.set_interpreter_builder(
123
149
  self._interpreter_builder(edge_model.tflite_model())
124
150
  )
125
151
 
126
152
  self.assertTrue(
127
- model_coverage.compare_tflite_torch(
153
+ test_utils.compare_tflite_torch(
128
154
  edge_model,
129
155
  pytorch_model,
130
- (tokens, input_pos),
131
- num_valid_inputs=1,
156
+ tokens,
157
+ input_pos,
158
+ kv,
159
+ signature_name="serving_default",
132
160
  atol=1e-3,
133
161
  rtol=1e-3,
134
162
  )
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utils for testing."""
17
+
18
+ from ai_edge_torch import model
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.lowertools import common_utils
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils import _pytree as pytree
24
+
25
+
26
+ def compare_tflite_torch(
27
+ edge_model: model.Model,
28
+ torch_model: torch.nn.Module,
29
+ tokens: torch.Tensor,
30
+ input_pos: torch.Tensor,
31
+ kv_cache: kv_utils.KVCache,
32
+ signature_name: str,
33
+ atol: float = 1e-5,
34
+ rtol: float = 1e-5,
35
+ ):
36
+ """Compares torch models and TFLite models."""
37
+ values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
38
+ flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
39
+ torch_output = torch_model(tokens, input_pos, kv_cache)
40
+
41
+ input_kv_flatten = {k: v.numpy() for k, v in zip(flat_names, values)}
42
+ edge_output = edge_model(
43
+ signature_name=signature_name,
44
+ tokens=tokens.numpy(),
45
+ input_pos=input_pos.numpy(),
46
+ **input_kv_flatten,
47
+ )
48
+
49
+ return np.allclose(
50
+ edge_output["logits"],
51
+ torch_output["logits"].detach().numpy(),
52
+ atol=atol,
53
+ rtol=rtol,
54
+ )
@@ -221,7 +221,8 @@ class ModelLoader:
221
221
  converted_state: Dict[str, torch.Tensor],
222
222
  ):
223
223
  prefix = f"transformer_blocks.{idx}"
224
- if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
224
+ ff_config = config.block_config(idx).ff_config
225
+ if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
225
226
  ff_up_proj_name = self._names.ff_up_proj.format(idx)
226
227
  ff_down_proj_name = self._names.ff_down_proj.format(idx)
227
228
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
@@ -230,7 +231,7 @@ class ModelLoader:
230
231
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
231
232
  f"{ff_down_proj_name}.weight"
232
233
  )
233
- if config.ff_config.use_bias:
234
+ if ff_config.use_bias:
234
235
  converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
235
236
  f"{ff_up_proj_name}.bias"
236
237
  )
@@ -250,7 +251,7 @@ class ModelLoader:
250
251
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
251
252
  f"{ff_gate_proj_name}.weight"
252
253
  )
253
- if config.ff_config.use_bias:
254
+ if ff_config.use_bias:
254
255
  converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
255
256
  f"{ff_up_proj_name}.bias"
256
257
  )
@@ -289,6 +290,7 @@ class ModelLoader:
289
290
  converted_state: Dict[str, torch.Tensor],
290
291
  ):
291
292
  prefix = f"transformer_blocks.{idx}"
293
+ attn_config = config.block_config(idx).attn_config
292
294
  if self._names.attn_fused_qkv_proj:
293
295
  fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
294
296
  converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
@@ -300,13 +302,13 @@ class ModelLoader:
300
302
  v_name = self._names.attn_value_proj.format(idx)
301
303
  converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
302
304
  self._fuse_qkv(
303
- config,
305
+ attn_config,
304
306
  state.pop(f"{q_name}.weight"),
305
307
  state.pop(f"{k_name}.weight"),
306
308
  state.pop(f"{v_name}.weight"),
307
309
  )
308
310
  )
309
- if config.attn_config.qkv_use_bias:
311
+ if attn_config.qkv_use_bias:
310
312
  if self._names.attn_fused_qkv_proj:
311
313
  converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
312
314
  f"{fused_qkv_name}.bias"
@@ -314,7 +316,7 @@ class ModelLoader:
314
316
  else:
315
317
  converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
316
318
  self._fuse_qkv(
317
- config,
319
+ attn_config,
318
320
  state.pop(f"{q_name}.bias"),
319
321
  state.pop(f"{k_name}.bias"),
320
322
  state.pop(f"{v_name}.bias"),
@@ -325,7 +327,7 @@ class ModelLoader:
325
327
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
326
328
  state.pop(f"{o_name}.weight")
327
329
  )
328
- if config.attn_config.output_proj_use_bias:
330
+ if attn_config.output_proj_use_bias:
329
331
  converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
330
332
  state.pop(f"{o_name}.bias")
331
333
  )
@@ -360,18 +362,16 @@ class ModelLoader:
360
362
 
361
363
  def _fuse_qkv(
362
364
  self,
363
- config: model_config.ModelConfig,
365
+ attn_config: model_config.AttentionConfig,
364
366
  q: torch.Tensor,
365
367
  k: torch.Tensor,
366
368
  v: torch.Tensor,
367
369
  ) -> torch.Tensor:
368
- if config.attn_config.qkv_fused_interleaved:
369
- q_per_kv = (
370
- config.attn_config.num_heads // config.attn_config.num_query_groups
371
- )
372
- qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
373
- ks = torch.split(k, config.attn_config.head_dim)
374
- vs = torch.split(v, config.attn_config.head_dim)
370
+ if attn_config.qkv_fused_interleaved:
371
+ q_per_kv = attn_config.num_heads // attn_config.num_query_groups
372
+ qs = torch.split(q, attn_config.head_dim * q_per_kv)
373
+ ks = torch.split(k, attn_config.head_dim)
374
+ vs = torch.split(v, attn_config.head_dim)
375
375
  cycled = [t for group in zip(qs, ks, vs) for t in group]
376
376
  return torch.cat(cycled)
377
377
  else:
@@ -279,7 +279,8 @@ class ModelLoader:
279
279
  prefix = additional_prefix + f"transformer_blocks.{idx}"
280
280
  if names.ff_up_proj is None or names.ff_down_proj is None:
281
281
  return
282
- if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
282
+ ff_config = config.block_config(idx).ff_config
283
+ if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
283
284
  ff_up_proj_name = names.ff_up_proj.format(idx)
284
285
  ff_down_proj_name = names.ff_down_proj.format(idx)
285
286
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
@@ -288,7 +289,7 @@ class ModelLoader:
288
289
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
289
290
  f"{ff_down_proj_name}.weight"
290
291
  )
291
- if config.ff_config.use_bias:
292
+ if ff_config.use_bias:
292
293
  converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
293
294
  f"{ff_up_proj_name}.bias"
294
295
  )
@@ -309,7 +310,7 @@ class ModelLoader:
309
310
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
310
311
  f"{ff_gate_proj_name}.weight"
311
312
  )
312
- if config.ff_config.use_bias:
313
+ if ff_config.use_bias:
313
314
  converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
314
315
  f"{ff_up_proj_name}.bias"
315
316
  )
@@ -337,20 +338,21 @@ class ModelLoader:
337
338
  ):
338
339
  return
339
340
  prefix = additional_prefix + f"transformer_blocks.{idx}"
341
+ attn_config = config.block_config(idx).attn_config
340
342
  q_name = names.attn_query_proj.format(idx)
341
343
  k_name = names.attn_key_proj.format(idx)
342
344
  v_name = names.attn_value_proj.format(idx)
343
345
  # model.encoder.transformer_blocks[0].atten_func.q_projection.weight
344
346
  if fuse_attention:
345
347
  converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
346
- config,
348
+ attn_config,
347
349
  state.pop(f"{q_name}.weight"),
348
350
  state.pop(f"{k_name}.weight"),
349
351
  state.pop(f"{v_name}.weight"),
350
352
  )
351
- if config.attn_config.qkv_use_bias:
353
+ if attn_config.qkv_use_bias:
352
354
  converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
353
- config,
355
+ attn_config,
354
356
  state.pop(f"{q_name}.bias"),
355
357
  state.pop(f"{k_name}.bias"),
356
358
  state.pop(f"{v_name}.bias"),
@@ -365,7 +367,7 @@ class ModelLoader:
365
367
  converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
366
368
  f"{v_name}.weight"
367
369
  )
368
- if config.attn_config.qkv_use_bias:
370
+ if attn_config.qkv_use_bias:
369
371
  converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
370
372
  f"{q_name}.bias"
371
373
  )
@@ -380,7 +382,7 @@ class ModelLoader:
380
382
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
381
383
  state.pop(f"{o_name}.weight")
382
384
  )
383
- if config.attn_config.output_proj_use_bias:
385
+ if attn_config.output_proj_use_bias:
384
386
  converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
385
387
  state.pop(f"{o_name}.bias")
386
388
  )
@@ -402,6 +404,7 @@ class ModelLoader:
402
404
  ):
403
405
  return
404
406
  prefix = additional_prefix + f"transformer_blocks.{idx}"
407
+ attn_config = config.block_config(idx).attn_config
405
408
  q_name = names.cross_attn_query_proj.format(idx)
406
409
  k_name = names.cross_attn_key_proj.format(idx)
407
410
  v_name = names.cross_attn_value_proj.format(idx)
@@ -409,16 +412,16 @@ class ModelLoader:
409
412
  if fuse_attention:
410
413
  converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
411
414
  self._fuse_qkv(
412
- config,
415
+ attn_config,
413
416
  state.pop(f"{q_name}.weight"),
414
417
  state.pop(f"{k_name}.weight"),
415
418
  state.pop(f"{v_name}.weight"),
416
419
  )
417
420
  )
418
- if config.attn_config.qkv_use_bias:
421
+ if attn_config.qkv_use_bias:
419
422
  converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
420
423
  self._fuse_qkv(
421
- config,
424
+ attn_config,
422
425
  state.pop(f"{q_name}.bias"),
423
426
  state.pop(f"{k_name}.bias"),
424
427
  state.pop(f"{v_name}.bias"),
@@ -434,7 +437,7 @@ class ModelLoader:
434
437
  converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
435
438
  state.pop(f"{v_name}.weight")
436
439
  )
437
- if config.attn_config.qkv_use_bias:
440
+ if attn_config.qkv_use_bias:
438
441
  converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
439
442
  state.pop(f"{q_name}.bias")
440
443
  )
@@ -449,7 +452,7 @@ class ModelLoader:
449
452
  converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
450
453
  state.pop(f"{o_name}.weight")
451
454
  )
452
- if config.attn_config.output_proj_use_bias:
455
+ if attn_config.output_proj_use_bias:
453
456
  converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
454
457
  state.pop(f"{o_name}.bias")
455
458
  )
@@ -496,16 +499,14 @@ class ModelLoader:
496
499
 
497
500
  def _fuse_qkv(
498
501
  self,
499
- config: model_config.ModelConfig,
502
+ attn_config: model_config.AttentionConfig,
500
503
  q: torch.Tensor,
501
504
  k: torch.Tensor,
502
505
  v: torch.Tensor,
503
506
  ) -> torch.Tensor:
504
- q_per_kv = (
505
- config.attn_config.num_heads // config.attn_config.num_query_groups
506
- )
507
- qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
508
- ks = torch.split(k, config.attn_config.head_dim)
509
- vs = torch.split(v, config.attn_config.head_dim)
507
+ q_per_kv = attn_config.num_heads // attn_config.num_query_groups
508
+ qs = torch.split(q, attn_config.head_dim * q_per_kv)
509
+ ks = torch.split(k, attn_config.head_dim)
510
+ vs = torch.split(v, attn_config.head_dim)
510
511
  cycled = [t for group in zip(qs, ks, vs) for t in group]
511
512
  return torch.cat(cycled)
@@ -16,6 +16,7 @@ from . import _basic
16
16
  from . import _batch_norm
17
17
  from . import _convolution
18
18
  from . import _jax_lowerings
19
+ from . import _layer_norm
19
20
  from . import context
20
21
  from . import registry
21
22
  from . import utils