xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +48 -41
- xinference/model/audio/chattts.py +24 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +23 -1
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +49 -1
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +2 -0
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +10 -1
- xinference/model/llm/vllm/core.py +1 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,779 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
from loguru import logger
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
from torch.nn import functional as F
|
|
14
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
15
|
+
from torch.utils.checkpoint import checkpoint
|
|
16
|
+
from transformers import AutoTokenizer
|
|
17
|
+
|
|
18
|
+
from fish_speech.conversation import SEMANTIC_TOKEN
|
|
19
|
+
from fish_speech.utils import RankedLogger
|
|
20
|
+
|
|
21
|
+
from .lora import LoraConfig, setup_lora
|
|
22
|
+
|
|
23
|
+
log = RankedLogger(__name__, rank_zero_only=True)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def find_multiple(n: int, k: int) -> int:
|
|
27
|
+
if n % k == 0:
|
|
28
|
+
return n
|
|
29
|
+
return n + k - (n % k)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class BaseModelArgs:
|
|
34
|
+
model_type: str = "base"
|
|
35
|
+
|
|
36
|
+
vocab_size: int = 32000
|
|
37
|
+
n_layer: int = 32
|
|
38
|
+
n_head: int = 32
|
|
39
|
+
dim: int = 4096
|
|
40
|
+
intermediate_size: int = None
|
|
41
|
+
n_local_heads: int = -1
|
|
42
|
+
head_dim: int = 64
|
|
43
|
+
rope_base: float = 10000
|
|
44
|
+
norm_eps: float = 1e-5
|
|
45
|
+
max_seq_len: int = 2048
|
|
46
|
+
dropout: float = 0.0
|
|
47
|
+
tie_word_embeddings: bool = True
|
|
48
|
+
attention_qkv_bias: bool = False
|
|
49
|
+
|
|
50
|
+
# Codebook configs
|
|
51
|
+
codebook_size: int = 160
|
|
52
|
+
num_codebooks: int = 4
|
|
53
|
+
|
|
54
|
+
# Gradient checkpointing
|
|
55
|
+
use_gradient_checkpointing: bool = True
|
|
56
|
+
|
|
57
|
+
# Initialize the model
|
|
58
|
+
initializer_range: float = 0.02
|
|
59
|
+
|
|
60
|
+
def __post_init__(self):
|
|
61
|
+
if self.n_local_heads == -1:
|
|
62
|
+
self.n_local_heads = self.n_head
|
|
63
|
+
if self.intermediate_size is None:
|
|
64
|
+
hidden_dim = 4 * self.dim
|
|
65
|
+
n_hidden = int(2 * hidden_dim / 3)
|
|
66
|
+
self.intermediate_size = find_multiple(n_hidden, 256)
|
|
67
|
+
self.head_dim = self.dim // self.n_head
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def from_pretrained(path: str):
|
|
71
|
+
path = Path(path)
|
|
72
|
+
|
|
73
|
+
if path.is_dir():
|
|
74
|
+
path = path / "config.json"
|
|
75
|
+
|
|
76
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
77
|
+
data = json.load(f)
|
|
78
|
+
|
|
79
|
+
match data["model_type"]:
|
|
80
|
+
case "naive":
|
|
81
|
+
cls = NaiveModelArgs
|
|
82
|
+
case "dual_ar":
|
|
83
|
+
cls = DualARModelArgs
|
|
84
|
+
case _:
|
|
85
|
+
raise ValueError(f"Unknown model type: {data['model_type']}")
|
|
86
|
+
|
|
87
|
+
return cls(**data)
|
|
88
|
+
|
|
89
|
+
def save(self, path: str):
|
|
90
|
+
with open(path, "w") as f:
|
|
91
|
+
json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class NaiveModelArgs(BaseModelArgs):
|
|
96
|
+
model_type: str = "naive"
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class DualARModelArgs(BaseModelArgs):
|
|
101
|
+
model_type: str = "dual_ar"
|
|
102
|
+
n_fast_layer: int = 4
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KVCache(nn.Module):
|
|
106
|
+
def __init__(
|
|
107
|
+
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
|
|
108
|
+
):
|
|
109
|
+
super().__init__()
|
|
110
|
+
cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
|
|
111
|
+
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
|
|
112
|
+
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
|
|
113
|
+
|
|
114
|
+
def update(self, input_pos, k_val, v_val):
|
|
115
|
+
# input_pos: [S], k_val: [B, H, S, D]
|
|
116
|
+
assert input_pos.shape[0] == k_val.shape[2]
|
|
117
|
+
|
|
118
|
+
k_out = self.k_cache
|
|
119
|
+
v_out = self.v_cache
|
|
120
|
+
k_out[:, :, input_pos] = k_val
|
|
121
|
+
v_out[:, :, input_pos] = v_val
|
|
122
|
+
|
|
123
|
+
return k_out, v_out
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class TransformerForwardResult:
|
|
128
|
+
token_logits: Tensor
|
|
129
|
+
codebook_logits: Tensor
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass
|
|
133
|
+
class BaseTransformerForwardResult:
|
|
134
|
+
logits: Tensor
|
|
135
|
+
hidden_states: Tensor
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class BaseTransformer(nn.Module):
|
|
139
|
+
def __init__(
|
|
140
|
+
self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
|
|
141
|
+
) -> None:
|
|
142
|
+
super().__init__()
|
|
143
|
+
self.config = config
|
|
144
|
+
self.tokenizer = tokenizer
|
|
145
|
+
|
|
146
|
+
self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
|
|
147
|
+
|
|
148
|
+
# Slow transformer
|
|
149
|
+
self.embeddings = nn.Embedding(
|
|
150
|
+
config.vocab_size,
|
|
151
|
+
config.dim,
|
|
152
|
+
)
|
|
153
|
+
self.codebook_embeddings = nn.Embedding(
|
|
154
|
+
config.codebook_size * config.num_codebooks,
|
|
155
|
+
config.dim,
|
|
156
|
+
)
|
|
157
|
+
self.layers = nn.ModuleList(
|
|
158
|
+
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
|
|
159
|
+
)
|
|
160
|
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
161
|
+
|
|
162
|
+
if self.config.tie_word_embeddings is False:
|
|
163
|
+
self.output = nn.Linear(
|
|
164
|
+
config.dim,
|
|
165
|
+
config.vocab_size,
|
|
166
|
+
bias=False,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.register_buffer(
|
|
170
|
+
"freqs_cis",
|
|
171
|
+
precompute_freqs_cis(
|
|
172
|
+
config.max_seq_len,
|
|
173
|
+
config.dim // config.n_head,
|
|
174
|
+
config.rope_base,
|
|
175
|
+
),
|
|
176
|
+
persistent=False,
|
|
177
|
+
)
|
|
178
|
+
self.register_buffer(
|
|
179
|
+
"causal_mask",
|
|
180
|
+
torch.tril(
|
|
181
|
+
torch.ones(
|
|
182
|
+
config.max_seq_len,
|
|
183
|
+
config.max_seq_len,
|
|
184
|
+
dtype=torch.bool,
|
|
185
|
+
)
|
|
186
|
+
),
|
|
187
|
+
persistent=False,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# For kv cache
|
|
191
|
+
self.max_batch_size = -1
|
|
192
|
+
self.max_seq_len = -1
|
|
193
|
+
|
|
194
|
+
if init_weights:
|
|
195
|
+
self.apply(self._init_weights)
|
|
196
|
+
|
|
197
|
+
def setup_caches(
|
|
198
|
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
|
199
|
+
):
|
|
200
|
+
if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
head_dim = self.config.dim // self.config.n_head
|
|
204
|
+
max_seq_len = find_multiple(max_seq_len, 8)
|
|
205
|
+
self.max_seq_len = max_seq_len
|
|
206
|
+
self.max_batch_size = max_batch_size
|
|
207
|
+
|
|
208
|
+
for b in self.layers:
|
|
209
|
+
b.attention.kv_cache = KVCache(
|
|
210
|
+
max_batch_size,
|
|
211
|
+
max_seq_len,
|
|
212
|
+
self.config.n_local_heads,
|
|
213
|
+
head_dim,
|
|
214
|
+
dtype=dtype,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def embed(self, x: Tensor) -> Tensor:
|
|
218
|
+
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
219
|
+
for i in range(self.config.num_codebooks):
|
|
220
|
+
emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
|
|
221
|
+
emb[x[:, 0] != self.semantic_token_id] = 0
|
|
222
|
+
vocab_embeds.append(emb)
|
|
223
|
+
|
|
224
|
+
x = torch.stack(vocab_embeds, dim=3)
|
|
225
|
+
x = x.sum(dim=3)
|
|
226
|
+
|
|
227
|
+
return x
|
|
228
|
+
|
|
229
|
+
def forward(
|
|
230
|
+
self,
|
|
231
|
+
inp: Tensor,
|
|
232
|
+
key_padding_mask: Optional[Tensor] = None,
|
|
233
|
+
) -> BaseTransformerForwardResult:
|
|
234
|
+
seq_len = inp.size(2)
|
|
235
|
+
|
|
236
|
+
# Here we want to merge the embeddings of the codebooks
|
|
237
|
+
x = self.embed(inp)
|
|
238
|
+
|
|
239
|
+
freqs_cis = self.freqs_cis[:seq_len]
|
|
240
|
+
|
|
241
|
+
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
|
242
|
+
# That is, FALSE means masked out
|
|
243
|
+
# To maintain consistency, key_padding_mask use TRUE to mask out
|
|
244
|
+
mask = None
|
|
245
|
+
if key_padding_mask is not None:
|
|
246
|
+
mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
|
|
247
|
+
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
|
248
|
+
|
|
249
|
+
for layer in self.layers:
|
|
250
|
+
if self.config.use_gradient_checkpointing and self.training:
|
|
251
|
+
x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
|
|
252
|
+
else:
|
|
253
|
+
x = layer(x, freqs_cis, mask)
|
|
254
|
+
|
|
255
|
+
# We got slow_out here
|
|
256
|
+
slow_out = self.norm(x)
|
|
257
|
+
|
|
258
|
+
if self.config.tie_word_embeddings:
|
|
259
|
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
|
260
|
+
else:
|
|
261
|
+
token_logits = self.output(slow_out)
|
|
262
|
+
|
|
263
|
+
return BaseTransformerForwardResult(
|
|
264
|
+
logits=token_logits,
|
|
265
|
+
hidden_states=x,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def forward_generate(
|
|
269
|
+
self,
|
|
270
|
+
x: Tensor,
|
|
271
|
+
input_pos: Optional[Tensor] = None,
|
|
272
|
+
return_all: bool = False,
|
|
273
|
+
) -> BaseTransformerForwardResult:
|
|
274
|
+
# This is used for generation, optimized for torch compile
|
|
275
|
+
assert (
|
|
276
|
+
self.max_seq_len != -1 and self.max_batch_size != -1
|
|
277
|
+
), "Please call setup_caches before forward_generate"
|
|
278
|
+
|
|
279
|
+
x = self.embed(x)
|
|
280
|
+
|
|
281
|
+
mask = self.causal_mask[
|
|
282
|
+
None, None, input_pos, : self.max_seq_len
|
|
283
|
+
] # (B, N, Q, K)
|
|
284
|
+
freqs_cis = self.freqs_cis[input_pos]
|
|
285
|
+
|
|
286
|
+
for layer in self.layers:
|
|
287
|
+
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
|
288
|
+
|
|
289
|
+
# If prefill, we only calculate the logits of last token
|
|
290
|
+
if x.size(1) > 1 and not return_all:
|
|
291
|
+
x = x[:, -1:]
|
|
292
|
+
|
|
293
|
+
# We got slow_out here
|
|
294
|
+
slow_out = self.norm(x)
|
|
295
|
+
|
|
296
|
+
if self.config.tie_word_embeddings:
|
|
297
|
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
|
298
|
+
else:
|
|
299
|
+
token_logits = self.output(slow_out)
|
|
300
|
+
|
|
301
|
+
return BaseTransformerForwardResult(
|
|
302
|
+
logits=token_logits,
|
|
303
|
+
hidden_states=x,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def _init_weights(self, module):
|
|
307
|
+
std = self.config.initializer_range
|
|
308
|
+
if isinstance(module, nn.Linear):
|
|
309
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
310
|
+
if module.bias is not None:
|
|
311
|
+
module.bias.data.zero_()
|
|
312
|
+
elif isinstance(module, nn.Embedding):
|
|
313
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
314
|
+
if module.padding_idx is not None:
|
|
315
|
+
module.weight.data[module.padding_idx].zero_()
|
|
316
|
+
|
|
317
|
+
@staticmethod
|
|
318
|
+
def from_pretrained(
|
|
319
|
+
path: str,
|
|
320
|
+
load_weights: bool = False,
|
|
321
|
+
max_length: int | None = None,
|
|
322
|
+
lora_config: LoraConfig | None = None,
|
|
323
|
+
rope_base: int | None = None,
|
|
324
|
+
) -> "BaseTransformer":
|
|
325
|
+
config = BaseModelArgs.from_pretrained(str(path))
|
|
326
|
+
if max_length is not None:
|
|
327
|
+
config.max_seq_len = max_length
|
|
328
|
+
log.info(f"Override max_seq_len to {max_length}")
|
|
329
|
+
|
|
330
|
+
if rope_base is not None:
|
|
331
|
+
config.rope_base = rope_base
|
|
332
|
+
log.info(f"Override rope_base to {rope_base}")
|
|
333
|
+
|
|
334
|
+
match config.model_type:
|
|
335
|
+
case "naive":
|
|
336
|
+
model_cls = NaiveTransformer
|
|
337
|
+
case "dual_ar":
|
|
338
|
+
model_cls = DualARTransformer
|
|
339
|
+
case _:
|
|
340
|
+
raise ValueError(f"Unknown model type: {config.model_type}")
|
|
341
|
+
|
|
342
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path))
|
|
343
|
+
log.info(f"Loading model from {path}, config: {config}")
|
|
344
|
+
model = model_cls(config, tokenizer=tokenizer)
|
|
345
|
+
|
|
346
|
+
if lora_config is not None:
|
|
347
|
+
setup_lora(model, lora_config)
|
|
348
|
+
log.info(f"LoRA setup: {lora_config}")
|
|
349
|
+
|
|
350
|
+
if load_weights is False:
|
|
351
|
+
log.info("Randomly initialized model")
|
|
352
|
+
else:
|
|
353
|
+
|
|
354
|
+
if "int8" in str(Path(path)):
|
|
355
|
+
logger.info("Using int8 weight-only quantization!")
|
|
356
|
+
from ...tools.llama.quantize import WeightOnlyInt8QuantHandler
|
|
357
|
+
|
|
358
|
+
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
|
359
|
+
model = simple_quantizer.convert_for_runtime()
|
|
360
|
+
|
|
361
|
+
if "int4" in str(Path(path)):
|
|
362
|
+
logger.info("Using int4 quantization!")
|
|
363
|
+
path_comps = path.name.split("-")
|
|
364
|
+
assert path_comps[-2].startswith("g")
|
|
365
|
+
groupsize = int(path_comps[-2][1:])
|
|
366
|
+
from ...tools.llama.quantize import WeightOnlyInt4QuantHandler
|
|
367
|
+
|
|
368
|
+
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
|
369
|
+
model = simple_quantizer.convert_for_runtime()
|
|
370
|
+
|
|
371
|
+
weights = torch.load(
|
|
372
|
+
Path(path) / "model.pth", map_location="cpu", mmap=True
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if "state_dict" in weights:
|
|
376
|
+
logger.warning(
|
|
377
|
+
"Using a TextToSemantic LightningModule checkpoint, "
|
|
378
|
+
"please make sure it is a full model, not a LoRA model."
|
|
379
|
+
)
|
|
380
|
+
weights = weights["state_dict"]
|
|
381
|
+
|
|
382
|
+
if next(iter(weights.keys())).startswith("model."):
|
|
383
|
+
logger.info(
|
|
384
|
+
f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
|
|
385
|
+
)
|
|
386
|
+
new_weights = OrderedDict()
|
|
387
|
+
for k, v in weights.items():
|
|
388
|
+
new_weights[k.replace("model.", "")] = v
|
|
389
|
+
weights = new_weights
|
|
390
|
+
|
|
391
|
+
# Verify the name and shape of parameters since strict=False in load_state_dict.
|
|
392
|
+
for k, v in model.named_parameters():
|
|
393
|
+
if k not in weights:
|
|
394
|
+
logger.warning(f"No weight for {k}")
|
|
395
|
+
elif v.shape != weights[k].shape:
|
|
396
|
+
logger.warning(
|
|
397
|
+
f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
err = model.load_state_dict(weights, strict=False, assign=True)
|
|
401
|
+
log.info(f"Loaded weights with error: {err}")
|
|
402
|
+
|
|
403
|
+
return model
|
|
404
|
+
|
|
405
|
+
def save_pretrained(self, path: str, drop_lora: bool = False):
|
|
406
|
+
path = Path(path)
|
|
407
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
408
|
+
|
|
409
|
+
self.config.save(path / "config.json")
|
|
410
|
+
state_dict = self.state_dict()
|
|
411
|
+
|
|
412
|
+
if drop_lora:
|
|
413
|
+
for key in list(state_dict.keys()):
|
|
414
|
+
if "lora" not in key:
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
state_dict.pop(key)
|
|
418
|
+
log.info(f"Drop LoRA parameter: {key}")
|
|
419
|
+
|
|
420
|
+
torch.save(state_dict, path / "model.pth")
|
|
421
|
+
self.tokenizer.save_pretrained(path)
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class NaiveTransformer(BaseTransformer):
|
|
425
|
+
def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
|
|
426
|
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
427
|
+
|
|
428
|
+
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
429
|
+
self.codebook_output = nn.Linear(
|
|
430
|
+
config.dim,
|
|
431
|
+
config.codebook_size * config.num_codebooks,
|
|
432
|
+
bias=False,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
self.apply(self._init_weights)
|
|
436
|
+
|
|
437
|
+
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
|
|
438
|
+
token_logits = result.logits
|
|
439
|
+
x = result.hidden_states
|
|
440
|
+
|
|
441
|
+
# Codebook
|
|
442
|
+
codebook_logits = self.codebook_output(self.codebook_norm(x))
|
|
443
|
+
codebook_logits = rearrange(
|
|
444
|
+
codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
return TransformerForwardResult(
|
|
448
|
+
token_logits=token_logits,
|
|
449
|
+
codebook_logits=codebook_logits,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def forward(
|
|
453
|
+
self,
|
|
454
|
+
inp: Tensor,
|
|
455
|
+
key_padding_mask: Optional[Tensor] = None,
|
|
456
|
+
) -> TransformerForwardResult:
|
|
457
|
+
result = super().forward(
|
|
458
|
+
inp=inp,
|
|
459
|
+
key_padding_mask=key_padding_mask,
|
|
460
|
+
)
|
|
461
|
+
return self.decode(result)
|
|
462
|
+
|
|
463
|
+
def forward_generate(
|
|
464
|
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
465
|
+
) -> TransformerForwardResult:
|
|
466
|
+
result = super().forward_generate(x, input_pos)
|
|
467
|
+
return self.decode(result)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class DualARTransformer(BaseTransformer):
|
|
471
|
+
def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
|
|
472
|
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
473
|
+
|
|
474
|
+
# Fast transformer
|
|
475
|
+
self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
|
|
476
|
+
|
|
477
|
+
# The equivalent bs is so large that sdpa doesn't work
|
|
478
|
+
self.fast_layers = nn.ModuleList(
|
|
479
|
+
TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
|
|
480
|
+
)
|
|
481
|
+
self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
482
|
+
self.fast_output = nn.Linear(
|
|
483
|
+
config.dim,
|
|
484
|
+
config.codebook_size,
|
|
485
|
+
bias=False,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
self.apply(self._init_weights)
|
|
489
|
+
|
|
490
|
+
def setup_caches(
|
|
491
|
+
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
|
492
|
+
):
|
|
493
|
+
super().setup_caches(max_batch_size, max_seq_len, dtype)
|
|
494
|
+
|
|
495
|
+
head_dim = self.config.dim // self.config.n_head
|
|
496
|
+
|
|
497
|
+
# Fast transformer
|
|
498
|
+
# The max seq len here is the number of codebooks
|
|
499
|
+
for b in self.fast_layers:
|
|
500
|
+
b.attention.kv_cache = KVCache(
|
|
501
|
+
max_batch_size,
|
|
502
|
+
self.config.num_codebooks,
|
|
503
|
+
self.config.n_local_heads,
|
|
504
|
+
head_dim,
|
|
505
|
+
dtype=dtype,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
def forward(
|
|
509
|
+
self,
|
|
510
|
+
inp: Tensor,
|
|
511
|
+
key_padding_mask: Optional[Tensor] = None,
|
|
512
|
+
) -> TransformerForwardResult:
|
|
513
|
+
parent_result = super().forward(inp, key_padding_mask)
|
|
514
|
+
token_logits = parent_result.logits
|
|
515
|
+
x = parent_result.hidden_states
|
|
516
|
+
|
|
517
|
+
# Fast transformer
|
|
518
|
+
fast_seq_len = self.config.num_codebooks
|
|
519
|
+
fast_mask = self.causal_mask[
|
|
520
|
+
None, None, :fast_seq_len, :fast_seq_len
|
|
521
|
+
] # (B, N, Q, K)
|
|
522
|
+
fast_freqs_cis = self.freqs_cis[:fast_seq_len]
|
|
523
|
+
|
|
524
|
+
# Drop the last token and rotate left
|
|
525
|
+
codebooks = inp[:, 1:-1, 1:]
|
|
526
|
+
codebooks = F.pad(codebooks, (0, 1), value=0)
|
|
527
|
+
codebook_embeddings = self.fast_embeddings(codebooks)
|
|
528
|
+
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
|
529
|
+
b, s = x.size(0), x.size(2)
|
|
530
|
+
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
|
|
531
|
+
|
|
532
|
+
# Remove padded part
|
|
533
|
+
codebooks = rearrange(codebooks, "b n s -> (b s) n")
|
|
534
|
+
codebook_mask = (codebooks == 0).all(dim=-1)
|
|
535
|
+
|
|
536
|
+
if torch.all(codebook_mask):
|
|
537
|
+
# If all codebooks are padded, we keep first 8 to make sure the model runs
|
|
538
|
+
codebook_mask[:8] = False
|
|
539
|
+
|
|
540
|
+
x_bs, x_len = x.size(0), x.size(1)
|
|
541
|
+
x = x[~codebook_mask]
|
|
542
|
+
|
|
543
|
+
for layer in self.fast_layers:
|
|
544
|
+
if self.config.use_gradient_checkpointing and self.training:
|
|
545
|
+
x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
|
|
546
|
+
else:
|
|
547
|
+
x = layer(x, fast_freqs_cis, fast_mask)
|
|
548
|
+
|
|
549
|
+
# unflatten the batch and num_codebooks
|
|
550
|
+
fast_out = self.fast_norm(x)
|
|
551
|
+
codebook_logits = self.fast_output(fast_out)
|
|
552
|
+
|
|
553
|
+
# Re-pad the codebook_logits
|
|
554
|
+
buffer = torch.zeros(
|
|
555
|
+
x_bs,
|
|
556
|
+
x_len,
|
|
557
|
+
codebook_logits.size(-1),
|
|
558
|
+
device=codebook_logits.device,
|
|
559
|
+
dtype=codebook_logits.dtype,
|
|
560
|
+
)
|
|
561
|
+
buffer[~codebook_mask] = codebook_logits
|
|
562
|
+
codebook_logits = buffer
|
|
563
|
+
|
|
564
|
+
assert codebook_logits.shape[1] == self.config.num_codebooks
|
|
565
|
+
codebook_logits = rearrange(
|
|
566
|
+
codebook_logits,
|
|
567
|
+
"(b s) n d -> b s n d",
|
|
568
|
+
b=b,
|
|
569
|
+
s=s,
|
|
570
|
+
n=self.config.num_codebooks,
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
return TransformerForwardResult(
|
|
574
|
+
token_logits=token_logits,
|
|
575
|
+
codebook_logits=codebook_logits,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
def forward_generate_fast(
|
|
579
|
+
self, x: Tensor, input_pos: Optional[Tensor] = None
|
|
580
|
+
) -> Tensor:
|
|
581
|
+
# Fast transformer
|
|
582
|
+
x = x.view(1, 1, -1)
|
|
583
|
+
|
|
584
|
+
fast_mask = self.causal_mask[
|
|
585
|
+
None, None, input_pos, : self.config.num_codebooks
|
|
586
|
+
] # (B, N, Q, K)
|
|
587
|
+
fast_freqs_cis = self.freqs_cis[input_pos]
|
|
588
|
+
|
|
589
|
+
for layer in self.fast_layers:
|
|
590
|
+
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
|
|
591
|
+
|
|
592
|
+
# unflatten the batch and num_codebooks
|
|
593
|
+
fast_out = self.fast_norm(x) # only take the last token
|
|
594
|
+
codebook_logits = self.fast_output(fast_out)
|
|
595
|
+
|
|
596
|
+
return codebook_logits
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class TransformerBlock(nn.Module):
|
|
600
|
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
|
|
601
|
+
super().__init__()
|
|
602
|
+
self.attention = Attention(config, use_sdpa=use_sdpa)
|
|
603
|
+
self.feed_forward = FeedForward(config)
|
|
604
|
+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
|
605
|
+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
|
606
|
+
|
|
607
|
+
def forward(
|
|
608
|
+
self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
|
|
609
|
+
) -> Tensor:
|
|
610
|
+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
|
|
611
|
+
out = h + self.feed_forward(self.ffn_norm(h))
|
|
612
|
+
return out
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
class Attention(nn.Module):
|
|
616
|
+
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
|
|
617
|
+
super().__init__()
|
|
618
|
+
assert config.dim % config.n_head == 0
|
|
619
|
+
|
|
620
|
+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
|
621
|
+
# key, query, value projections for all heads, but in a batch
|
|
622
|
+
self.wqkv = nn.Linear(
|
|
623
|
+
config.dim, total_head_dim, bias=config.attention_qkv_bias
|
|
624
|
+
)
|
|
625
|
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
|
626
|
+
self.kv_cache = None
|
|
627
|
+
|
|
628
|
+
self.dropout = config.dropout
|
|
629
|
+
self.n_head = config.n_head
|
|
630
|
+
self.head_dim = config.head_dim
|
|
631
|
+
self.n_local_heads = config.n_local_heads
|
|
632
|
+
self.dim = config.dim
|
|
633
|
+
self.use_sdpa = use_sdpa
|
|
634
|
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
635
|
+
|
|
636
|
+
def load_hook(self, state_dict, prefix, *args):
|
|
637
|
+
if prefix + "wq.weight" in state_dict:
|
|
638
|
+
wq = state_dict.pop(prefix + "wq.weight")
|
|
639
|
+
wk = state_dict.pop(prefix + "wk.weight")
|
|
640
|
+
wv = state_dict.pop(prefix + "wv.weight")
|
|
641
|
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
|
642
|
+
|
|
643
|
+
def forward(
|
|
644
|
+
self,
|
|
645
|
+
x: Tensor,
|
|
646
|
+
freqs_cis: Tensor,
|
|
647
|
+
mask: Tensor,
|
|
648
|
+
input_pos: Optional[Tensor] = None,
|
|
649
|
+
) -> Tensor:
|
|
650
|
+
bsz, seqlen, _ = x.shape
|
|
651
|
+
|
|
652
|
+
kv_size = self.n_local_heads * self.head_dim
|
|
653
|
+
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
|
654
|
+
|
|
655
|
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
|
656
|
+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
|
657
|
+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
|
658
|
+
|
|
659
|
+
q = apply_rotary_emb(q, freqs_cis)
|
|
660
|
+
k = apply_rotary_emb(k, freqs_cis)
|
|
661
|
+
|
|
662
|
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
|
663
|
+
|
|
664
|
+
if self.kv_cache is not None:
|
|
665
|
+
k, v = self.kv_cache.update(input_pos, k, v)
|
|
666
|
+
|
|
667
|
+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
668
|
+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
|
669
|
+
|
|
670
|
+
if self.use_sdpa:
|
|
671
|
+
if mask is None:
|
|
672
|
+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
|
673
|
+
y = F.scaled_dot_product_attention(
|
|
674
|
+
q,
|
|
675
|
+
k,
|
|
676
|
+
v,
|
|
677
|
+
dropout_p=self.dropout if self.training else 0.0,
|
|
678
|
+
is_causal=True,
|
|
679
|
+
# No third party attn_mask here to use flash_attention
|
|
680
|
+
)
|
|
681
|
+
else:
|
|
682
|
+
y = F.scaled_dot_product_attention(
|
|
683
|
+
q,
|
|
684
|
+
k,
|
|
685
|
+
v,
|
|
686
|
+
attn_mask=mask,
|
|
687
|
+
dropout_p=self.dropout if self.training else 0.0,
|
|
688
|
+
)
|
|
689
|
+
else:
|
|
690
|
+
y = self.eq_scaled_dot_product_attention(
|
|
691
|
+
q,
|
|
692
|
+
k,
|
|
693
|
+
v,
|
|
694
|
+
attn_mask=mask,
|
|
695
|
+
dropout_p=self.dropout if self.training else 0.0,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
|
699
|
+
|
|
700
|
+
return self.wo(y)
|
|
701
|
+
|
|
702
|
+
def eq_scaled_dot_product_attention(
|
|
703
|
+
self,
|
|
704
|
+
query,
|
|
705
|
+
key,
|
|
706
|
+
value,
|
|
707
|
+
attn_mask=None,
|
|
708
|
+
dropout_p=0.0,
|
|
709
|
+
) -> torch.Tensor:
|
|
710
|
+
# This is a standard scaled dot product attention
|
|
711
|
+
# It's low efficient, but it doesn't raise cuda error
|
|
712
|
+
|
|
713
|
+
L, S = query.size(-2), key.size(-2)
|
|
714
|
+
scale_factor = 1 / math.sqrt(query.size(-1))
|
|
715
|
+
attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
|
|
716
|
+
|
|
717
|
+
if attn_mask is not None:
|
|
718
|
+
if attn_mask.dtype == torch.bool:
|
|
719
|
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
|
720
|
+
else:
|
|
721
|
+
attn_bias += attn_mask
|
|
722
|
+
|
|
723
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
724
|
+
attn_weight += attn_bias
|
|
725
|
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
726
|
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
|
727
|
+
|
|
728
|
+
return attn_weight @ value
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
class FeedForward(nn.Module):
|
|
732
|
+
def __init__(self, config: BaseModelArgs) -> None:
|
|
733
|
+
super().__init__()
|
|
734
|
+
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
|
735
|
+
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
|
|
736
|
+
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
|
|
737
|
+
|
|
738
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
739
|
+
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
class RMSNorm(nn.Module):
|
|
743
|
+
def __init__(self, dim: int, eps: float = 1e-5):
|
|
744
|
+
super().__init__()
|
|
745
|
+
self.eps = eps
|
|
746
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
747
|
+
|
|
748
|
+
def _norm(self, x):
|
|
749
|
+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
|
750
|
+
|
|
751
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
752
|
+
output = self._norm(x.float()).type_as(x)
|
|
753
|
+
return output * self.weight
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
|
|
757
|
+
freqs = 1.0 / (
|
|
758
|
+
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
|
|
759
|
+
)
|
|
760
|
+
t = torch.arange(seq_len, device=freqs.device)
|
|
761
|
+
freqs = torch.outer(t, freqs)
|
|
762
|
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
763
|
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
|
764
|
+
return cache.to(dtype=torch.bfloat16)
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|
768
|
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
769
|
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
|
|
770
|
+
x_out2 = torch.stack(
|
|
771
|
+
[
|
|
772
|
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
|
773
|
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
|
774
|
+
],
|
|
775
|
+
-1,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
x_out2 = x_out2.flatten(3)
|
|
779
|
+
return x_out2.type_as(x)
|