lalamo 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {lalamo-0.2.1 → lalamo-0.2.3}/PKG-INFO +1 -1
  2. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo/__init__.py +1 -1
  3. lalamo-0.2.3/lalamo/model_import/__init__.py +8 -0
  4. lalamo-0.2.3/lalamo/model_import/common.py +111 -0
  5. lalamo-0.2.3/lalamo/model_import/configs/__init__.py +24 -0
  6. lalamo-0.2.3/lalamo/model_import/configs/common.py +62 -0
  7. lalamo-0.2.3/lalamo/model_import/configs/executorch.py +166 -0
  8. lalamo-0.2.3/lalamo/model_import/configs/huggingface/__init__.py +18 -0
  9. lalamo-0.2.3/lalamo/model_import/configs/huggingface/common.py +72 -0
  10. lalamo-0.2.3/lalamo/model_import/configs/huggingface/gemma2.py +122 -0
  11. lalamo-0.2.3/lalamo/model_import/configs/huggingface/gemma3.py +187 -0
  12. lalamo-0.2.3/lalamo/model_import/configs/huggingface/llama.py +155 -0
  13. lalamo-0.2.3/lalamo/model_import/configs/huggingface/mistral.py +132 -0
  14. lalamo-0.2.3/lalamo/model_import/configs/huggingface/qwen2.py +144 -0
  15. lalamo-0.2.3/lalamo/model_import/configs/huggingface/qwen3.py +142 -0
  16. lalamo-0.2.3/lalamo/model_import/loaders/__init__.py +7 -0
  17. lalamo-0.2.3/lalamo/model_import/loaders/common.py +45 -0
  18. lalamo-0.2.3/lalamo/model_import/loaders/executorch.py +223 -0
  19. lalamo-0.2.3/lalamo/model_import/loaders/huggingface.py +304 -0
  20. lalamo-0.2.3/lalamo/model_import/model_specs/__init__.py +38 -0
  21. lalamo-0.2.3/lalamo/model_import/model_specs/common.py +118 -0
  22. lalamo-0.2.3/lalamo/model_import/model_specs/deepseek.py +28 -0
  23. lalamo-0.2.3/lalamo/model_import/model_specs/gemma.py +76 -0
  24. lalamo-0.2.3/lalamo/model_import/model_specs/huggingface.py +28 -0
  25. lalamo-0.2.3/lalamo/model_import/model_specs/llama.py +100 -0
  26. lalamo-0.2.3/lalamo/model_import/model_specs/mistral.py +59 -0
  27. lalamo-0.2.3/lalamo/model_import/model_specs/pleias.py +28 -0
  28. lalamo-0.2.3/lalamo/model_import/model_specs/polaris.py +22 -0
  29. lalamo-0.2.3/lalamo/model_import/model_specs/qwen.py +336 -0
  30. lalamo-0.2.3/lalamo/model_import/model_specs/reka.py +28 -0
  31. lalamo-0.2.3/lalamo/modules/__init__.py +85 -0
  32. lalamo-0.2.3/lalamo/modules/activations.py +30 -0
  33. lalamo-0.2.3/lalamo/modules/attention.py +326 -0
  34. lalamo-0.2.3/lalamo/modules/common.py +133 -0
  35. lalamo-0.2.3/lalamo/modules/decoder.py +244 -0
  36. lalamo-0.2.3/lalamo/modules/decoder_layer.py +240 -0
  37. lalamo-0.2.3/lalamo/modules/embedding.py +299 -0
  38. lalamo-0.2.3/lalamo/modules/kv_cache.py +196 -0
  39. lalamo-0.2.3/lalamo/modules/linear.py +603 -0
  40. lalamo-0.2.3/lalamo/modules/mlp.py +79 -0
  41. lalamo-0.2.3/lalamo/modules/normalization.py +77 -0
  42. lalamo-0.2.3/lalamo/modules/rope.py +255 -0
  43. lalamo-0.2.3/lalamo/modules/utils.py +13 -0
  44. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo.egg-info/PKG-INFO +1 -1
  45. lalamo-0.2.3/lalamo.egg-info/SOURCES.txt +58 -0
  46. {lalamo-0.2.1 → lalamo-0.2.3}/pyproject.toml +4 -2
  47. lalamo-0.2.1/lalamo.egg-info/SOURCES.txt +0 -17
  48. {lalamo-0.2.1 → lalamo-0.2.3}/LICENSE +0 -0
  49. {lalamo-0.2.1 → lalamo-0.2.3}/README.md +0 -0
  50. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo/common.py +0 -0
  51. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo/language_model.py +0 -0
  52. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo/main.py +0 -0
  53. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo/quantization.py +0 -0
  54. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo/utils.py +0 -0
  55. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo.egg-info/dependency_links.txt +0 -0
  56. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo.egg-info/entry_points.txt +0 -0
  57. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo.egg-info/requires.txt +0 -0
  58. {lalamo-0.2.1 → lalamo-0.2.3}/lalamo.egg-info/top_level.txt +0 -0
  59. {lalamo-0.2.1 → lalamo-0.2.3}/setup.cfg +0 -0
  60. {lalamo-0.2.1 → lalamo-0.2.3}/tests/test_generation.py +0 -0
  61. {lalamo-0.2.1 → lalamo-0.2.3}/tests/test_huggingface_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lalamo
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: JAX library for optimization and export of models for use with the UZU inference engine.
5
5
  Requires-Python: <4,>=3.12
6
6
  Description-Content-Type: text/markdown
@@ -1,7 +1,7 @@
1
1
  from lalamo.model_import import REPO_TO_MODEL, ModelSpec, import_model
2
2
  from lalamo.modules import Decoder
3
3
 
4
- __version__ = "0.2.1"
4
+ __version__ = "0.2.3"
5
5
 
6
6
  __all__ = [
7
7
  "REPO_TO_MODEL",
@@ -0,0 +1,8 @@
1
+ from .common import REPO_TO_MODEL, ModelMetadata, ModelSpec, import_model
2
+
3
+ __all__ = [
4
+ "REPO_TO_MODEL",
5
+ "ModelMetadata",
6
+ "ModelSpec",
7
+ "import_model",
8
+ ]
@@ -0,0 +1,111 @@
1
+ import importlib.metadata
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import NamedTuple
5
+
6
+ import huggingface_hub
7
+ import jax.numpy as jnp
8
+ from jaxtyping import DTypeLike
9
+
10
+ from lalamo.modules import Decoder, DecoderConfig
11
+ from lalamo.quantization import QuantizationMode
12
+
13
+ from .model_specs import REPO_TO_MODEL, ModelSpec, UseCase
14
+
15
+ __all__ = [
16
+ "REPO_TO_MODEL",
17
+ "ModelMetadata",
18
+ "ModelSpec",
19
+ "import_model",
20
+ ]
21
+
22
+
23
+ LALAMO_VERSION = importlib.metadata.version("lalamo")
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class ModelMetadata:
28
+ toolchain_version: str
29
+ vendor: str
30
+ family: str
31
+ name: str
32
+ size: str
33
+ quantization: QuantizationMode | None
34
+ repo: str
35
+ use_cases: tuple[UseCase, ...]
36
+ model_config: DecoderConfig
37
+ tokenizer_file_names: tuple[str, ...]
38
+
39
+
40
+ def download_weights(model_spec: ModelSpec, output_dir: Path | str | None = None) -> list[Path]:
41
+ result = [
42
+ huggingface_hub.hf_hub_download(
43
+ repo_id=model_spec.repo,
44
+ local_dir=output_dir,
45
+ filename=filename,
46
+ )
47
+ for filename in model_spec.weights_file_names
48
+ ]
49
+ return [Path(path) for path in result]
50
+
51
+
52
+ def download_config_file(model_spec: ModelSpec, output_dir: Path | str | None = None) -> Path:
53
+ result = huggingface_hub.hf_hub_download(
54
+ repo_id=model_spec.repo,
55
+ local_dir=output_dir,
56
+ filename=model_spec.config_file_name,
57
+ )
58
+ return Path(result)
59
+
60
+
61
+ def download_tokenizer_files(model_spec: ModelSpec, output_dir: Path | str | None = None) -> tuple[Path, ...]:
62
+ result = [
63
+ huggingface_hub.hf_hub_download(
64
+ repo_id=tokenizer_file_spec.repo or model_spec.repo,
65
+ local_dir=output_dir,
66
+ filename=tokenizer_file_spec.filename,
67
+ )
68
+ for tokenizer_file_spec in model_spec.tokenizer_files
69
+ ]
70
+ return tuple(Path(path) for path in result)
71
+
72
+
73
+ class ImportResults(NamedTuple):
74
+ model: Decoder
75
+ metadata: ModelMetadata
76
+ tokenizer_file_paths: tuple[Path, ...]
77
+
78
+
79
+ def import_model(
80
+ model_spec: ModelSpec,
81
+ *,
82
+ context_length: int | None = None,
83
+ precision: DTypeLike | None = None,
84
+ accumulation_precision: DTypeLike = jnp.float32,
85
+ ) -> ImportResults:
86
+ foreign_config_file = download_config_file(model_spec)
87
+ foreign_config = model_spec.config_type.from_json(foreign_config_file)
88
+
89
+ tokenizer_file_paths = download_tokenizer_files(model_spec)
90
+ if precision is None:
91
+ precision = foreign_config.default_precision
92
+
93
+ weights_paths = download_weights(model_spec)
94
+ weights_dict = {}
95
+ for weights_path in weights_paths:
96
+ weights_dict.update(model_spec.weights_type.load(weights_path, precision))
97
+
98
+ model = foreign_config.load_model(context_length, precision, accumulation_precision, weights_dict)
99
+ metadata = ModelMetadata(
100
+ toolchain_version=LALAMO_VERSION,
101
+ vendor=model_spec.vendor,
102
+ family=model_spec.family,
103
+ name=model_spec.name,
104
+ size=model_spec.size,
105
+ quantization=model_spec.quantization,
106
+ repo=model_spec.repo,
107
+ use_cases=model_spec.use_cases,
108
+ model_config=model.config,
109
+ tokenizer_file_names=tuple(p.name for p in tokenizer_file_paths),
110
+ )
111
+ return ImportResults(model, metadata, tokenizer_file_paths)
@@ -0,0 +1,24 @@
1
+ from .common import ForeignConfig
2
+
3
+ # from .executorch import ETLlamaConfig
4
+ from .huggingface import (
5
+ HFGemma2Config,
6
+ HFGemma3Config,
7
+ HFGemma3TextConfig,
8
+ HFLlamaConfig,
9
+ HFMistralConfig,
10
+ HFQwen2Config,
11
+ HFQwen3Config,
12
+ )
13
+
14
+ __all__ = [
15
+ # "ETLlamaConfig",
16
+ "ForeignConfig",
17
+ "HFGemma2Config",
18
+ "HFGemma3Config",
19
+ "HFGemma3TextConfig",
20
+ "HFLlamaConfig",
21
+ "HFMistralConfig",
22
+ "HFQwen2Config",
23
+ "HFQwen3Config",
24
+ ]
@@ -0,0 +1,62 @@
1
+ import json
2
+ from abc import abstractmethod
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import ClassVar, Self
6
+
7
+ import cattrs
8
+ import jax
9
+ from jaxtyping import Array, DTypeLike
10
+
11
+ from lalamo.modules import Decoder, DecoderConfig
12
+
13
+ __all__ = ["ForeignConfig"]
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class ForeignConfig:
18
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
19
+ _converter.register_structure_hook(int | list[int], lambda v, _: v)
20
+
21
+ @property
22
+ @abstractmethod
23
+ def default_precision(self) -> DTypeLike: ...
24
+
25
+ @classmethod
26
+ def from_json(cls, json_path: Path | str) -> Self:
27
+ json_path = Path(json_path)
28
+ with open(json_path) as f:
29
+ config = json.load(f)
30
+ return cls._converter.structure(config, cls)
31
+
32
+ def to_json(self, json_path: Path | str) -> None:
33
+ json_path = Path(json_path)
34
+ with open(json_path, "w") as f:
35
+ json.dump(self._converter.unstructure(self), f, indent=2)
36
+
37
+ def to_decoder_config(
38
+ self,
39
+ context_length: int | None,
40
+ activation_precision: DTypeLike,
41
+ accumulation_precision: DTypeLike,
42
+ ) -> DecoderConfig:
43
+ raise NotImplementedError
44
+
45
+ @classmethod
46
+ def _load_weights(
47
+ cls,
48
+ model: Decoder,
49
+ weights_dict: dict[str, Array],
50
+ ) -> Decoder:
51
+ raise NotImplementedError
52
+
53
+ def load_model(
54
+ self,
55
+ context_length: int | None,
56
+ activation_precision: DTypeLike,
57
+ accumulation_precision: DTypeLike,
58
+ weights_dict: dict[str, Array],
59
+ ) -> Decoder:
60
+ config = self.to_decoder_config(context_length, activation_precision, accumulation_precision)
61
+ model = config.random_init(key=jax.random.PRNGKey(0))
62
+ return self._load_weights(model, weights_dict)
@@ -0,0 +1,166 @@
1
+ from dataclasses import dataclass
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array, DTypeLike
5
+
6
+ from lalamo.model_import.loaders.executorch import load_executorch
7
+ from lalamo.modules import (
8
+ Activation,
9
+ AttentionConfig,
10
+ Decoder,
11
+ DecoderConfig,
12
+ DecoderLayerConfig,
13
+ LlamaRoPEConfig,
14
+ MLPConfig,
15
+ QLoRALinearConfig,
16
+ QuantizedTiedEmbeddingConfig,
17
+ RMSNormConfig,
18
+ UpcastMode,
19
+ )
20
+ from lalamo.quantization import QuantizationMode
21
+
22
+ from .common import ForeignConfig
23
+
24
+ __all__ = ["ETLlamaConfig"]
25
+
26
+
27
+ # These parameters are not present in the config file, and are extracted from the executorch implementation
28
+ LOW_FREQ_FACTOR = 1.0
29
+ HIGH_FREQ_FACTOR = 4.0
30
+ OLD_CONTEXT_LENGTH = 8192
31
+ MAX_SEQUENCE_LENGTH = 8192 * 32
32
+
33
+ ROPE_SCALING_FACTOR = 32.0
34
+
35
+ EMBEDDING_QUANTIZATION_MODE = QuantizationMode.INT8
36
+ ACTIVATION_QUANTIZATION_MODE = QuantizationMode.INT8
37
+ WEIGHT_QUANTIZATION_MODE = QuantizationMode.UINT4
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class QuantizationConfig:
42
+ group_size: int
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class LoraConfig:
47
+ rank: int
48
+ scale: float
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class ExecutorchConfig(ForeignConfig):
53
+ @property
54
+ def default_precision(self) -> DTypeLike:
55
+ return jnp.bfloat16
56
+
57
+ @classmethod
58
+ def _load_weights(
59
+ cls,
60
+ model: Decoder,
61
+ weights_dict: dict[str, Array],
62
+ ) -> Decoder:
63
+ return load_executorch(model, weights_dict)
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class ETLlamaConfig(ExecutorchConfig):
68
+ dim: int
69
+ n_layers: int
70
+ n_heads: int
71
+ n_kv_heads: int
72
+ vocab_size: int
73
+ ffn_dim_multiplier: float
74
+ multiple_of: int
75
+ norm_eps: float
76
+ rope_theta: float
77
+ use_scaled_rope: bool
78
+ quantization_args: QuantizationConfig | None = None
79
+ lora_args: LoraConfig | None = None
80
+
81
+ def _find_hidden_size(self) -> int:
82
+ # Magic formula from executorch
83
+ size_candidate = int(8 / 3 * self.dim * self.ffn_dim_multiplier)
84
+ return size_candidate // self.multiple_of * self.multiple_of
85
+
86
+ def to_decoder_config(
87
+ self,
88
+ context_length: int | None,
89
+ activation_precision: DTypeLike,
90
+ accumulation_precision: DTypeLike,
91
+ ) -> DecoderConfig:
92
+ if self.lora_args is None:
93
+ raise ValueError("We only support QLoRA models for now.")
94
+
95
+ if self.quantization_args is None:
96
+ raise ValueError("Quantization arguments are required for QLoRA models.")
97
+
98
+ embedding_config = QuantizedTiedEmbeddingConfig(
99
+ input_scale=None,
100
+ logits_soft_cap=None,
101
+ embedding_quantization_mode=EMBEDDING_QUANTIZATION_MODE,
102
+ activation_quantization_mode=ACTIVATION_QUANTIZATION_MODE,
103
+ activation_precision=activation_precision,
104
+ )
105
+ rope_config = LlamaRoPEConfig(
106
+ precision=activation_precision,
107
+ base=self.rope_theta,
108
+ max_sequence_length=MAX_SEQUENCE_LENGTH,
109
+ scaling_factor=ROPE_SCALING_FACTOR,
110
+ original_context_length=OLD_CONTEXT_LENGTH,
111
+ low_frequency_factor=LOW_FREQ_FACTOR,
112
+ high_frequency_factor=HIGH_FREQ_FACTOR,
113
+ )
114
+ rmsnorm_config = RMSNormConfig(
115
+ scale_precision=activation_precision,
116
+ accumulation_precision=accumulation_precision,
117
+ epsilon=self.norm_eps,
118
+ scale_offset=None,
119
+ upcast_mode=UpcastMode.ONLY_NORMALIZATION,
120
+ )
121
+ linear_config = QLoRALinearConfig(
122
+ group_size=self.quantization_args.group_size,
123
+ weight_quantization_mode=WEIGHT_QUANTIZATION_MODE,
124
+ activation_quantization_mode=ACTIVATION_QUANTIZATION_MODE,
125
+ activation_precision=activation_precision,
126
+ lora_rank=self.lora_args.rank,
127
+ lora_scale=self.lora_args.scale,
128
+ )
129
+ attention_config = AttentionConfig(
130
+ qkv_projection_config=linear_config,
131
+ out_projection_config=linear_config,
132
+ query_norm_config=None,
133
+ key_norm_config=None,
134
+ logit_soft_cap=None,
135
+ has_qkv_biases=False,
136
+ has_out_biases=False,
137
+ )
138
+ mlp_config = MLPConfig(
139
+ linear_config=linear_config,
140
+ activation=Activation.SILU,
141
+ )
142
+ decoder_layer_config = DecoderLayerConfig(
143
+ pre_attention_norm_config=rmsnorm_config,
144
+ attention_config=attention_config,
145
+ post_attention_norm_config=None,
146
+ pre_mlp_norm_config=rmsnorm_config,
147
+ mlp_config=mlp_config,
148
+ post_mlp_norm_config=None,
149
+ )
150
+ return DecoderConfig(
151
+ embedding_config=embedding_config,
152
+ global_rope_config=rope_config,
153
+ local_rope_config=None,
154
+ layer_config=decoder_layer_config,
155
+ output_norm_config=rmsnorm_config,
156
+ vocab_size=self.vocab_size,
157
+ model_dim=self.dim,
158
+ hidden_dim=self._find_hidden_size(),
159
+ num_heads=self.n_heads,
160
+ num_groups=self.n_kv_heads,
161
+ head_dim=self.dim // self.n_heads,
162
+ attention_scale=None,
163
+ num_layers=self.n_layers,
164
+ sliding_window_sizes=None,
165
+ context_length=context_length or MAX_SEQUENCE_LENGTH,
166
+ )
@@ -0,0 +1,18 @@
1
+ from .common import HuggingFaceConfig
2
+ from .gemma2 import HFGemma2Config
3
+ from .gemma3 import HFGemma3Config, HFGemma3TextConfig
4
+ from .llama import HFLlamaConfig
5
+ from .mistral import HFMistralConfig
6
+ from .qwen2 import HFQwen2Config
7
+ from .qwen3 import HFQwen3Config
8
+
9
+ __all__ = [
10
+ "HFGemma2Config",
11
+ "HFGemma3Config",
12
+ "HFGemma3TextConfig",
13
+ "HFLlamaConfig",
14
+ "HFMistralConfig",
15
+ "HFQwen2Config",
16
+ "HFQwen3Config",
17
+ "HuggingFaceConfig",
18
+ ]
@@ -0,0 +1,72 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import jax.numpy as jnp
5
+ from jaxtyping import Array, DTypeLike
6
+
7
+ from lalamo.model_import.configs import ForeignConfig
8
+ from lalamo.model_import.loaders import load_huggingface
9
+ from lalamo.modules import Decoder
10
+
11
+ __all__ = [
12
+ "HuggingFaceConfig",
13
+ "AWQQuantizationConfig",
14
+ "GPTQMetaConfig",
15
+ "GPTQQuantizationConfig"
16
+ ]
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class AWQQuantizationConfig:
21
+ backend: Literal["autoawq"] = "autoawq"
22
+ bits: Literal[4, 8] = 4
23
+ do_fuse: Literal[False] = False
24
+ exllama_config: None = None
25
+ fuse_max_seq_len: None = None
26
+ group_size: int = 128
27
+ modules_to_fuse: None = None
28
+ modules_to_not_convert: None = None
29
+ quant_method: Literal["awq"] = "awq"
30
+ version: Literal["gemm"] = "gemm"
31
+ zero_point: bool = True
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class GPTQMetaConfig:
36
+ damp_auto_increment: float
37
+ damp_percent: float
38
+ mse: float
39
+ quantizer: list[str]
40
+ static_groups: bool
41
+ true_sequential: bool
42
+ uri: str
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class GPTQQuantizationConfig:
47
+ bits: int
48
+ checkpoint_format: str
49
+ desc_act: bool
50
+ group_size: int
51
+ lm_head: bool
52
+ meta: GPTQMetaConfig
53
+ pack_dtype: str
54
+ quant_method: Literal["gptq"]
55
+ sym: bool
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class HuggingFaceConfig(ForeignConfig):
60
+ torch_dtype: Literal["bfloat16", "float16", "float32"]
61
+
62
+ @property
63
+ def default_precision(self) -> DTypeLike:
64
+ return jnp.dtype(self.torch_dtype)
65
+
66
+ @classmethod
67
+ def _load_weights(
68
+ cls,
69
+ model: Decoder,
70
+ weights_dict: dict[str, Array],
71
+ ) -> Decoder:
72
+ return load_huggingface(model, weights_dict)
@@ -0,0 +1,122 @@
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ from jaxtyping import DTypeLike
5
+
6
+ from lalamo.modules import (
7
+ Activation,
8
+ AttentionConfig,
9
+ DecoderConfig,
10
+ DecoderLayerConfig,
11
+ FullPrecisionLinearConfig,
12
+ MLPConfig,
13
+ RMSNormConfig,
14
+ TiedEmbeddingConfig,
15
+ UnscaledRoPEConfig,
16
+ UpcastMode,
17
+ )
18
+
19
+ from .common import HuggingFaceConfig
20
+
21
+ __all__ = ["HFGemma2Config"]
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class HFGemma2Config(HuggingFaceConfig):
26
+ architectures: list[Literal["Gemma2ForCausalLM"]]
27
+ attention_bias: bool
28
+ attention_dropout: float
29
+ attn_logit_softcapping: float
30
+ bos_token_id: int | list[int]
31
+ cache_implementation: Literal["hybrid"]
32
+ eos_token_id: int | list[int]
33
+ final_logit_softcapping: float
34
+ head_dim: int
35
+ hidden_act: Literal["gelu_pytorch_tanh"]
36
+ hidden_activation: Literal["gelu_pytorch_tanh"]
37
+ hidden_size: int
38
+ initializer_range: float
39
+ intermediate_size: int
40
+ max_position_embeddings: int
41
+ model_type: Literal["gemma2"]
42
+ num_attention_heads: int
43
+ num_hidden_layers: int
44
+ num_key_value_heads: int
45
+ pad_token_id: int
46
+ query_pre_attn_scalar: float
47
+ rms_norm_eps: float
48
+ rope_theta: float
49
+ sliding_window: int
50
+ transformers_version: str
51
+ use_cache: bool
52
+ vocab_size: int
53
+
54
+ def to_decoder_config(
55
+ self,
56
+ context_length: int | None,
57
+ activation_precision: DTypeLike,
58
+ accumulation_precision: DTypeLike,
59
+ ) -> DecoderConfig:
60
+ sliding_window_sizes = tuple(
61
+ self.sliding_window if not bool(i % 2) else None for i in range(self.num_hidden_layers)
62
+ )
63
+ embedding_input_scale = self.hidden_size**0.5
64
+ attention_scale = self.query_pre_attn_scalar**-0.5
65
+ embedding_config = TiedEmbeddingConfig(
66
+ input_scale=embedding_input_scale,
67
+ logits_soft_cap=self.final_logit_softcapping,
68
+ precision=activation_precision,
69
+ )
70
+ rope_config = UnscaledRoPEConfig(
71
+ precision=activation_precision,
72
+ base=self.rope_theta,
73
+ max_sequence_length=self.max_position_embeddings,
74
+ )
75
+ rmsnorm_config = RMSNormConfig(
76
+ scale_precision=activation_precision,
77
+ accumulation_precision=accumulation_precision,
78
+ epsilon=self.rms_norm_eps,
79
+ scale_offset=1.0,
80
+ upcast_mode=UpcastMode.FULL_LAYER,
81
+ )
82
+ linear_config = FullPrecisionLinearConfig(
83
+ precision=activation_precision,
84
+ )
85
+ attention_config = AttentionConfig(
86
+ qkv_projection_config=linear_config,
87
+ out_projection_config=linear_config,
88
+ query_norm_config=None,
89
+ key_norm_config=None,
90
+ logit_soft_cap=self.attn_logit_softcapping,
91
+ has_qkv_biases=self.attention_bias,
92
+ has_out_biases=False,
93
+ )
94
+ mlp_config = MLPConfig(
95
+ linear_config=linear_config,
96
+ activation=Activation.GELU,
97
+ )
98
+ decoder_layer_config = DecoderLayerConfig(
99
+ pre_attention_norm_config=rmsnorm_config,
100
+ attention_config=attention_config,
101
+ post_attention_norm_config=rmsnorm_config,
102
+ pre_mlp_norm_config=rmsnorm_config,
103
+ mlp_config=mlp_config,
104
+ post_mlp_norm_config=rmsnorm_config,
105
+ )
106
+ return DecoderConfig(
107
+ embedding_config=embedding_config,
108
+ global_rope_config=rope_config,
109
+ local_rope_config=None,
110
+ layer_config=decoder_layer_config,
111
+ output_norm_config=rmsnorm_config,
112
+ vocab_size=self.vocab_size,
113
+ model_dim=self.hidden_size,
114
+ hidden_dim=self.intermediate_size,
115
+ num_heads=self.num_attention_heads,
116
+ num_groups=self.num_key_value_heads,
117
+ head_dim=self.head_dim,
118
+ attention_scale=attention_scale,
119
+ num_layers=self.num_hidden_layers,
120
+ sliding_window_sizes=sliding_window_sizes,
121
+ context_length=context_length or self.max_position_embeddings,
122
+ )