lalamo 0.2.7__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. lalamo/__init__.py +1 -1
  2. lalamo/common.py +79 -29
  3. lalamo/language_model.py +106 -83
  4. lalamo/main.py +91 -18
  5. lalamo/message_processor.py +170 -0
  6. lalamo/model_import/common.py +159 -43
  7. lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
  8. lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
  9. lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
  10. lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
  11. lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
  12. lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
  13. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
  14. lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
  15. lalamo/model_import/huggingface_generation_config.py +44 -0
  16. lalamo/model_import/huggingface_tokenizer_config.py +85 -0
  17. lalamo/model_import/loaders/common.py +2 -1
  18. lalamo/model_import/loaders/huggingface.py +12 -10
  19. lalamo/model_import/model_specs/__init__.py +3 -2
  20. lalamo/model_import/model_specs/common.py +32 -34
  21. lalamo/model_import/model_specs/deepseek.py +1 -10
  22. lalamo/model_import/model_specs/gemma.py +2 -25
  23. lalamo/model_import/model_specs/huggingface.py +2 -12
  24. lalamo/model_import/model_specs/llama.py +2 -58
  25. lalamo/model_import/model_specs/mistral.py +9 -19
  26. lalamo/model_import/model_specs/pleias.py +3 -13
  27. lalamo/model_import/model_specs/polaris.py +5 -7
  28. lalamo/model_import/model_specs/qwen.py +12 -111
  29. lalamo/model_import/model_specs/reka.py +4 -13
  30. lalamo/modules/__init__.py +2 -1
  31. lalamo/modules/attention.py +90 -10
  32. lalamo/modules/common.py +51 -4
  33. lalamo/modules/decoder.py +90 -8
  34. lalamo/modules/decoder_layer.py +85 -8
  35. lalamo/modules/embedding.py +95 -29
  36. lalamo/modules/kv_cache.py +3 -3
  37. lalamo/modules/linear.py +170 -130
  38. lalamo/modules/mlp.py +40 -7
  39. lalamo/modules/normalization.py +24 -6
  40. lalamo/modules/rope.py +24 -6
  41. lalamo/sampling.py +99 -0
  42. lalamo/utils.py +86 -1
  43. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/METADATA +6 -6
  44. lalamo-0.3.0.dist-info/RECORD +58 -0
  45. lalamo-0.2.7.dist-info/RECORD +0 -54
  46. /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
  47. /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
  48. /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
  49. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/WHEEL +0 -0
  50. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/entry_points.txt +0 -0
  51. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/licenses/LICENSE +0 -0
  52. {lalamo-0.2.7.dist-info → lalamo-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
1
+ import re
2
+ from collections.abc import Iterable
3
+ from dataclasses import dataclass
4
+ from functools import cached_property
5
+ from re import Pattern
6
+ from typing import NotRequired, TypedDict
7
+
8
+ from jinja2 import Template
9
+ from tokenizers import Tokenizer
10
+
11
+ __all__ = [
12
+ "AssistantMessage",
13
+ "ContentBlock",
14
+ "Image",
15
+ "Message",
16
+ "MessageProcessor",
17
+ "MessageProcessorConfig",
18
+ "SystemMessage",
19
+ "ToolSchema",
20
+ "UserMessage",
21
+ ]
22
+
23
+ type ToolSchema = None # WIP
24
+ type Image = None # WIP
25
+
26
+
27
+ class HuggingFaceMessage(TypedDict):
28
+ role: str
29
+ content: str
30
+ tool_calls: NotRequired[list[dict]]
31
+ reasoning_content: NotRequired[str]
32
+
33
+
34
+ class HuggingFaceRequest(TypedDict):
35
+ add_generation_prompt: bool
36
+ bos_token: str | None
37
+ messages: list[HuggingFaceMessage]
38
+ enable_thinking: NotRequired[bool]
39
+ tools: NotRequired[dict]
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class Message:
44
+ pass
45
+
46
+
47
+ type ContentBlock = str | Image
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class UserMessage(Message):
52
+ content: tuple[ContentBlock, ...] | ContentBlock
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class SystemMessage(UserMessage):
57
+ content: tuple[ContentBlock, ...] | ContentBlock
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class AssistantMessage(Message):
62
+ chain_of_thought: str | None
63
+ response: str
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class MessageProcessorConfig:
68
+ prompt_template: str
69
+ output_parser_regex: str | None
70
+ system_role_name: str
71
+ user_role_name: str
72
+ assistant_role_name: str
73
+ bos_token: str | None
74
+
75
+ def init(self, tokenizer: Tokenizer) -> "MessageProcessor":
76
+ return MessageProcessor(
77
+ config=self,
78
+ tokenizer=tokenizer,
79
+ )
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class MessageProcessor:
84
+ config: MessageProcessorConfig
85
+ tokenizer: Tokenizer
86
+
87
+ @cached_property
88
+ def prompt_template(self) -> Template:
89
+ return Template(self.config.prompt_template)
90
+
91
+ @cached_property
92
+ def output_parser_regex(self) -> Pattern | None:
93
+ if self.config.output_parser_regex is None:
94
+ return None
95
+ return re.compile(self.config.output_parser_regex)
96
+
97
+ @property
98
+ def system_role_name(self) -> str:
99
+ return self.config.system_role_name
100
+
101
+ @property
102
+ def user_role_name(self) -> str:
103
+ return self.config.user_role_name
104
+
105
+ @property
106
+ def assistant_role_name(self) -> str:
107
+ return self.config.assistant_role_name
108
+
109
+ @property
110
+ def bos_token(self) -> str | None:
111
+ return self.config.bos_token
112
+
113
+ def message_to_dict(self, message: Message) -> HuggingFaceMessage:
114
+ match message:
115
+ case UserMessage(content=content):
116
+ assert isinstance(content, str)
117
+ return HuggingFaceMessage(role=self.user_role_name, content=content)
118
+ case SystemMessage(content=content):
119
+ assert isinstance(content, str)
120
+ return HuggingFaceMessage(role=self.system_role_name, content=content)
121
+ case AssistantMessage(chain_of_thought=chain_of_thought, response=response):
122
+ result = HuggingFaceMessage(role=self.assistant_role_name, content=response)
123
+ if chain_of_thought:
124
+ result["reasoning_content"] = chain_of_thought
125
+ return result
126
+ raise ValueError(f"Unsupported message type: {type(message)}")
127
+
128
+ def request_to_dict(
129
+ self,
130
+ messages: Iterable[Message],
131
+ tools: Iterable[ToolSchema] | None = None,
132
+ enable_thinking: bool | None = None,
133
+ ) -> HuggingFaceRequest:
134
+ converted_messages = [self.message_to_dict(message) for message in messages]
135
+ result = HuggingFaceRequest(add_generation_prompt=True, messages=converted_messages, bos_token=self.bos_token)
136
+ if enable_thinking is not None:
137
+ result["enable_thinking"] = enable_thinking
138
+ if tools is not None:
139
+ raise NotImplementedError("Tools are not supported yet.")
140
+ return result
141
+
142
+ def render_request(self, messages: Iterable[Message]) -> str:
143
+ request_dict = self.request_to_dict(messages)
144
+ return self.prompt_template.render(request_dict)
145
+
146
+ def parse_response(self, response: str) -> AssistantMessage:
147
+ if self.output_parser_regex is None:
148
+ return AssistantMessage(chain_of_thought=None, response=response)
149
+ match = self.output_parser_regex.match(response)
150
+ if match is None:
151
+ raise ValueError(f"Invalid response format: {response}")
152
+ return AssistantMessage(**match.groupdict())
153
+
154
+ def tokenize(self, text: str) -> list[int]:
155
+ return self.tokenizer.encode(text, add_special_tokens=False).ids
156
+
157
+ def detokenize(self, tokens: list[int]) -> str:
158
+ return self.tokenizer.decode(tokens, skip_special_tokens=False)
159
+
160
+ def __post_init__(self) -> None:
161
+ if self.output_parser_regex is not None:
162
+ all_fields = AssistantMessage.__dataclass_fields__
163
+ required_fields = {k: v for k, v in all_fields.items() if v.type == v.type | None}
164
+ named_groups = self.output_parser_regex.groupindex
165
+ invalid_groups = set(named_groups) - set(all_fields)
166
+ if invalid_groups:
167
+ raise ValueError(f"Unsupported output fields: {list(invalid_groups)}")
168
+ for group_name in required_fields:
169
+ if group_name not in named_groups:
170
+ raise ValueError(f"Missing required output field: {group_name}")
@@ -1,21 +1,32 @@
1
1
  import importlib.metadata
2
+ from collections import ChainMap
3
+ from collections.abc import Callable
2
4
  from dataclasses import dataclass
3
5
  from pathlib import Path
4
6
  from typing import NamedTuple
5
7
 
6
8
  import huggingface_hub
7
9
  import jax.numpy as jnp
10
+ from jax import Array
8
11
  from jaxtyping import DTypeLike
12
+ from tokenizers import Tokenizer
9
13
 
10
- from lalamo.modules import Decoder, DecoderConfig
14
+ from lalamo.language_model import GenerationConfig, LanguageModel, LanguageModelConfig
15
+ from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
11
16
  from lalamo.quantization import QuantizationMode
12
17
 
13
- from .model_specs import REPO_TO_MODEL, ModelSpec, UseCase
18
+ from .huggingface_generation_config import HFGenerationConfig
19
+ from .huggingface_tokenizer_config import HFTokenizerConfig
20
+ from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, UseCase
14
21
 
15
22
  __all__ = [
16
23
  "REPO_TO_MODEL",
24
+ "DownloadingFileEvent",
25
+ "FinishedDownloadingFileEvent",
26
+ "InitializingModelEvent",
17
27
  "ModelMetadata",
18
28
  "ModelSpec",
29
+ "StatusEvent",
19
30
  "import_model",
20
31
  ]
21
32
 
@@ -23,6 +34,27 @@ __all__ = [
23
34
  LALAMO_VERSION = importlib.metadata.version("lalamo")
24
35
 
25
36
 
37
+ class DownloadingFileEvent(NamedTuple):
38
+ file: FileSpec
39
+
40
+
41
+ class FinishedDownloadingFileEvent(NamedTuple):
42
+ file: FileSpec
43
+
44
+
45
+ class InitializingModelEvent(NamedTuple):
46
+ pass
47
+
48
+
49
+ class FinishedInitializingModelEvent(NamedTuple):
50
+ pass
51
+
52
+
53
+ type StatusEvent = (
54
+ DownloadingFileEvent | FinishedDownloadingFileEvent | InitializingModelEvent | FinishedInitializingModelEvent
55
+ )
56
+
57
+
26
58
  @dataclass(frozen=True)
27
59
  class ModelMetadata:
28
60
  toolchain_version: str
@@ -33,69 +65,154 @@ class ModelMetadata:
33
65
  quantization: QuantizationMode | None
34
66
  repo: str
35
67
  use_cases: tuple[UseCase, ...]
36
- model_config: DecoderConfig
37
- tokenizer_file_names: tuple[str, ...]
68
+ model_config: LanguageModelConfig
38
69
 
39
70
 
40
- def download_weights(model_spec: ModelSpec, output_dir: Path | str | None = None) -> list[Path]:
41
- result = [
42
- huggingface_hub.hf_hub_download(
43
- repo_id=model_spec.repo,
44
- local_dir=output_dir,
45
- filename=filename,
46
- )
47
- for filename in model_spec.weights_file_names
48
- ]
49
- return [Path(path) for path in result]
50
-
51
-
52
- def download_config_file(model_spec: ModelSpec, output_dir: Path | str | None = None) -> Path:
71
+ def download_file(
72
+ file_spec: FileSpec,
73
+ model_repo: str,
74
+ output_dir: Path | str | None = None,
75
+ progress_callback: Callable[[StatusEvent], None] | None = None,
76
+ ) -> Path:
77
+ if progress_callback is not None:
78
+ progress_callback(DownloadingFileEvent(file_spec))
53
79
  result = huggingface_hub.hf_hub_download(
54
- repo_id=model_spec.repo,
80
+ repo_id=file_spec.repo or model_repo,
55
81
  local_dir=output_dir,
56
- filename=model_spec.config_file_name,
82
+ filename=file_spec.filename,
57
83
  )
84
+ if progress_callback is not None:
85
+ progress_callback(FinishedDownloadingFileEvent(file_spec))
58
86
  return Path(result)
59
87
 
60
88
 
61
- def download_tokenizer_files(model_spec: ModelSpec, output_dir: Path | str | None = None) -> tuple[Path, ...]:
62
- result = [
63
- huggingface_hub.hf_hub_download(
64
- repo_id=tokenizer_file_spec.repo or model_spec.repo,
65
- local_dir=output_dir,
66
- filename=tokenizer_file_spec.filename,
67
- )
68
- for tokenizer_file_spec in model_spec.tokenizer_files
89
+ def list_weight_files(model_repo: str) -> list[FileSpec]:
90
+ all_files = huggingface_hub.list_repo_files(model_repo)
91
+ return [FileSpec(filename) for filename in all_files if filename.endswith(".safetensors")]
92
+
93
+
94
+ def download_weights(
95
+ model_spec: ModelSpec,
96
+ output_dir: Path | str | None = None,
97
+ progress_callback: Callable[[StatusEvent], None] | None = None,
98
+ ) -> list[Path]:
99
+ return [
100
+ download_file(file_spec, model_spec.repo, output_dir, progress_callback)
101
+ for file_spec in list_weight_files(model_spec.repo)
69
102
  ]
70
- return tuple(Path(path) for path in result)
103
+
104
+
105
+ def download_config_file(
106
+ model_spec: ModelSpec,
107
+ output_dir: Path | str | None = None,
108
+ progress_callback: Callable[[StatusEvent], None] | None = None,
109
+ ) -> Path:
110
+ return download_file(model_spec.configs.model_config, model_spec.repo, output_dir, progress_callback)
71
111
 
72
112
 
73
113
  class ImportResults(NamedTuple):
74
- model: Decoder
114
+ model: LanguageModel
75
115
  metadata: ModelMetadata
76
- tokenizer_file_paths: tuple[Path, ...]
77
116
 
78
117
 
79
- def import_model(
118
+ def import_message_processor(
80
119
  model_spec: ModelSpec,
120
+ output_dir: Path | str | None = None,
121
+ progress_callback: Callable[[StatusEvent], None] | None = None,
122
+ ) -> MessageProcessor:
123
+ tokenizer_file = download_file(model_spec.configs.tokenizer, model_spec.repo, output_dir, progress_callback)
124
+ tokenizer_config_file = download_file(
125
+ model_spec.configs.tokenizer_config,
126
+ model_spec.repo,
127
+ output_dir,
128
+ progress_callback,
129
+ )
130
+ tokenizer_config = HFTokenizerConfig.from_json(tokenizer_config_file)
131
+ if tokenizer_config.chat_template is None:
132
+ if model_spec.configs.chat_template is None:
133
+ raise ValueError("Missiing chat template.")
134
+ chat_template_file = download_file(model_spec.configs.chat_template, model_spec.repo, output_dir)
135
+ prompt_template = chat_template_file.read_text()
136
+ else:
137
+ if model_spec.configs.chat_template is not None:
138
+ raise ValueError("Conflicting chat template specifications.")
139
+ prompt_template = tokenizer_config.chat_template
140
+ tokenizer = Tokenizer.from_file(str(tokenizer_file))
141
+ tokenizer.add_special_tokens(tokenizer_config.added_tokens())
142
+ message_processor_config = MessageProcessorConfig(
143
+ prompt_template=prompt_template,
144
+ output_parser_regex=model_spec.output_parser_regex,
145
+ system_role_name=model_spec.system_role_name,
146
+ user_role_name=model_spec.user_role_name,
147
+ assistant_role_name=model_spec.assistant_role_name,
148
+ bos_token=tokenizer_config.bos_token,
149
+ )
150
+ return MessageProcessor(config=message_processor_config, tokenizer=tokenizer)
151
+
152
+
153
+ def import_model(
154
+ model_spec: ModelSpec | str,
81
155
  *,
82
156
  context_length: int | None = None,
83
157
  precision: DTypeLike | None = None,
84
158
  accumulation_precision: DTypeLike = jnp.float32,
159
+ progress_callback: Callable[[StatusEvent], None] | None = None,
85
160
  ) -> ImportResults:
86
- foreign_config_file = download_config_file(model_spec)
87
- foreign_config = model_spec.config_type.from_json(foreign_config_file)
161
+ if isinstance(model_spec, str):
162
+ try:
163
+ model_spec = REPO_TO_MODEL[model_spec]
164
+ except KeyError as e:
165
+ raise ValueError(f"Unknown model: {model_spec}") from e
166
+
167
+ foreign_decoder_config_file = download_config_file(model_spec)
168
+ foreign_decoder_config = model_spec.config_type.from_json(foreign_decoder_config_file)
88
169
 
89
- tokenizer_file_paths = download_tokenizer_files(model_spec)
90
170
  if precision is None:
91
- precision = foreign_config.default_precision
171
+ precision = foreign_decoder_config.default_precision
172
+
173
+ weights_paths = download_weights(model_spec, progress_callback=progress_callback)
174
+ weights_dict: ChainMap[str, Array] = ChainMap(
175
+ *[model_spec.weights_type.load(weights_path, precision) for weights_path in weights_paths], # type: ignore
176
+ )
92
177
 
93
- weights_paths = download_weights(model_spec)
94
- weights_dict = {}
95
- for weights_path in weights_paths:
96
- weights_dict.update(model_spec.weights_type.load(weights_path, precision))
178
+ if progress_callback is not None:
179
+ progress_callback(InitializingModelEvent())
180
+
181
+ decoder = foreign_decoder_config.load_decoder(context_length, precision, accumulation_precision, weights_dict)
182
+
183
+ if progress_callback is not None:
184
+ progress_callback(FinishedInitializingModelEvent())
185
+
186
+ message_processor = import_message_processor(model_spec)
187
+
188
+ stop_token_ids = tuple(foreign_decoder_config.eos_token_ids)
189
+
190
+ if model_spec.configs.generation_config is not None:
191
+ hf_generation_config_file = download_file(model_spec.configs.generation_config, model_spec.repo)
192
+ hf_generation_config = HFGenerationConfig.from_json(hf_generation_config_file)
193
+ generation_config = GenerationConfig(
194
+ stop_token_ids=stop_token_ids,
195
+ temperature=hf_generation_config.temperature,
196
+ top_p=hf_generation_config.top_p,
197
+ top_k=hf_generation_config.top_k,
198
+ banned_tokens=None,
199
+ )
200
+ else:
201
+ generation_config = GenerationConfig(
202
+ stop_token_ids=stop_token_ids,
203
+ temperature=None,
204
+ top_p=None,
205
+ top_k=None,
206
+ banned_tokens=None,
207
+ )
208
+
209
+ language_model_config = LanguageModelConfig(
210
+ decoder_config=decoder.config,
211
+ message_processor_config=message_processor.config,
212
+ generation_config=generation_config,
213
+ )
97
214
 
98
- model = foreign_config.load_model(context_length, precision, accumulation_precision, weights_dict)
215
+ language_model = LanguageModel(language_model_config, decoder, message_processor)
99
216
  metadata = ModelMetadata(
100
217
  toolchain_version=LALAMO_VERSION,
101
218
  vendor=model_spec.vendor,
@@ -105,7 +222,6 @@ def import_model(
105
222
  quantization=model_spec.quantization,
106
223
  repo=model_spec.repo,
107
224
  use_cases=model_spec.use_cases,
108
- model_config=model.config,
109
- tokenizer_file_names=tuple(p.name for p in tokenizer_file_paths),
225
+ model_config=language_model_config,
110
226
  )
111
- return ImportResults(model, metadata, tokenizer_file_paths)
227
+ return ImportResults(language_model, metadata)
@@ -1,5 +1,4 @@
1
1
  from .common import ForeignConfig
2
-
3
2
  # from .executorch import ETLlamaConfig
4
3
  from .huggingface import (
5
4
  HFGemma2Config,
@@ -1,11 +1,11 @@
1
1
  import json
2
2
  from abc import abstractmethod
3
+ from collections.abc import Mapping
3
4
  from dataclasses import dataclass
4
5
  from pathlib import Path
5
6
  from typing import ClassVar, Self
6
7
 
7
8
  import cattrs
8
- import jax
9
9
  from jaxtyping import Array, DTypeLike
10
10
 
11
11
  from lalamo.modules import Decoder, DecoderConfig
@@ -18,6 +18,12 @@ class ForeignConfig:
18
18
  _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
19
19
  _converter.register_structure_hook(int | list[int], lambda v, _: v)
20
20
 
21
+ eos_token_id: int | list[int]
22
+
23
+ @property
24
+ def eos_token_ids(self) -> list[int]:
25
+ return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
26
+
21
27
  @property
22
28
  @abstractmethod
23
29
  def default_precision(self) -> DTypeLike: ...
@@ -29,11 +35,6 @@ class ForeignConfig:
29
35
  config = json.load(f)
30
36
  return cls._converter.structure(config, cls)
31
37
 
32
- def to_json(self, json_path: Path | str) -> None:
33
- json_path = Path(json_path)
34
- with open(json_path, "w") as f:
35
- json.dump(self._converter.unstructure(self), f, indent=2)
36
-
37
38
  def to_decoder_config(
38
39
  self,
39
40
  context_length: int | None,
@@ -46,17 +47,17 @@ class ForeignConfig:
46
47
  def _load_weights(
47
48
  cls,
48
49
  model: Decoder,
49
- weights_dict: dict[str, Array],
50
+ weights_dict: Mapping[str, Array],
50
51
  ) -> Decoder:
51
52
  raise NotImplementedError
52
53
 
53
- def load_model(
54
+ def load_decoder(
54
55
  self,
55
56
  context_length: int | None,
56
57
  activation_precision: DTypeLike,
57
58
  accumulation_precision: DTypeLike,
58
- weights_dict: dict[str, Array],
59
+ weights_dict: Mapping[str, Array],
59
60
  ) -> Decoder:
60
61
  config = self.to_decoder_config(context_length, activation_precision, accumulation_precision)
61
- model = config.random_init(key=jax.random.PRNGKey(0))
62
+ model = config.empty()
62
63
  return self._load_weights(model, weights_dict)
@@ -1,18 +1,19 @@
1
+ from collections.abc import Mapping
1
2
  from dataclasses import dataclass
2
3
  from typing import Literal
3
4
 
4
5
  import jax.numpy as jnp
5
6
  from jaxtyping import Array, DTypeLike
6
7
 
7
- from lalamo.model_import.configs import ForeignConfig
8
+ from lalamo.model_import.decoder_configs import ForeignConfig
8
9
  from lalamo.model_import.loaders import load_huggingface
9
10
  from lalamo.modules import Decoder
10
11
 
11
12
  __all__ = [
12
- "HuggingFaceConfig",
13
13
  "AWQQuantizationConfig",
14
14
  "GPTQMetaConfig",
15
- "GPTQQuantizationConfig"
15
+ "GPTQQuantizationConfig",
16
+ "HuggingFaceConfig",
16
17
  ]
17
18
 
18
19
 
@@ -59,6 +60,10 @@ class GPTQQuantizationConfig:
59
60
  class HuggingFaceConfig(ForeignConfig):
60
61
  torch_dtype: Literal["bfloat16", "float16", "float32"]
61
62
 
63
+ @property
64
+ def eos_token_ids(self) -> list[int]:
65
+ return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
66
+
62
67
  @property
63
68
  def default_precision(self) -> DTypeLike:
64
69
  return jnp.dtype(self.torch_dtype)
@@ -67,6 +72,6 @@ class HuggingFaceConfig(ForeignConfig):
67
72
  def _load_weights(
68
73
  cls,
69
74
  model: Decoder,
70
- weights_dict: dict[str, Array],
75
+ weights_dict: Mapping[str, Array],
71
76
  ) -> Decoder:
72
77
  return load_huggingface(model, weights_dict)
@@ -97,12 +97,12 @@ class HFGemma3TextConfigRaw:
97
97
  global_rope_config = UnscaledRoPEConfig(
98
98
  precision=activation_precision,
99
99
  base=self.rope_theta,
100
- max_sequence_length=self.max_position_embeddings,
100
+ max_sequence_length=context_length or self.max_position_embeddings,
101
101
  )
102
102
  local_rope_config = UnscaledRoPEConfig(
103
103
  precision=activation_precision,
104
104
  base=self.rope_local_base_freq,
105
- max_sequence_length=self.max_position_embeddings,
105
+ max_sequence_length=context_length or self.max_position_embeddings,
106
106
  )
107
107
 
108
108
  linear_config = FullPrecisionLinearConfig(precision=activation_precision)
@@ -85,13 +85,13 @@ class HFLlamaConfig(HuggingFaceConfig):
85
85
  rope_config = UnscaledRoPEConfig(
86
86
  precision=activation_precision,
87
87
  base=self.rope_theta,
88
- max_sequence_length=self.max_position_embeddings,
88
+ max_sequence_length=context_length or self.max_position_embeddings,
89
89
  )
90
90
  else:
91
91
  rope_config = LlamaRoPEConfig(
92
92
  precision=activation_precision,
93
93
  base=self.rope_theta,
94
- max_sequence_length=self.max_position_embeddings,
94
+ max_sequence_length=context_length or self.max_position_embeddings,
95
95
  scaling_factor=self.rope_scaling.factor,
96
96
  original_context_length=self.rope_scaling.original_max_position_embeddings,
97
97
  low_frequency_factor=self.rope_scaling.low_freq_factor,
@@ -70,7 +70,7 @@ class HFMistralConfig(HuggingFaceConfig):
70
70
  rope_config = UnscaledRoPEConfig(
71
71
  precision=activation_precision,
72
72
  base=self.rope_theta,
73
- max_sequence_length=self.max_position_embeddings,
73
+ max_sequence_length=context_length or self.max_position_embeddings,
74
74
  )
75
75
 
76
76
  rmsnorm_config = RMSNormConfig(
@@ -84,7 +84,7 @@ class HFQwen2Config(HuggingFaceConfig):
84
84
  rope_config = UnscaledRoPEConfig(
85
85
  precision=activation_precision,
86
86
  base=self.rope_theta,
87
- max_sequence_length=self.max_position_embeddings,
87
+ max_sequence_length=context_length or self.max_position_embeddings,
88
88
  )
89
89
  rmsnorm_config = RMSNormConfig(
90
90
  scale_precision=activation_precision,
@@ -82,7 +82,7 @@ class HFQwen3Config(HuggingFaceConfig):
82
82
  rope_config = UnscaledRoPEConfig(
83
83
  precision=activation_precision,
84
84
  base=self.rope_theta,
85
- max_sequence_length=self.max_position_embeddings,
85
+ max_sequence_length=context_length or self.max_position_embeddings,
86
86
  )
87
87
  rmsnorm_config = RMSNormConfig(
88
88
  scale_precision=activation_precision,
@@ -0,0 +1,44 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import ClassVar
5
+
6
+ import cattrs
7
+
8
+ __all__ = ["HFGenerationConfig"]
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class HFGenerationConfig:
13
+ _converter: ClassVar[cattrs.Converter] = cattrs.Converter()
14
+ _converter.register_structure_hook(int | list[int], lambda v, _: v)
15
+ _converter.register_structure_hook(int | list[int] | None, lambda v, _: v)
16
+
17
+ # -------- identity / bookkeeping --------
18
+ _from_model_config: bool | None = None # some Mistral & DeepSeek models
19
+ transformers_version: str | None = None # library version that saved the file
20
+
21
+ # -------- special-token ids -------------
22
+ bos_token_id: int | None = None
23
+ eos_token_id: int | list[int] | None = None
24
+ pad_token_id: int | None = None
25
+
26
+ # -------- backend hints -----------------
27
+ cache_implementation: str | None = None # “hybrid” for Gemma 3/2
28
+
29
+ # -------- sampling strategy -------------
30
+ do_sample: bool | None = None
31
+ temperature: float | None = None
32
+ top_p: float | None = None
33
+ top_k: int | None = None
34
+ repetition_penalty: float | None = None
35
+
36
+ # -------- length limits -----------------
37
+ max_length: int | None = None # seen in Llama 3, Gemma 2/3
38
+
39
+ @classmethod
40
+ def from_json(cls, json_path: Path | str) -> "HFGenerationConfig":
41
+ json_path = Path(json_path)
42
+ with open(json_path) as f:
43
+ config = json.load(f)
44
+ return cls._converter.structure(config, cls)