ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +35 -13
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  4. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  7. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  8. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  9. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  10. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  11. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  12. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  14. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  16. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  17. ai_edge_torch/generative/layers/attention.py +77 -73
  18. ai_edge_torch/generative/layers/builder.py +5 -3
  19. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  20. ai_edge_torch/generative/layers/model_config.py +38 -19
  21. ai_edge_torch/generative/layers/normalization.py +158 -0
  22. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  23. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  24. ai_edge_torch/generative/test/test_loader.py +1 -1
  25. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  26. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  27. ai_edge_torch/generative/test/utils.py +54 -0
  28. ai_edge_torch/generative/utilities/loader.py +15 -15
  29. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  30. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  31. ai_edge_torch/odml_torch/lowerings/_convolution.py +196 -74
  32. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
  33. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  34. ai_edge_torch/version.py +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +41 -47
  37. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  38. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  40. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  41. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  42. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  43. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  44. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  45. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  46. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  47. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  49. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  50. {ai_edge_torch_nightly-0.3.0.dev20240909.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
26
  import tensorflow as tf
27
27
  import torch
28
+ from torch import nn
28
29
  import torchvision
29
30
 
30
31
  from absl.testing import absltest as googletest
@@ -51,7 +52,7 @@ class TestConvert(googletest.TestCase):
51
52
  def test_convert_add(self):
52
53
  """Tests conversion of a simple Add module."""
53
54
 
54
- class Add(torch.nn.Module):
55
+ class Add(nn.Module):
55
56
 
56
57
  def forward(self, a, b):
57
58
  return a + b
@@ -70,7 +71,7 @@ class TestConvert(googletest.TestCase):
70
71
  def test_convert_dot_add(self):
71
72
  """Tests conversion of a matrix multiplication followed by an add."""
72
73
 
73
- class DotAdd(torch.nn.Module):
74
+ class DotAdd(nn.Module):
74
75
 
75
76
  def forward(self, a, b, c):
76
77
  return a @ b + c
@@ -99,7 +100,7 @@ class TestConvert(googletest.TestCase):
99
100
  def test_signature_args_ordering(self):
100
101
  """Tests conversion of a model with more than 10 arguments."""
101
102
 
102
- class AddChainWith11Args(torch.nn.Module):
103
+ class AddChainWith11Args(nn.Module):
103
104
  """A model with 11 arguments."""
104
105
 
105
106
  def forward(
@@ -152,7 +153,7 @@ class TestConvert(googletest.TestCase):
152
153
  def test_multi_output_model(self):
153
154
  """Tests conversion of a model that returns multiple outputs."""
154
155
 
155
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
156
+ class BasicAddModelWithMultipleOutputs(nn.Module):
156
157
  """A model that returns multiple outputs."""
157
158
 
158
159
  def forward(self, arg0, arg1):
@@ -176,7 +177,7 @@ class TestConvert(googletest.TestCase):
176
177
  def test_12_outputs_model(self):
177
178
  """Tests conversion of a model that returns more than 10 outputs."""
178
179
 
179
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
180
+ class BasicAddModelWithMultipleOutputs(nn.Module):
180
181
  """A model that returns multiple outputs."""
181
182
 
182
183
  def forward(self, arg0, arg1):
@@ -245,7 +246,7 @@ class TestConvert(googletest.TestCase):
245
246
  def test_convert_add_converter_flags(self):
246
247
  """Tests conversion of an add module setting a tflite converter flag."""
247
248
 
248
- class Add(torch.nn.Module):
249
+ class Add(nn.Module):
249
250
 
250
251
  def forward(self, a, b):
251
252
  return a + b
@@ -267,6 +268,27 @@ class TestConvert(googletest.TestCase):
267
268
  )
268
269
  self.assertTrue(os.path.isdir(ir_dump_path))
269
270
 
271
+ def test_convert_conv_transpose_batch_norm(self):
272
+ """Tests conversion of a model with ConvTranspose2d and BatchNorm2d."""
273
+
274
+ channels = 2
275
+ size = 2
276
+ torch_model = nn.Sequential(
277
+ nn.ConvTranspose2d(
278
+ channels, channels, 1, stride=2, dilation=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(channels),
281
+ )
282
+
283
+ torch_model.eval()
284
+ sample_input = (torch.rand(1, channels, size, size),)
285
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
286
+
287
+ result = model_coverage.compare_tflite_torch(
288
+ edge_model, torch_model, sample_input
289
+ )
290
+ self.assertTrue(result)
291
+
270
292
  @googletest.skipIf(
271
293
  not config.Config.use_torch_xla,
272
294
  reason="Shape polymorphism is not yet support with odml_torch.",
@@ -274,7 +296,7 @@ class TestConvert(googletest.TestCase):
274
296
  def test_convert_model_with_dynamic_batch(self):
275
297
  """Test converting a simple model with dynamic batch size."""
276
298
 
277
- class SampleModel(torch.nn.Module):
299
+ class SampleModel(nn.Module):
278
300
 
279
301
  def __init__(self):
280
302
  super().__init__()
@@ -304,7 +326,7 @@ class TestConvert(googletest.TestCase):
304
326
  def test_convert_model_with_kwargs(self):
305
327
  """Test converting a simple model with sample_kwargs."""
306
328
 
307
- class SampleModel(torch.nn.Module):
329
+ class SampleModel(nn.Module):
308
330
 
309
331
  def forward(self, x, y):
310
332
  return x + y
@@ -323,7 +345,7 @@ class TestConvert(googletest.TestCase):
323
345
  def test_convert_model_with_args_kwargs(self):
324
346
  """Test converting a simple model with both sample_args and sample_kwargs."""
325
347
 
326
- class SampleModel(torch.nn.Module):
348
+ class SampleModel(nn.Module):
327
349
 
328
350
  def forward(self, x, y):
329
351
  return x + y
@@ -343,7 +365,7 @@ class TestConvert(googletest.TestCase):
343
365
  def test_convert_model_with_args_nested_kwargs_1(self):
344
366
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
345
367
 
346
- class SampleModel(torch.nn.Module):
368
+ class SampleModel(nn.Module):
347
369
 
348
370
  def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
349
371
  return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
@@ -370,7 +392,7 @@ class TestConvert(googletest.TestCase):
370
392
  def test_convert_model_with_args_nested_kwargs_2(self):
371
393
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
372
394
 
373
- class SampleModel(torch.nn.Module):
395
+ class SampleModel(nn.Module):
374
396
 
375
397
  def forward(self, x, y, z):
376
398
  return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
@@ -397,7 +419,7 @@ class TestConvert(googletest.TestCase):
397
419
  def test_convert_model_with_args_nested_kwargs_3(self):
398
420
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
399
421
 
400
- class SampleModel(torch.nn.Module):
422
+ class SampleModel(nn.Module):
401
423
 
402
424
  def forward(self, x, y, z):
403
425
  return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
@@ -424,7 +446,7 @@ class TestConvert(googletest.TestCase):
424
446
  def test_convert_model_non_flat_output_dict(self):
425
447
  """Test converting a model with non-flat output structure."""
426
448
 
427
- class SampleModel(torch.nn.Module):
449
+ class SampleModel(nn.Module):
428
450
 
429
451
  def forward(self, x, y, z):
430
452
  return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}
@@ -13,32 +13,35 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma2 model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma2
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
24
27
 
25
- def convert_gemma_to_tflite(
28
+ def convert_gemma2_to_tflite(
26
29
  checkpoint_path: str,
27
30
  prefill_seq_len: int = 512,
28
31
  kv_cache_max_len: int = 1024,
29
32
  quantize: bool = True,
30
33
  ):
31
- """Converting a Gemma 2 2B model to multi-signature
32
- tflite model.
34
+ """Converts a Gemma2 2B model to multi-signature tflite model.
33
35
 
34
36
  Args:
35
- checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
36
39
  prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
40
  Defaults to 512.
38
41
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
42
  including both prefill and decode. Defaults to 1024.
40
- quantize (bool, optional): Whether the model should be quanized.
41
- Defaults to True.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
42
45
  """
43
46
  pytorch_model = gemma2.build_2b_model(
44
47
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
+ convert_gemma2_to_tflite(path)
@@ -13,11 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """Example of converting a Gemma model to multi-signature tflite model."""
17
+
16
18
  import os
17
- from pathlib import Path
19
+ import pathlib
18
20
 
19
21
  import ai_edge_torch
20
22
  from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
24
  from ai_edge_torch.generative.quantize import quant_recipes
22
25
  import torch
23
26
 
@@ -48,20 +51,36 @@ def convert_gemma_to_tflite(
48
51
  prefill_input_pos = torch.arange(0, prefill_seq_len)
49
52
  decode_token = torch.tensor([[0]], dtype=torch.long)
50
53
  decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
51
55
 
52
56
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
57
  edge_model = (
54
58
  ai_edge_torch.signature(
55
- 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
56
75
  )
57
- .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
76
  .convert(quant_config=quant_config)
59
77
  )
78
+ quant_suffix = 'q8' if quantize else 'f32'
60
79
  edge_model.export(
61
- f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
80
+ f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
62
81
  )
63
82
 
64
83
 
65
84
  if __name__ == '__main__':
66
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma-2b')
67
- convert_gemma_to_tflite(checkpoint_path)
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
+ convert_gemma_to_tflite(path)
@@ -12,13 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Example of building a Gemma model.
15
+
16
+ """Example of building a Gemma model."""
16
17
 
17
18
  import os
18
- from pathlib import Path
19
+ import pathlib
19
20
 
20
21
  from ai_edge_torch.generative.layers import attention
21
22
  from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
25
  import ai_edge_torch.generative.layers.model_config as cfg
24
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
@@ -48,7 +50,6 @@ class Gemma(nn.Module):
48
50
  def __init__(self, config: cfg.ModelConfig):
49
51
  super().__init__()
50
52
 
51
- self.config = config
52
53
  # Construct model layers.
53
54
  self.tok_embedding = nn.Embedding(
54
55
  config.vocab_size, config.embedding_dim, padding_idx=0
@@ -60,18 +61,20 @@ class Gemma(nn.Module):
60
61
  )
61
62
  # Gemma re-uses the embedding as the head projection layer.
62
63
  self.lm_head.weight.data = self.tok_embedding.weight.data
64
+ # Gemma has only one block config.
65
+ block_config = config.block_config(0)
63
66
  self.transformer_blocks = nn.ModuleList(
64
- attention.TransformerBlock(config) for _ in range(config.num_layers)
67
+ attention.TransformerBlock(block_config, config)
68
+ for _ in range(config.num_layers)
65
69
  )
66
70
  self.final_norm = builder.build_norm(
67
71
  config.embedding_dim,
68
72
  config.final_norm_config,
69
73
  )
74
+ attn_config = block_config.attn_config
70
75
  self.rope_cache = attn_utils.build_rope_cache(
71
76
  size=config.kv_cache_max,
72
- dim=int(
73
- config.attn_config.rotary_percentage * config.attn_config.head_dim
74
- ),
77
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
75
78
  base=10_000,
76
79
  condense_ratio=1,
77
80
  dtype=torch.float32,
@@ -84,16 +87,22 @@ class Gemma(nn.Module):
84
87
  )
85
88
  self.config = config
86
89
 
87
- # The model's forward function takes in additional k/v cache tensors
88
- # and returns the updated k/v cache tensors to the caller.
89
- # This can be eliminated if we handle k/v cache updates inside the model itself.
90
90
  @torch.inference_mode
91
- def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
92
- _, seq_len = idx.size()
91
+ def forward(
92
+ self,
93
+ tokens: torch.Tensor,
94
+ input_pos: torch.Tensor,
95
+ kv_cache: kv_utils.KVCache,
96
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
+ _, seq_len = tokens.size()
93
98
  assert self.config.max_seq_len >= seq_len, (
94
99
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
95
100
  f" {self.config.max_seq_len}"
96
101
  )
102
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
103
+ "The number of transformer blocks and the number of KV cache entries"
104
+ " must be the same."
105
+ )
97
106
 
98
107
  cos, sin = self.rope_cache
99
108
  cos = cos.index_select(0, input_pos)
@@ -102,15 +111,20 @@ class Gemma(nn.Module):
102
111
  mask = mask[:, :, :, : self.config.kv_cache_max]
103
112
 
104
113
  # token embeddings of shape (b, t, n_embd)
105
- x = self.tok_embedding(idx)
114
+ x = self.tok_embedding(tokens)
106
115
  x = x * (self.config.embedding_dim**0.5)
107
116
 
108
- for _, block in enumerate(self.transformer_blocks):
109
- x = block(x, (cos, sin), mask, input_pos)
117
+ updated_kv_entires = []
118
+ for i, block in enumerate(self.transformer_blocks):
119
+ kv_entry = kv_cache.caches[i] if kv_cache else None
120
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
121
+ if kv_entry:
122
+ updated_kv_entires.append(kv_entry)
123
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
110
124
 
111
125
  x = self.final_norm(x)
112
- res = self.lm_head(x) # (b, t, vocab_size)
113
- return res
126
+ logits = self.lm_head(x) # (b, t, vocab_size)
127
+ return {"logits": logits, "kv_cache": updated_kv_cache}
114
128
 
115
129
 
116
130
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -139,18 +153,20 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
139
153
  epsilon=1e-6,
140
154
  zero_centered=True,
141
155
  )
156
+ block_config = cfg.TransformerBlockConfig(
157
+ attn_config=attn_config,
158
+ ff_config=ff_config,
159
+ pre_attention_norm_config=norm_config,
160
+ post_attention_norm_config=norm_config,
161
+ )
142
162
  config = cfg.ModelConfig(
143
163
  vocab_size=256000,
144
164
  num_layers=18,
145
165
  max_seq_len=8192,
146
166
  embedding_dim=2048,
147
167
  kv_cache_max_len=kv_cache_max_len,
148
- attn_config=attn_config,
149
- ff_config=ff_config,
150
- pre_attention_norm_config=norm_config,
151
- post_attention_norm_config=norm_config,
168
+ block_configs=block_config,
152
169
  final_norm_config=norm_config,
153
- parallel_residual=False,
154
170
  lm_head_use_bias=False,
155
171
  enable_hlfb=True,
156
172
  )
@@ -159,7 +175,8 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
159
175
 
160
176
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
161
177
  config = get_model_config_2b(kv_cache_max_len)
162
- config.ff_config.intermediate_size = 128
178
+ # Gemma has only one block config.
179
+ config.block_config(0).ff_config.intermediate_size = 128
163
180
  config.vocab_size = 128
164
181
  config.num_layers = 2
165
182
  config.max_seq_len = 2 * kv_cache_max_len
@@ -170,32 +187,35 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
170
187
  config = get_model_config_2b(**kwargs)
171
188
  model = Gemma(config)
172
189
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
173
- # since embedding and lm-head use the same weight, we need to set strict
190
+ # Since embedding and lm-head use the same weight, we need to set strict
174
191
  # to False.
175
192
  loader.load(model, strict=False)
176
193
  model.eval()
177
194
  return model
178
195
 
179
196
 
180
- def define_and_run_2b() -> None:
197
+ def define_and_run_2b(checkpoint_path: str) -> None:
181
198
  """Instantiates and runs a Gemma 2B model."""
182
199
 
183
- current_dir = Path(__file__).parent.resolve()
200
+ current_dir = pathlib.Path(__file__).parent.resolve()
184
201
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
185
202
 
186
203
  kv_cache_max_len = 1024
187
- checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma-2b")
188
204
  model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
189
205
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
190
206
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
191
207
  tokens[0, :4] = idx
192
208
  input_pos = torch.arange(0, kv_cache_max_len)
193
- lm_logits = model.forward(tokens, input_pos)
209
+ kv = kv_utils.KVCache.from_model_config(model.config)
210
+ output = model.forward(tokens, input_pos, kv)
194
211
  print("comparing with goldens..")
195
212
  assert torch.allclose(
196
- gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
213
+ gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
197
214
  )
198
215
 
199
216
 
200
217
  if __name__ == "__main__":
201
- define_and_run_2b()
218
+ input_checkpoint_path = os.path.join(
219
+ pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
+ )
221
+ define_and_run_2b(input_checkpoint_path)