lalamo 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/model_import/__init__.py +8 -0
- lalamo/model_import/common.py +111 -0
- lalamo/model_import/configs/__init__.py +23 -0
- lalamo/model_import/configs/common.py +62 -0
- lalamo/model_import/configs/executorch.py +166 -0
- lalamo/model_import/configs/huggingface/__init__.py +18 -0
- lalamo/model_import/configs/huggingface/common.py +72 -0
- lalamo/model_import/configs/huggingface/gemma2.py +122 -0
- lalamo/model_import/configs/huggingface/gemma3.py +187 -0
- lalamo/model_import/configs/huggingface/llama.py +155 -0
- lalamo/model_import/configs/huggingface/mistral.py +132 -0
- lalamo/model_import/configs/huggingface/qwen2.py +144 -0
- lalamo/model_import/configs/huggingface/qwen3.py +142 -0
- lalamo/model_import/loaders/__init__.py +7 -0
- lalamo/model_import/loaders/common.py +45 -0
- lalamo/model_import/loaders/executorch.py +223 -0
- lalamo/model_import/loaders/huggingface.py +304 -0
- lalamo/model_import/model_specs/__init__.py +38 -0
- lalamo/model_import/model_specs/common.py +118 -0
- lalamo/model_import/model_specs/deepseek.py +28 -0
- lalamo/model_import/model_specs/gemma.py +76 -0
- lalamo/model_import/model_specs/huggingface.py +28 -0
- lalamo/model_import/model_specs/llama.py +101 -0
- lalamo/model_import/model_specs/mistral.py +59 -0
- lalamo/model_import/model_specs/pleias.py +28 -0
- lalamo/model_import/model_specs/polaris.py +22 -0
- lalamo/model_import/model_specs/qwen.py +336 -0
- lalamo/model_import/model_specs/reka.py +28 -0
- lalamo/modules/__init__.py +85 -0
- lalamo/modules/activations.py +30 -0
- lalamo/modules/attention.py +326 -0
- lalamo/modules/common.py +133 -0
- lalamo/modules/decoder.py +244 -0
- lalamo/modules/decoder_layer.py +240 -0
- lalamo/modules/embedding.py +299 -0
- lalamo/modules/kv_cache.py +196 -0
- lalamo/modules/linear.py +603 -0
- lalamo/modules/mlp.py +79 -0
- lalamo/modules/normalization.py +77 -0
- lalamo/modules/rope.py +255 -0
- lalamo/modules/utils.py +13 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/METADATA +1 -1
- lalamo-0.2.2.dist-info/RECORD +53 -0
- lalamo-0.2.1.dist-info/RECORD +0 -12
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/WHEEL +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.1.dist-info → lalamo-0.2.2.dist-info}/top_level.txt +0 -0
lalamo/__init__.py
CHANGED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
import huggingface_hub
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jaxtyping import DTypeLike
|
|
9
|
+
|
|
10
|
+
from lalamo.modules import Decoder, DecoderConfig
|
|
11
|
+
from lalamo.quantization import QuantizationMode
|
|
12
|
+
|
|
13
|
+
from .model_specs import REPO_TO_MODEL, ModelSpec, UseCase
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"REPO_TO_MODEL",
|
|
17
|
+
"ModelMetadata",
|
|
18
|
+
"ModelSpec",
|
|
19
|
+
"import_model",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
LALAMO_VERSION = importlib.metadata.version("lalamo")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class ModelMetadata:
|
|
28
|
+
toolchain_version: str
|
|
29
|
+
vendor: str
|
|
30
|
+
family: str
|
|
31
|
+
name: str
|
|
32
|
+
size: str
|
|
33
|
+
quantization: QuantizationMode | None
|
|
34
|
+
repo: str
|
|
35
|
+
use_cases: tuple[UseCase, ...]
|
|
36
|
+
model_config: DecoderConfig
|
|
37
|
+
tokenizer_file_names: tuple[str, ...]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def download_weights(model_spec: ModelSpec, output_dir: Path | str | None = None) -> list[Path]:
|
|
41
|
+
result = [
|
|
42
|
+
huggingface_hub.hf_hub_download(
|
|
43
|
+
repo_id=model_spec.repo,
|
|
44
|
+
local_dir=output_dir,
|
|
45
|
+
filename=filename,
|
|
46
|
+
)
|
|
47
|
+
for filename in model_spec.weights_file_names
|
|
48
|
+
]
|
|
49
|
+
return [Path(path) for path in result]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def download_config_file(model_spec: ModelSpec, output_dir: Path | str | None = None) -> Path:
|
|
53
|
+
result = huggingface_hub.hf_hub_download(
|
|
54
|
+
repo_id=model_spec.repo,
|
|
55
|
+
local_dir=output_dir,
|
|
56
|
+
filename=model_spec.config_file_name,
|
|
57
|
+
)
|
|
58
|
+
return Path(result)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def download_tokenizer_files(model_spec: ModelSpec, output_dir: Path | str | None = None) -> tuple[Path, ...]:
|
|
62
|
+
result = [
|
|
63
|
+
huggingface_hub.hf_hub_download(
|
|
64
|
+
repo_id=tokenizer_file_spec.repo or model_spec.repo,
|
|
65
|
+
local_dir=output_dir,
|
|
66
|
+
filename=tokenizer_file_spec.filename,
|
|
67
|
+
)
|
|
68
|
+
for tokenizer_file_spec in model_spec.tokenizer_files
|
|
69
|
+
]
|
|
70
|
+
return tuple(Path(path) for path in result)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ImportResults(NamedTuple):
|
|
74
|
+
model: Decoder
|
|
75
|
+
metadata: ModelMetadata
|
|
76
|
+
tokenizer_file_paths: tuple[Path, ...]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def import_model(
|
|
80
|
+
model_spec: ModelSpec,
|
|
81
|
+
*,
|
|
82
|
+
context_length: int | None = None,
|
|
83
|
+
precision: DTypeLike | None = None,
|
|
84
|
+
accumulation_precision: DTypeLike = jnp.float32,
|
|
85
|
+
) -> ImportResults:
|
|
86
|
+
foreign_config_file = download_config_file(model_spec)
|
|
87
|
+
foreign_config = model_spec.config_type.from_json(foreign_config_file)
|
|
88
|
+
|
|
89
|
+
tokenizer_file_paths = download_tokenizer_files(model_spec)
|
|
90
|
+
if precision is None:
|
|
91
|
+
precision = foreign_config.default_precision
|
|
92
|
+
|
|
93
|
+
weights_paths = download_weights(model_spec)
|
|
94
|
+
weights_dict = {}
|
|
95
|
+
for weights_path in weights_paths:
|
|
96
|
+
weights_dict.update(model_spec.weights_type.load(weights_path, precision))
|
|
97
|
+
|
|
98
|
+
model = foreign_config.load_model(context_length, precision, accumulation_precision, weights_dict)
|
|
99
|
+
metadata = ModelMetadata(
|
|
100
|
+
toolchain_version=LALAMO_VERSION,
|
|
101
|
+
vendor=model_spec.vendor,
|
|
102
|
+
family=model_spec.family,
|
|
103
|
+
name=model_spec.name,
|
|
104
|
+
size=model_spec.size,
|
|
105
|
+
quantization=model_spec.quantization,
|
|
106
|
+
repo=model_spec.repo,
|
|
107
|
+
use_cases=model_spec.use_cases,
|
|
108
|
+
model_config=model.config,
|
|
109
|
+
tokenizer_file_names=tuple(p.name for p in tokenizer_file_paths),
|
|
110
|
+
)
|
|
111
|
+
return ImportResults(model, metadata, tokenizer_file_paths)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .common import ForeignConfig
|
|
2
|
+
from .executorch import ETLlamaConfig
|
|
3
|
+
from .huggingface import (
|
|
4
|
+
HFGemma2Config,
|
|
5
|
+
HFGemma3Config,
|
|
6
|
+
HFGemma3TextConfig,
|
|
7
|
+
HFLlamaConfig,
|
|
8
|
+
HFMistralConfig,
|
|
9
|
+
HFQwen2Config,
|
|
10
|
+
HFQwen3Config,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"ETLlamaConfig",
|
|
15
|
+
"ForeignConfig",
|
|
16
|
+
"HFGemma2Config",
|
|
17
|
+
"HFGemma3Config",
|
|
18
|
+
"HFGemma3TextConfig",
|
|
19
|
+
"HFLlamaConfig",
|
|
20
|
+
"HFMistralConfig",
|
|
21
|
+
"HFQwen2Config",
|
|
22
|
+
"HFQwen3Config",
|
|
23
|
+
]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import ClassVar, Self
|
|
6
|
+
|
|
7
|
+
import cattrs
|
|
8
|
+
import jax
|
|
9
|
+
from jaxtyping import Array, DTypeLike
|
|
10
|
+
|
|
11
|
+
from lalamo.modules import Decoder, DecoderConfig
|
|
12
|
+
|
|
13
|
+
__all__ = ["ForeignConfig"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class ForeignConfig:
|
|
18
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
19
|
+
_converter.register_structure_hook(int | list[int], lambda v, _: v)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def default_precision(self) -> DTypeLike: ...
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def from_json(cls, json_path: Path | str) -> Self:
|
|
27
|
+
json_path = Path(json_path)
|
|
28
|
+
with open(json_path) as f:
|
|
29
|
+
config = json.load(f)
|
|
30
|
+
return cls._converter.structure(config, cls)
|
|
31
|
+
|
|
32
|
+
def to_json(self, json_path: Path | str) -> None:
|
|
33
|
+
json_path = Path(json_path)
|
|
34
|
+
with open(json_path, "w") as f:
|
|
35
|
+
json.dump(self._converter.unstructure(self), f, indent=2)
|
|
36
|
+
|
|
37
|
+
def to_decoder_config(
|
|
38
|
+
self,
|
|
39
|
+
context_length: int | None,
|
|
40
|
+
activation_precision: DTypeLike,
|
|
41
|
+
accumulation_precision: DTypeLike,
|
|
42
|
+
) -> DecoderConfig:
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def _load_weights(
|
|
47
|
+
cls,
|
|
48
|
+
model: Decoder,
|
|
49
|
+
weights_dict: dict[str, Array],
|
|
50
|
+
) -> Decoder:
|
|
51
|
+
raise NotImplementedError
|
|
52
|
+
|
|
53
|
+
def load_model(
|
|
54
|
+
self,
|
|
55
|
+
context_length: int | None,
|
|
56
|
+
activation_precision: DTypeLike,
|
|
57
|
+
accumulation_precision: DTypeLike,
|
|
58
|
+
weights_dict: dict[str, Array],
|
|
59
|
+
) -> Decoder:
|
|
60
|
+
config = self.to_decoder_config(context_length, activation_precision, accumulation_precision)
|
|
61
|
+
model = config.random_init(key=jax.random.PRNGKey(0))
|
|
62
|
+
return self._load_weights(model, weights_dict)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array, DTypeLike
|
|
5
|
+
|
|
6
|
+
from lalamo.model_import.loaders import load_executorch
|
|
7
|
+
from lalamo.modules import (
|
|
8
|
+
Activation,
|
|
9
|
+
AttentionConfig,
|
|
10
|
+
Decoder,
|
|
11
|
+
DecoderConfig,
|
|
12
|
+
DecoderLayerConfig,
|
|
13
|
+
LlamaRoPEConfig,
|
|
14
|
+
MLPConfig,
|
|
15
|
+
QLoRALinearConfig,
|
|
16
|
+
QuantizedTiedEmbeddingConfig,
|
|
17
|
+
RMSNormConfig,
|
|
18
|
+
UpcastMode,
|
|
19
|
+
)
|
|
20
|
+
from lalamo.quantization import QuantizationMode
|
|
21
|
+
|
|
22
|
+
from .common import ForeignConfig
|
|
23
|
+
|
|
24
|
+
__all__ = ["ETLlamaConfig"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# These parameters are not present in the config file, and are extracted from the executorch implementation
|
|
28
|
+
LOW_FREQ_FACTOR = 1.0
|
|
29
|
+
HIGH_FREQ_FACTOR = 4.0
|
|
30
|
+
OLD_CONTEXT_LENGTH = 8192
|
|
31
|
+
MAX_SEQUENCE_LENGTH = 8192 * 32
|
|
32
|
+
|
|
33
|
+
ROPE_SCALING_FACTOR = 32.0
|
|
34
|
+
|
|
35
|
+
EMBEDDING_QUANTIZATION_MODE = QuantizationMode.INT8
|
|
36
|
+
ACTIVATION_QUANTIZATION_MODE = QuantizationMode.INT8
|
|
37
|
+
WEIGHT_QUANTIZATION_MODE = QuantizationMode.UINT4
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class QuantizationConfig:
|
|
42
|
+
group_size: int
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(frozen=True)
|
|
46
|
+
class LoraConfig:
|
|
47
|
+
rank: int
|
|
48
|
+
scale: float
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class ExecutorchConfig(ForeignConfig):
|
|
53
|
+
@property
|
|
54
|
+
def default_precision(self) -> DTypeLike:
|
|
55
|
+
return jnp.bfloat16
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def _load_weights(
|
|
59
|
+
cls,
|
|
60
|
+
model: Decoder,
|
|
61
|
+
weights_dict: dict[str, Array],
|
|
62
|
+
) -> Decoder:
|
|
63
|
+
return load_executorch(model, weights_dict)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True)
|
|
67
|
+
class ETLlamaConfig(ExecutorchConfig):
|
|
68
|
+
dim: int
|
|
69
|
+
n_layers: int
|
|
70
|
+
n_heads: int
|
|
71
|
+
n_kv_heads: int
|
|
72
|
+
vocab_size: int
|
|
73
|
+
ffn_dim_multiplier: float
|
|
74
|
+
multiple_of: int
|
|
75
|
+
norm_eps: float
|
|
76
|
+
rope_theta: float
|
|
77
|
+
use_scaled_rope: bool
|
|
78
|
+
quantization_args: QuantizationConfig | None = None
|
|
79
|
+
lora_args: LoraConfig | None = None
|
|
80
|
+
|
|
81
|
+
def _find_hidden_size(self) -> int:
|
|
82
|
+
# Magic formula from executorch
|
|
83
|
+
size_candidate = int(8 / 3 * self.dim * self.ffn_dim_multiplier)
|
|
84
|
+
return size_candidate // self.multiple_of * self.multiple_of
|
|
85
|
+
|
|
86
|
+
def to_decoder_config(
|
|
87
|
+
self,
|
|
88
|
+
context_length: int | None,
|
|
89
|
+
activation_precision: DTypeLike,
|
|
90
|
+
accumulation_precision: DTypeLike,
|
|
91
|
+
) -> DecoderConfig:
|
|
92
|
+
if self.lora_args is None:
|
|
93
|
+
raise ValueError("We only support QLoRA models for now.")
|
|
94
|
+
|
|
95
|
+
if self.quantization_args is None:
|
|
96
|
+
raise ValueError("Quantization arguments are required for QLoRA models.")
|
|
97
|
+
|
|
98
|
+
embedding_config = QuantizedTiedEmbeddingConfig(
|
|
99
|
+
input_scale=None,
|
|
100
|
+
logits_soft_cap=None,
|
|
101
|
+
embedding_quantization_mode=EMBEDDING_QUANTIZATION_MODE,
|
|
102
|
+
activation_quantization_mode=ACTIVATION_QUANTIZATION_MODE,
|
|
103
|
+
activation_precision=activation_precision,
|
|
104
|
+
)
|
|
105
|
+
rope_config = LlamaRoPEConfig(
|
|
106
|
+
precision=activation_precision,
|
|
107
|
+
base=self.rope_theta,
|
|
108
|
+
max_sequence_length=MAX_SEQUENCE_LENGTH,
|
|
109
|
+
scaling_factor=ROPE_SCALING_FACTOR,
|
|
110
|
+
original_context_length=OLD_CONTEXT_LENGTH,
|
|
111
|
+
low_frequency_factor=LOW_FREQ_FACTOR,
|
|
112
|
+
high_frequency_factor=HIGH_FREQ_FACTOR,
|
|
113
|
+
)
|
|
114
|
+
rmsnorm_config = RMSNormConfig(
|
|
115
|
+
scale_precision=activation_precision,
|
|
116
|
+
accumulation_precision=accumulation_precision,
|
|
117
|
+
epsilon=self.norm_eps,
|
|
118
|
+
scale_offset=None,
|
|
119
|
+
upcast_mode=UpcastMode.ONLY_NORMALIZATION,
|
|
120
|
+
)
|
|
121
|
+
linear_config = QLoRALinearConfig(
|
|
122
|
+
group_size=self.quantization_args.group_size,
|
|
123
|
+
weight_quantization_mode=WEIGHT_QUANTIZATION_MODE,
|
|
124
|
+
activation_quantization_mode=ACTIVATION_QUANTIZATION_MODE,
|
|
125
|
+
activation_precision=activation_precision,
|
|
126
|
+
lora_rank=self.lora_args.rank,
|
|
127
|
+
lora_scale=self.lora_args.scale,
|
|
128
|
+
)
|
|
129
|
+
attention_config = AttentionConfig(
|
|
130
|
+
qkv_projection_config=linear_config,
|
|
131
|
+
out_projection_config=linear_config,
|
|
132
|
+
query_norm_config=None,
|
|
133
|
+
key_norm_config=None,
|
|
134
|
+
logit_soft_cap=None,
|
|
135
|
+
has_qkv_biases=False,
|
|
136
|
+
has_out_biases=False,
|
|
137
|
+
)
|
|
138
|
+
mlp_config = MLPConfig(
|
|
139
|
+
linear_config=linear_config,
|
|
140
|
+
activation=Activation.SILU,
|
|
141
|
+
)
|
|
142
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
143
|
+
pre_attention_norm_config=rmsnorm_config,
|
|
144
|
+
attention_config=attention_config,
|
|
145
|
+
post_attention_norm_config=None,
|
|
146
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
147
|
+
mlp_config=mlp_config,
|
|
148
|
+
post_mlp_norm_config=None,
|
|
149
|
+
)
|
|
150
|
+
return DecoderConfig(
|
|
151
|
+
embedding_config=embedding_config,
|
|
152
|
+
global_rope_config=rope_config,
|
|
153
|
+
local_rope_config=None,
|
|
154
|
+
layer_config=decoder_layer_config,
|
|
155
|
+
output_norm_config=rmsnorm_config,
|
|
156
|
+
vocab_size=self.vocab_size,
|
|
157
|
+
model_dim=self.dim,
|
|
158
|
+
hidden_dim=self._find_hidden_size(),
|
|
159
|
+
num_heads=self.n_heads,
|
|
160
|
+
num_groups=self.n_kv_heads,
|
|
161
|
+
head_dim=self.dim // self.n_heads,
|
|
162
|
+
attention_scale=None,
|
|
163
|
+
num_layers=self.n_layers,
|
|
164
|
+
sliding_window_sizes=None,
|
|
165
|
+
context_length=context_length or MAX_SEQUENCE_LENGTH,
|
|
166
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .common import HuggingFaceConfig
|
|
2
|
+
from .gemma2 import HFGemma2Config
|
|
3
|
+
from .gemma3 import HFGemma3Config, HFGemma3TextConfig
|
|
4
|
+
from .llama import HFLlamaConfig
|
|
5
|
+
from .mistral import HFMistralConfig
|
|
6
|
+
from .qwen2 import HFQwen2Config
|
|
7
|
+
from .qwen3 import HFQwen3Config
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"HFGemma2Config",
|
|
11
|
+
"HFGemma3Config",
|
|
12
|
+
"HFGemma3TextConfig",
|
|
13
|
+
"HFLlamaConfig",
|
|
14
|
+
"HFMistralConfig",
|
|
15
|
+
"HFQwen2Config",
|
|
16
|
+
"HFQwen3Config",
|
|
17
|
+
"HuggingFaceConfig",
|
|
18
|
+
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jaxtyping import Array, DTypeLike
|
|
6
|
+
|
|
7
|
+
from lalamo.model_import.configs import ForeignConfig
|
|
8
|
+
from lalamo.model_import.loaders import load_huggingface
|
|
9
|
+
from lalamo.modules import Decoder
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"HuggingFaceConfig",
|
|
13
|
+
"AWQQuantizationConfig",
|
|
14
|
+
"GPTQMetaConfig",
|
|
15
|
+
"GPTQQuantizationConfig"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class AWQQuantizationConfig:
|
|
21
|
+
backend: Literal["autoawq"] = "autoawq"
|
|
22
|
+
bits: Literal[4, 8] = 4
|
|
23
|
+
do_fuse: Literal[False] = False
|
|
24
|
+
exllama_config: None = None
|
|
25
|
+
fuse_max_seq_len: None = None
|
|
26
|
+
group_size: int = 128
|
|
27
|
+
modules_to_fuse: None = None
|
|
28
|
+
modules_to_not_convert: None = None
|
|
29
|
+
quant_method: Literal["awq"] = "awq"
|
|
30
|
+
version: Literal["gemm"] = "gemm"
|
|
31
|
+
zero_point: bool = True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class GPTQMetaConfig:
|
|
36
|
+
damp_auto_increment: float
|
|
37
|
+
damp_percent: float
|
|
38
|
+
mse: float
|
|
39
|
+
quantizer: list[str]
|
|
40
|
+
static_groups: bool
|
|
41
|
+
true_sequential: bool
|
|
42
|
+
uri: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass(frozen=True)
|
|
46
|
+
class GPTQQuantizationConfig:
|
|
47
|
+
bits: int
|
|
48
|
+
checkpoint_format: str
|
|
49
|
+
desc_act: bool
|
|
50
|
+
group_size: int
|
|
51
|
+
lm_head: bool
|
|
52
|
+
meta: GPTQMetaConfig
|
|
53
|
+
pack_dtype: str
|
|
54
|
+
quant_method: Literal["gptq"]
|
|
55
|
+
sym: bool
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass(frozen=True)
|
|
59
|
+
class HuggingFaceConfig(ForeignConfig):
|
|
60
|
+
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def default_precision(self) -> DTypeLike:
|
|
64
|
+
return jnp.dtype(self.torch_dtype)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def _load_weights(
|
|
68
|
+
cls,
|
|
69
|
+
model: Decoder,
|
|
70
|
+
weights_dict: dict[str, Array],
|
|
71
|
+
) -> Decoder:
|
|
72
|
+
return load_huggingface(model, weights_dict)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from jaxtyping import DTypeLike
|
|
5
|
+
|
|
6
|
+
from lalamo.modules import (
|
|
7
|
+
Activation,
|
|
8
|
+
AttentionConfig,
|
|
9
|
+
DecoderConfig,
|
|
10
|
+
DecoderLayerConfig,
|
|
11
|
+
FullPrecisionLinearConfig,
|
|
12
|
+
MLPConfig,
|
|
13
|
+
RMSNormConfig,
|
|
14
|
+
TiedEmbeddingConfig,
|
|
15
|
+
UnscaledRoPEConfig,
|
|
16
|
+
UpcastMode,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from .common import HuggingFaceConfig
|
|
20
|
+
|
|
21
|
+
__all__ = ["HFGemma2Config"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass(frozen=True)
|
|
25
|
+
class HFGemma2Config(HuggingFaceConfig):
|
|
26
|
+
architectures: list[Literal["Gemma2ForCausalLM"]]
|
|
27
|
+
attention_bias: bool
|
|
28
|
+
attention_dropout: float
|
|
29
|
+
attn_logit_softcapping: float
|
|
30
|
+
bos_token_id: int | list[int]
|
|
31
|
+
cache_implementation: Literal["hybrid"]
|
|
32
|
+
eos_token_id: int | list[int]
|
|
33
|
+
final_logit_softcapping: float
|
|
34
|
+
head_dim: int
|
|
35
|
+
hidden_act: Literal["gelu_pytorch_tanh"]
|
|
36
|
+
hidden_activation: Literal["gelu_pytorch_tanh"]
|
|
37
|
+
hidden_size: int
|
|
38
|
+
initializer_range: float
|
|
39
|
+
intermediate_size: int
|
|
40
|
+
max_position_embeddings: int
|
|
41
|
+
model_type: Literal["gemma2"]
|
|
42
|
+
num_attention_heads: int
|
|
43
|
+
num_hidden_layers: int
|
|
44
|
+
num_key_value_heads: int
|
|
45
|
+
pad_token_id: int
|
|
46
|
+
query_pre_attn_scalar: float
|
|
47
|
+
rms_norm_eps: float
|
|
48
|
+
rope_theta: float
|
|
49
|
+
sliding_window: int
|
|
50
|
+
transformers_version: str
|
|
51
|
+
use_cache: bool
|
|
52
|
+
vocab_size: int
|
|
53
|
+
|
|
54
|
+
def to_decoder_config(
|
|
55
|
+
self,
|
|
56
|
+
context_length: int | None,
|
|
57
|
+
activation_precision: DTypeLike,
|
|
58
|
+
accumulation_precision: DTypeLike,
|
|
59
|
+
) -> DecoderConfig:
|
|
60
|
+
sliding_window_sizes = tuple(
|
|
61
|
+
self.sliding_window if not bool(i % 2) else None for i in range(self.num_hidden_layers)
|
|
62
|
+
)
|
|
63
|
+
embedding_input_scale = self.hidden_size**0.5
|
|
64
|
+
attention_scale = self.query_pre_attn_scalar**-0.5
|
|
65
|
+
embedding_config = TiedEmbeddingConfig(
|
|
66
|
+
input_scale=embedding_input_scale,
|
|
67
|
+
logits_soft_cap=self.final_logit_softcapping,
|
|
68
|
+
precision=activation_precision,
|
|
69
|
+
)
|
|
70
|
+
rope_config = UnscaledRoPEConfig(
|
|
71
|
+
precision=activation_precision,
|
|
72
|
+
base=self.rope_theta,
|
|
73
|
+
max_sequence_length=self.max_position_embeddings,
|
|
74
|
+
)
|
|
75
|
+
rmsnorm_config = RMSNormConfig(
|
|
76
|
+
scale_precision=activation_precision,
|
|
77
|
+
accumulation_precision=accumulation_precision,
|
|
78
|
+
epsilon=self.rms_norm_eps,
|
|
79
|
+
scale_offset=1.0,
|
|
80
|
+
upcast_mode=UpcastMode.FULL_LAYER,
|
|
81
|
+
)
|
|
82
|
+
linear_config = FullPrecisionLinearConfig(
|
|
83
|
+
precision=activation_precision,
|
|
84
|
+
)
|
|
85
|
+
attention_config = AttentionConfig(
|
|
86
|
+
qkv_projection_config=linear_config,
|
|
87
|
+
out_projection_config=linear_config,
|
|
88
|
+
query_norm_config=None,
|
|
89
|
+
key_norm_config=None,
|
|
90
|
+
logit_soft_cap=self.attn_logit_softcapping,
|
|
91
|
+
has_qkv_biases=self.attention_bias,
|
|
92
|
+
has_out_biases=False,
|
|
93
|
+
)
|
|
94
|
+
mlp_config = MLPConfig(
|
|
95
|
+
linear_config=linear_config,
|
|
96
|
+
activation=Activation.GELU,
|
|
97
|
+
)
|
|
98
|
+
decoder_layer_config = DecoderLayerConfig(
|
|
99
|
+
pre_attention_norm_config=rmsnorm_config,
|
|
100
|
+
attention_config=attention_config,
|
|
101
|
+
post_attention_norm_config=rmsnorm_config,
|
|
102
|
+
pre_mlp_norm_config=rmsnorm_config,
|
|
103
|
+
mlp_config=mlp_config,
|
|
104
|
+
post_mlp_norm_config=rmsnorm_config,
|
|
105
|
+
)
|
|
106
|
+
return DecoderConfig(
|
|
107
|
+
embedding_config=embedding_config,
|
|
108
|
+
global_rope_config=rope_config,
|
|
109
|
+
local_rope_config=None,
|
|
110
|
+
layer_config=decoder_layer_config,
|
|
111
|
+
output_norm_config=rmsnorm_config,
|
|
112
|
+
vocab_size=self.vocab_size,
|
|
113
|
+
model_dim=self.hidden_size,
|
|
114
|
+
hidden_dim=self.intermediate_size,
|
|
115
|
+
num_heads=self.num_attention_heads,
|
|
116
|
+
num_groups=self.num_key_value_heads,
|
|
117
|
+
head_dim=self.head_dim,
|
|
118
|
+
attention_scale=attention_scale,
|
|
119
|
+
num_layers=self.num_hidden_layers,
|
|
120
|
+
sliding_window_sizes=sliding_window_sizes,
|
|
121
|
+
context_length=context_length or self.max_position_embeddings,
|
|
122
|
+
)
|