lalamo 0.2.7__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/common.py +79 -29
- lalamo/language_model.py +106 -83
- lalamo/main.py +91 -18
- lalamo/message_processor.py +170 -0
- lalamo/model_import/common.py +159 -43
- lalamo/model_import/{configs → decoder_configs}/__init__.py +0 -1
- lalamo/model_import/{configs → decoder_configs}/common.py +11 -10
- lalamo/model_import/{configs → decoder_configs}/huggingface/common.py +9 -4
- lalamo/model_import/{configs → decoder_configs}/huggingface/gemma3.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/llama.py +2 -2
- lalamo/model_import/{configs → decoder_configs}/huggingface/mistral.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen2.py +1 -1
- lalamo/model_import/{configs → decoder_configs}/huggingface/qwen3.py +1 -1
- lalamo/model_import/huggingface_generation_config.py +44 -0
- lalamo/model_import/huggingface_tokenizer_config.py +85 -0
- lalamo/model_import/loaders/common.py +2 -1
- lalamo/model_import/loaders/huggingface.py +12 -10
- lalamo/model_import/model_specs/__init__.py +3 -2
- lalamo/model_import/model_specs/common.py +31 -32
- lalamo/model_import/model_specs/deepseek.py +1 -10
- lalamo/model_import/model_specs/gemma.py +2 -25
- lalamo/model_import/model_specs/huggingface.py +2 -12
- lalamo/model_import/model_specs/llama.py +2 -58
- lalamo/model_import/model_specs/mistral.py +9 -19
- lalamo/model_import/model_specs/pleias.py +3 -13
- lalamo/model_import/model_specs/polaris.py +5 -7
- lalamo/model_import/model_specs/qwen.py +12 -111
- lalamo/model_import/model_specs/reka.py +4 -13
- lalamo/modules/__init__.py +2 -1
- lalamo/modules/attention.py +90 -10
- lalamo/modules/common.py +51 -4
- lalamo/modules/decoder.py +90 -8
- lalamo/modules/decoder_layer.py +85 -8
- lalamo/modules/embedding.py +95 -29
- lalamo/modules/kv_cache.py +3 -3
- lalamo/modules/linear.py +170 -130
- lalamo/modules/mlp.py +40 -7
- lalamo/modules/normalization.py +24 -6
- lalamo/modules/rope.py +24 -6
- lalamo/sampling.py +99 -0
- lalamo/utils.py +86 -1
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/METADATA +6 -6
- lalamo-0.3.1.dist-info/RECORD +58 -0
- lalamo-0.2.7.dist-info/RECORD +0 -54
- /lalamo/model_import/{configs → decoder_configs}/executorch.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/__init__.py +0 -0
- /lalamo/model_import/{configs → decoder_configs}/huggingface/gemma2.py +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/WHEEL +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/entry_points.txt +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.2.7.dist-info → lalamo-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from re import Pattern
|
|
6
|
+
from typing import NotRequired, TypedDict
|
|
7
|
+
|
|
8
|
+
from jinja2 import Template
|
|
9
|
+
from tokenizers import Tokenizer
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"AssistantMessage",
|
|
13
|
+
"ContentBlock",
|
|
14
|
+
"Image",
|
|
15
|
+
"Message",
|
|
16
|
+
"MessageProcessor",
|
|
17
|
+
"MessageProcessorConfig",
|
|
18
|
+
"SystemMessage",
|
|
19
|
+
"ToolSchema",
|
|
20
|
+
"UserMessage",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
type ToolSchema = None # WIP
|
|
24
|
+
type Image = None # WIP
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class HuggingFaceMessage(TypedDict):
|
|
28
|
+
role: str
|
|
29
|
+
content: str
|
|
30
|
+
tool_calls: NotRequired[list[dict]]
|
|
31
|
+
reasoning_content: NotRequired[str]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class HuggingFaceRequest(TypedDict):
|
|
35
|
+
add_generation_prompt: bool
|
|
36
|
+
bos_token: str | None
|
|
37
|
+
messages: list[HuggingFaceMessage]
|
|
38
|
+
enable_thinking: NotRequired[bool]
|
|
39
|
+
tools: NotRequired[dict]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(frozen=True)
|
|
43
|
+
class Message:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
type ContentBlock = str | Image
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(frozen=True)
|
|
51
|
+
class UserMessage(Message):
|
|
52
|
+
content: tuple[ContentBlock, ...] | ContentBlock
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(frozen=True)
|
|
56
|
+
class SystemMessage(UserMessage):
|
|
57
|
+
content: tuple[ContentBlock, ...] | ContentBlock
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass(frozen=True)
|
|
61
|
+
class AssistantMessage(Message):
|
|
62
|
+
chain_of_thought: str | None
|
|
63
|
+
response: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True)
|
|
67
|
+
class MessageProcessorConfig:
|
|
68
|
+
prompt_template: str
|
|
69
|
+
output_parser_regex: str | None
|
|
70
|
+
system_role_name: str
|
|
71
|
+
user_role_name: str
|
|
72
|
+
assistant_role_name: str
|
|
73
|
+
bos_token: str | None
|
|
74
|
+
|
|
75
|
+
def init(self, tokenizer: Tokenizer) -> "MessageProcessor":
|
|
76
|
+
return MessageProcessor(
|
|
77
|
+
config=self,
|
|
78
|
+
tokenizer=tokenizer,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass(frozen=True)
|
|
83
|
+
class MessageProcessor:
|
|
84
|
+
config: MessageProcessorConfig
|
|
85
|
+
tokenizer: Tokenizer
|
|
86
|
+
|
|
87
|
+
@cached_property
|
|
88
|
+
def prompt_template(self) -> Template:
|
|
89
|
+
return Template(self.config.prompt_template)
|
|
90
|
+
|
|
91
|
+
@cached_property
|
|
92
|
+
def output_parser_regex(self) -> Pattern | None:
|
|
93
|
+
if self.config.output_parser_regex is None:
|
|
94
|
+
return None
|
|
95
|
+
return re.compile(self.config.output_parser_regex)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def system_role_name(self) -> str:
|
|
99
|
+
return self.config.system_role_name
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def user_role_name(self) -> str:
|
|
103
|
+
return self.config.user_role_name
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def assistant_role_name(self) -> str:
|
|
107
|
+
return self.config.assistant_role_name
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def bos_token(self) -> str | None:
|
|
111
|
+
return self.config.bos_token
|
|
112
|
+
|
|
113
|
+
def message_to_dict(self, message: Message) -> HuggingFaceMessage:
|
|
114
|
+
match message:
|
|
115
|
+
case UserMessage(content=content):
|
|
116
|
+
assert isinstance(content, str)
|
|
117
|
+
return HuggingFaceMessage(role=self.user_role_name, content=content)
|
|
118
|
+
case SystemMessage(content=content):
|
|
119
|
+
assert isinstance(content, str)
|
|
120
|
+
return HuggingFaceMessage(role=self.system_role_name, content=content)
|
|
121
|
+
case AssistantMessage(chain_of_thought=chain_of_thought, response=response):
|
|
122
|
+
result = HuggingFaceMessage(role=self.assistant_role_name, content=response)
|
|
123
|
+
if chain_of_thought:
|
|
124
|
+
result["reasoning_content"] = chain_of_thought
|
|
125
|
+
return result
|
|
126
|
+
raise ValueError(f"Unsupported message type: {type(message)}")
|
|
127
|
+
|
|
128
|
+
def request_to_dict(
|
|
129
|
+
self,
|
|
130
|
+
messages: Iterable[Message],
|
|
131
|
+
tools: Iterable[ToolSchema] | None = None,
|
|
132
|
+
enable_thinking: bool | None = None,
|
|
133
|
+
) -> HuggingFaceRequest:
|
|
134
|
+
converted_messages = [self.message_to_dict(message) for message in messages]
|
|
135
|
+
result = HuggingFaceRequest(add_generation_prompt=True, messages=converted_messages, bos_token=self.bos_token)
|
|
136
|
+
if enable_thinking is not None:
|
|
137
|
+
result["enable_thinking"] = enable_thinking
|
|
138
|
+
if tools is not None:
|
|
139
|
+
raise NotImplementedError("Tools are not supported yet.")
|
|
140
|
+
return result
|
|
141
|
+
|
|
142
|
+
def render_request(self, messages: Iterable[Message]) -> str:
|
|
143
|
+
request_dict = self.request_to_dict(messages)
|
|
144
|
+
return self.prompt_template.render(request_dict)
|
|
145
|
+
|
|
146
|
+
def parse_response(self, response: str) -> AssistantMessage:
|
|
147
|
+
if self.output_parser_regex is None:
|
|
148
|
+
return AssistantMessage(chain_of_thought=None, response=response)
|
|
149
|
+
match = self.output_parser_regex.match(response)
|
|
150
|
+
if match is None:
|
|
151
|
+
raise ValueError(f"Invalid response format: {response}")
|
|
152
|
+
return AssistantMessage(**match.groupdict())
|
|
153
|
+
|
|
154
|
+
def tokenize(self, text: str) -> list[int]:
|
|
155
|
+
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
|
156
|
+
|
|
157
|
+
def detokenize(self, tokens: list[int]) -> str:
|
|
158
|
+
return self.tokenizer.decode(tokens, skip_special_tokens=False)
|
|
159
|
+
|
|
160
|
+
def __post_init__(self) -> None:
|
|
161
|
+
if self.output_parser_regex is not None:
|
|
162
|
+
all_fields = AssistantMessage.__dataclass_fields__
|
|
163
|
+
required_fields = {k: v for k, v in all_fields.items() if v.type == v.type | None}
|
|
164
|
+
named_groups = self.output_parser_regex.groupindex
|
|
165
|
+
invalid_groups = set(named_groups) - set(all_fields)
|
|
166
|
+
if invalid_groups:
|
|
167
|
+
raise ValueError(f"Unsupported output fields: {list(invalid_groups)}")
|
|
168
|
+
for group_name in required_fields:
|
|
169
|
+
if group_name not in named_groups:
|
|
170
|
+
raise ValueError(f"Missing required output field: {group_name}")
|
lalamo/model_import/common.py
CHANGED
|
@@ -1,21 +1,32 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
|
+
from collections import ChainMap
|
|
3
|
+
from collections.abc import Callable
|
|
2
4
|
from dataclasses import dataclass
|
|
3
5
|
from pathlib import Path
|
|
4
6
|
from typing import NamedTuple
|
|
5
7
|
|
|
6
8
|
import huggingface_hub
|
|
7
9
|
import jax.numpy as jnp
|
|
10
|
+
from jax import Array
|
|
8
11
|
from jaxtyping import DTypeLike
|
|
12
|
+
from tokenizers import Tokenizer
|
|
9
13
|
|
|
10
|
-
from lalamo.
|
|
14
|
+
from lalamo.language_model import GenerationConfig, LanguageModel, LanguageModelConfig
|
|
15
|
+
from lalamo.message_processor import MessageProcessor, MessageProcessorConfig
|
|
11
16
|
from lalamo.quantization import QuantizationMode
|
|
12
17
|
|
|
13
|
-
from .
|
|
18
|
+
from .huggingface_generation_config import HFGenerationConfig
|
|
19
|
+
from .huggingface_tokenizer_config import HFTokenizerConfig
|
|
20
|
+
from .model_specs import REPO_TO_MODEL, FileSpec, ModelSpec, UseCase
|
|
14
21
|
|
|
15
22
|
__all__ = [
|
|
16
23
|
"REPO_TO_MODEL",
|
|
24
|
+
"DownloadingFileEvent",
|
|
25
|
+
"FinishedDownloadingFileEvent",
|
|
26
|
+
"InitializingModelEvent",
|
|
17
27
|
"ModelMetadata",
|
|
18
28
|
"ModelSpec",
|
|
29
|
+
"StatusEvent",
|
|
19
30
|
"import_model",
|
|
20
31
|
]
|
|
21
32
|
|
|
@@ -23,6 +34,27 @@ __all__ = [
|
|
|
23
34
|
LALAMO_VERSION = importlib.metadata.version("lalamo")
|
|
24
35
|
|
|
25
36
|
|
|
37
|
+
class DownloadingFileEvent(NamedTuple):
|
|
38
|
+
file: FileSpec
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FinishedDownloadingFileEvent(NamedTuple):
|
|
42
|
+
file: FileSpec
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class InitializingModelEvent(NamedTuple):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class FinishedInitializingModelEvent(NamedTuple):
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
type StatusEvent = (
|
|
54
|
+
DownloadingFileEvent | FinishedDownloadingFileEvent | InitializingModelEvent | FinishedInitializingModelEvent
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
26
58
|
@dataclass(frozen=True)
|
|
27
59
|
class ModelMetadata:
|
|
28
60
|
toolchain_version: str
|
|
@@ -33,69 +65,154 @@ class ModelMetadata:
|
|
|
33
65
|
quantization: QuantizationMode | None
|
|
34
66
|
repo: str
|
|
35
67
|
use_cases: tuple[UseCase, ...]
|
|
36
|
-
model_config:
|
|
37
|
-
tokenizer_file_names: tuple[str, ...]
|
|
68
|
+
model_config: LanguageModelConfig
|
|
38
69
|
|
|
39
70
|
|
|
40
|
-
def
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
]
|
|
49
|
-
return [Path(path) for path in result]
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def download_config_file(model_spec: ModelSpec, output_dir: Path | str | None = None) -> Path:
|
|
71
|
+
def download_file(
|
|
72
|
+
file_spec: FileSpec,
|
|
73
|
+
model_repo: str,
|
|
74
|
+
output_dir: Path | str | None = None,
|
|
75
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
76
|
+
) -> Path:
|
|
77
|
+
if progress_callback is not None:
|
|
78
|
+
progress_callback(DownloadingFileEvent(file_spec))
|
|
53
79
|
result = huggingface_hub.hf_hub_download(
|
|
54
|
-
repo_id=
|
|
80
|
+
repo_id=file_spec.repo or model_repo,
|
|
55
81
|
local_dir=output_dir,
|
|
56
|
-
filename=
|
|
82
|
+
filename=file_spec.filename,
|
|
57
83
|
)
|
|
84
|
+
if progress_callback is not None:
|
|
85
|
+
progress_callback(FinishedDownloadingFileEvent(file_spec))
|
|
58
86
|
return Path(result)
|
|
59
87
|
|
|
60
88
|
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
89
|
+
def list_weight_files(model_repo: str) -> list[FileSpec]:
|
|
90
|
+
all_files = huggingface_hub.list_repo_files(model_repo)
|
|
91
|
+
return [FileSpec(filename) for filename in all_files if filename.endswith(".safetensors")]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def download_weights(
|
|
95
|
+
model_spec: ModelSpec,
|
|
96
|
+
output_dir: Path | str | None = None,
|
|
97
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
98
|
+
) -> list[Path]:
|
|
99
|
+
return [
|
|
100
|
+
download_file(file_spec, model_spec.repo, output_dir, progress_callback)
|
|
101
|
+
for file_spec in list_weight_files(model_spec.repo)
|
|
69
102
|
]
|
|
70
|
-
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def download_config_file(
|
|
106
|
+
model_spec: ModelSpec,
|
|
107
|
+
output_dir: Path | str | None = None,
|
|
108
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
109
|
+
) -> Path:
|
|
110
|
+
return download_file(model_spec.configs.model_config, model_spec.repo, output_dir, progress_callback)
|
|
71
111
|
|
|
72
112
|
|
|
73
113
|
class ImportResults(NamedTuple):
|
|
74
|
-
model:
|
|
114
|
+
model: LanguageModel
|
|
75
115
|
metadata: ModelMetadata
|
|
76
|
-
tokenizer_file_paths: tuple[Path, ...]
|
|
77
116
|
|
|
78
117
|
|
|
79
|
-
def
|
|
118
|
+
def import_message_processor(
|
|
80
119
|
model_spec: ModelSpec,
|
|
120
|
+
output_dir: Path | str | None = None,
|
|
121
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
122
|
+
) -> MessageProcessor:
|
|
123
|
+
tokenizer_file = download_file(model_spec.configs.tokenizer, model_spec.repo, output_dir, progress_callback)
|
|
124
|
+
tokenizer_config_file = download_file(
|
|
125
|
+
model_spec.configs.tokenizer_config,
|
|
126
|
+
model_spec.repo,
|
|
127
|
+
output_dir,
|
|
128
|
+
progress_callback,
|
|
129
|
+
)
|
|
130
|
+
tokenizer_config = HFTokenizerConfig.from_json(tokenizer_config_file)
|
|
131
|
+
if tokenizer_config.chat_template is None:
|
|
132
|
+
if model_spec.configs.chat_template is None:
|
|
133
|
+
raise ValueError("Missiing chat template.")
|
|
134
|
+
chat_template_file = download_file(model_spec.configs.chat_template, model_spec.repo, output_dir)
|
|
135
|
+
prompt_template = chat_template_file.read_text()
|
|
136
|
+
else:
|
|
137
|
+
if model_spec.configs.chat_template is not None:
|
|
138
|
+
raise ValueError("Conflicting chat template specifications.")
|
|
139
|
+
prompt_template = tokenizer_config.chat_template
|
|
140
|
+
tokenizer = Tokenizer.from_file(str(tokenizer_file))
|
|
141
|
+
tokenizer.add_special_tokens(tokenizer_config.added_tokens())
|
|
142
|
+
message_processor_config = MessageProcessorConfig(
|
|
143
|
+
prompt_template=prompt_template,
|
|
144
|
+
output_parser_regex=model_spec.output_parser_regex,
|
|
145
|
+
system_role_name=model_spec.system_role_name,
|
|
146
|
+
user_role_name=model_spec.user_role_name,
|
|
147
|
+
assistant_role_name=model_spec.assistant_role_name,
|
|
148
|
+
bos_token=tokenizer_config.bos_token,
|
|
149
|
+
)
|
|
150
|
+
return MessageProcessor(config=message_processor_config, tokenizer=tokenizer)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def import_model(
|
|
154
|
+
model_spec: ModelSpec | str,
|
|
81
155
|
*,
|
|
82
156
|
context_length: int | None = None,
|
|
83
157
|
precision: DTypeLike | None = None,
|
|
84
158
|
accumulation_precision: DTypeLike = jnp.float32,
|
|
159
|
+
progress_callback: Callable[[StatusEvent], None] | None = None,
|
|
85
160
|
) -> ImportResults:
|
|
86
|
-
|
|
87
|
-
|
|
161
|
+
if isinstance(model_spec, str):
|
|
162
|
+
try:
|
|
163
|
+
model_spec = REPO_TO_MODEL[model_spec]
|
|
164
|
+
except KeyError as e:
|
|
165
|
+
raise ValueError(f"Unknown model: {model_spec}") from e
|
|
166
|
+
|
|
167
|
+
foreign_decoder_config_file = download_config_file(model_spec)
|
|
168
|
+
foreign_decoder_config = model_spec.config_type.from_json(foreign_decoder_config_file)
|
|
88
169
|
|
|
89
|
-
tokenizer_file_paths = download_tokenizer_files(model_spec)
|
|
90
170
|
if precision is None:
|
|
91
|
-
precision =
|
|
171
|
+
precision = foreign_decoder_config.default_precision
|
|
172
|
+
|
|
173
|
+
weights_paths = download_weights(model_spec, progress_callback=progress_callback)
|
|
174
|
+
weights_dict: ChainMap[str, Array] = ChainMap(
|
|
175
|
+
*[model_spec.weights_type.load(weights_path, precision) for weights_path in weights_paths], # type: ignore
|
|
176
|
+
)
|
|
92
177
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
178
|
+
if progress_callback is not None:
|
|
179
|
+
progress_callback(InitializingModelEvent())
|
|
180
|
+
|
|
181
|
+
decoder = foreign_decoder_config.load_decoder(context_length, precision, accumulation_precision, weights_dict)
|
|
182
|
+
|
|
183
|
+
if progress_callback is not None:
|
|
184
|
+
progress_callback(FinishedInitializingModelEvent())
|
|
185
|
+
|
|
186
|
+
message_processor = import_message_processor(model_spec)
|
|
187
|
+
|
|
188
|
+
stop_token_ids = tuple(foreign_decoder_config.eos_token_ids)
|
|
189
|
+
|
|
190
|
+
if model_spec.configs.generation_config is not None:
|
|
191
|
+
hf_generation_config_file = download_file(model_spec.configs.generation_config, model_spec.repo)
|
|
192
|
+
hf_generation_config = HFGenerationConfig.from_json(hf_generation_config_file)
|
|
193
|
+
generation_config = GenerationConfig(
|
|
194
|
+
stop_token_ids=stop_token_ids,
|
|
195
|
+
temperature=hf_generation_config.temperature,
|
|
196
|
+
top_p=hf_generation_config.top_p,
|
|
197
|
+
top_k=hf_generation_config.top_k,
|
|
198
|
+
banned_tokens=None,
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
generation_config = GenerationConfig(
|
|
202
|
+
stop_token_ids=stop_token_ids,
|
|
203
|
+
temperature=None,
|
|
204
|
+
top_p=None,
|
|
205
|
+
top_k=None,
|
|
206
|
+
banned_tokens=None,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
language_model_config = LanguageModelConfig(
|
|
210
|
+
decoder_config=decoder.config,
|
|
211
|
+
message_processor_config=message_processor.config,
|
|
212
|
+
generation_config=generation_config,
|
|
213
|
+
)
|
|
97
214
|
|
|
98
|
-
|
|
215
|
+
language_model = LanguageModel(language_model_config, decoder, message_processor)
|
|
99
216
|
metadata = ModelMetadata(
|
|
100
217
|
toolchain_version=LALAMO_VERSION,
|
|
101
218
|
vendor=model_spec.vendor,
|
|
@@ -105,7 +222,6 @@ def import_model(
|
|
|
105
222
|
quantization=model_spec.quantization,
|
|
106
223
|
repo=model_spec.repo,
|
|
107
224
|
use_cases=model_spec.use_cases,
|
|
108
|
-
model_config=
|
|
109
|
-
tokenizer_file_names=tuple(p.name for p in tokenizer_file_paths),
|
|
225
|
+
model_config=language_model_config,
|
|
110
226
|
)
|
|
111
|
-
return ImportResults(
|
|
227
|
+
return ImportResults(language_model, metadata)
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import abstractmethod
|
|
3
|
+
from collections.abc import Mapping
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import ClassVar, Self
|
|
6
7
|
|
|
7
8
|
import cattrs
|
|
8
|
-
import jax
|
|
9
9
|
from jaxtyping import Array, DTypeLike
|
|
10
10
|
|
|
11
11
|
from lalamo.modules import Decoder, DecoderConfig
|
|
@@ -18,6 +18,12 @@ class ForeignConfig:
|
|
|
18
18
|
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
19
19
|
_converter.register_structure_hook(int | list[int], lambda v, _: v)
|
|
20
20
|
|
|
21
|
+
eos_token_id: int | list[int]
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def eos_token_ids(self) -> list[int]:
|
|
25
|
+
return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
|
|
26
|
+
|
|
21
27
|
@property
|
|
22
28
|
@abstractmethod
|
|
23
29
|
def default_precision(self) -> DTypeLike: ...
|
|
@@ -29,11 +35,6 @@ class ForeignConfig:
|
|
|
29
35
|
config = json.load(f)
|
|
30
36
|
return cls._converter.structure(config, cls)
|
|
31
37
|
|
|
32
|
-
def to_json(self, json_path: Path | str) -> None:
|
|
33
|
-
json_path = Path(json_path)
|
|
34
|
-
with open(json_path, "w") as f:
|
|
35
|
-
json.dump(self._converter.unstructure(self), f, indent=2)
|
|
36
|
-
|
|
37
38
|
def to_decoder_config(
|
|
38
39
|
self,
|
|
39
40
|
context_length: int | None,
|
|
@@ -46,17 +47,17 @@ class ForeignConfig:
|
|
|
46
47
|
def _load_weights(
|
|
47
48
|
cls,
|
|
48
49
|
model: Decoder,
|
|
49
|
-
weights_dict:
|
|
50
|
+
weights_dict: Mapping[str, Array],
|
|
50
51
|
) -> Decoder:
|
|
51
52
|
raise NotImplementedError
|
|
52
53
|
|
|
53
|
-
def
|
|
54
|
+
def load_decoder(
|
|
54
55
|
self,
|
|
55
56
|
context_length: int | None,
|
|
56
57
|
activation_precision: DTypeLike,
|
|
57
58
|
accumulation_precision: DTypeLike,
|
|
58
|
-
weights_dict:
|
|
59
|
+
weights_dict: Mapping[str, Array],
|
|
59
60
|
) -> Decoder:
|
|
60
61
|
config = self.to_decoder_config(context_length, activation_precision, accumulation_precision)
|
|
61
|
-
model = config.
|
|
62
|
+
model = config.empty()
|
|
62
63
|
return self._load_weights(model, weights_dict)
|
|
@@ -1,18 +1,19 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
1
2
|
from dataclasses import dataclass
|
|
2
3
|
from typing import Literal
|
|
3
4
|
|
|
4
5
|
import jax.numpy as jnp
|
|
5
6
|
from jaxtyping import Array, DTypeLike
|
|
6
7
|
|
|
7
|
-
from lalamo.model_import.
|
|
8
|
+
from lalamo.model_import.decoder_configs import ForeignConfig
|
|
8
9
|
from lalamo.model_import.loaders import load_huggingface
|
|
9
10
|
from lalamo.modules import Decoder
|
|
10
11
|
|
|
11
12
|
__all__ = [
|
|
12
|
-
"HuggingFaceConfig",
|
|
13
13
|
"AWQQuantizationConfig",
|
|
14
14
|
"GPTQMetaConfig",
|
|
15
|
-
"GPTQQuantizationConfig"
|
|
15
|
+
"GPTQQuantizationConfig",
|
|
16
|
+
"HuggingFaceConfig",
|
|
16
17
|
]
|
|
17
18
|
|
|
18
19
|
|
|
@@ -59,6 +60,10 @@ class GPTQQuantizationConfig:
|
|
|
59
60
|
class HuggingFaceConfig(ForeignConfig):
|
|
60
61
|
torch_dtype: Literal["bfloat16", "float16", "float32"]
|
|
61
62
|
|
|
63
|
+
@property
|
|
64
|
+
def eos_token_ids(self) -> list[int]:
|
|
65
|
+
return [self.eos_token_id] if isinstance(self.eos_token_id, int) else self.eos_token_id
|
|
66
|
+
|
|
62
67
|
@property
|
|
63
68
|
def default_precision(self) -> DTypeLike:
|
|
64
69
|
return jnp.dtype(self.torch_dtype)
|
|
@@ -67,6 +72,6 @@ class HuggingFaceConfig(ForeignConfig):
|
|
|
67
72
|
def _load_weights(
|
|
68
73
|
cls,
|
|
69
74
|
model: Decoder,
|
|
70
|
-
weights_dict:
|
|
75
|
+
weights_dict: Mapping[str, Array],
|
|
71
76
|
) -> Decoder:
|
|
72
77
|
return load_huggingface(model, weights_dict)
|
|
@@ -97,12 +97,12 @@ class HFGemma3TextConfigRaw:
|
|
|
97
97
|
global_rope_config = UnscaledRoPEConfig(
|
|
98
98
|
precision=activation_precision,
|
|
99
99
|
base=self.rope_theta,
|
|
100
|
-
max_sequence_length=self.max_position_embeddings,
|
|
100
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
101
101
|
)
|
|
102
102
|
local_rope_config = UnscaledRoPEConfig(
|
|
103
103
|
precision=activation_precision,
|
|
104
104
|
base=self.rope_local_base_freq,
|
|
105
|
-
max_sequence_length=self.max_position_embeddings,
|
|
105
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
106
106
|
)
|
|
107
107
|
|
|
108
108
|
linear_config = FullPrecisionLinearConfig(precision=activation_precision)
|
|
@@ -85,13 +85,13 @@ class HFLlamaConfig(HuggingFaceConfig):
|
|
|
85
85
|
rope_config = UnscaledRoPEConfig(
|
|
86
86
|
precision=activation_precision,
|
|
87
87
|
base=self.rope_theta,
|
|
88
|
-
max_sequence_length=self.max_position_embeddings,
|
|
88
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
89
89
|
)
|
|
90
90
|
else:
|
|
91
91
|
rope_config = LlamaRoPEConfig(
|
|
92
92
|
precision=activation_precision,
|
|
93
93
|
base=self.rope_theta,
|
|
94
|
-
max_sequence_length=self.max_position_embeddings,
|
|
94
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
95
95
|
scaling_factor=self.rope_scaling.factor,
|
|
96
96
|
original_context_length=self.rope_scaling.original_max_position_embeddings,
|
|
97
97
|
low_frequency_factor=self.rope_scaling.low_freq_factor,
|
|
@@ -70,7 +70,7 @@ class HFMistralConfig(HuggingFaceConfig):
|
|
|
70
70
|
rope_config = UnscaledRoPEConfig(
|
|
71
71
|
precision=activation_precision,
|
|
72
72
|
base=self.rope_theta,
|
|
73
|
-
max_sequence_length=self.max_position_embeddings,
|
|
73
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
74
74
|
)
|
|
75
75
|
|
|
76
76
|
rmsnorm_config = RMSNormConfig(
|
|
@@ -84,7 +84,7 @@ class HFQwen2Config(HuggingFaceConfig):
|
|
|
84
84
|
rope_config = UnscaledRoPEConfig(
|
|
85
85
|
precision=activation_precision,
|
|
86
86
|
base=self.rope_theta,
|
|
87
|
-
max_sequence_length=self.max_position_embeddings,
|
|
87
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
88
88
|
)
|
|
89
89
|
rmsnorm_config = RMSNormConfig(
|
|
90
90
|
scale_precision=activation_precision,
|
|
@@ -82,7 +82,7 @@ class HFQwen3Config(HuggingFaceConfig):
|
|
|
82
82
|
rope_config = UnscaledRoPEConfig(
|
|
83
83
|
precision=activation_precision,
|
|
84
84
|
base=self.rope_theta,
|
|
85
|
-
max_sequence_length=self.max_position_embeddings,
|
|
85
|
+
max_sequence_length=context_length or self.max_position_embeddings,
|
|
86
86
|
)
|
|
87
87
|
rmsnorm_config = RMSNormConfig(
|
|
88
88
|
scale_precision=activation_precision,
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import ClassVar
|
|
5
|
+
|
|
6
|
+
import cattrs
|
|
7
|
+
|
|
8
|
+
__all__ = ["HFGenerationConfig"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class HFGenerationConfig:
|
|
13
|
+
_converter: ClassVar[cattrs.Converter] = cattrs.Converter()
|
|
14
|
+
_converter.register_structure_hook(int | list[int], lambda v, _: v)
|
|
15
|
+
_converter.register_structure_hook(int | list[int] | None, lambda v, _: v)
|
|
16
|
+
|
|
17
|
+
# -------- identity / bookkeeping --------
|
|
18
|
+
_from_model_config: bool | None = None # some Mistral & DeepSeek models
|
|
19
|
+
transformers_version: str | None = None # library version that saved the file
|
|
20
|
+
|
|
21
|
+
# -------- special-token ids -------------
|
|
22
|
+
bos_token_id: int | None = None
|
|
23
|
+
eos_token_id: int | list[int] | None = None
|
|
24
|
+
pad_token_id: int | None = None
|
|
25
|
+
|
|
26
|
+
# -------- backend hints -----------------
|
|
27
|
+
cache_implementation: str | None = None # “hybrid” for Gemma 3/2
|
|
28
|
+
|
|
29
|
+
# -------- sampling strategy -------------
|
|
30
|
+
do_sample: bool | None = None
|
|
31
|
+
temperature: float | None = None
|
|
32
|
+
top_p: float | None = None
|
|
33
|
+
top_k: int | None = None
|
|
34
|
+
repetition_penalty: float | None = None
|
|
35
|
+
|
|
36
|
+
# -------- length limits -----------------
|
|
37
|
+
max_length: int | None = None # seen in Llama 3, Gemma 2/3
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_json(cls, json_path: Path | str) -> "HFGenerationConfig":
|
|
41
|
+
json_path = Path(json_path)
|
|
42
|
+
with open(json_path) as f:
|
|
43
|
+
config = json.load(f)
|
|
44
|
+
return cls._converter.structure(config, cls)
|