minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,92 @@
1
+ # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from torch import nn
16
+ from transformers import AutoConfig
17
+
18
+ from stepaudio2.flashcosyvoice.config import CosyVoice2LLMConfig
19
+ from stepaudio2.flashcosyvoice.modules.qwen2_components.layers import (
20
+ ParallelLMHead, Qwen2DecoderLayer, RMSNorm, VocabParallelEmbedding)
21
+
22
+
23
+ class Qwen2Model(nn.Module):
24
+
25
+ def __init__(
26
+ self,
27
+ config: CosyVoice2LLMConfig,
28
+ ):
29
+ super().__init__()
30
+ self.vocab_size = config.vocab_size
31
+ self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
32
+ self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
33
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
34
+
35
+ def forward(
36
+ self,
37
+ input_ids: torch.Tensor,
38
+ positions: torch.Tensor,
39
+ ) -> torch.Tensor:
40
+ hidden_states = self.embed_tokens(input_ids)
41
+ residual = None
42
+ for layer in self.layers:
43
+ hidden_states, residual = layer(
44
+ positions,
45
+ hidden_states,
46
+ residual,
47
+ )
48
+ hidden_states, _ = self.norm(hidden_states, residual)
49
+ return hidden_states
50
+
51
+
52
+ class Qwen2ForCausalLM(nn.Module):
53
+ packed_modules_mapping = {
54
+ "q_proj": ("qkv_proj", "q"),
55
+ "k_proj": ("qkv_proj", "k"),
56
+ "v_proj": ("qkv_proj", "v"),
57
+ "gate_proj": ("gate_up_proj", 0),
58
+ "up_proj": ("gate_up_proj", 1),
59
+ }
60
+
61
+ def __init__(
62
+ self,
63
+ config: CosyVoice2LLMConfig | AutoConfig
64
+ ):
65
+ super().__init__()
66
+ self.model = Qwen2Model(config)
67
+ if hasattr(config, "speech_vocab_size"):
68
+ self.lm_head = ParallelLMHead(config.speech_vocab_size, config.hidden_size, bias=getattr(config, "lm_head_bias", True))
69
+ self.model_type = "speech_llm"
70
+ else:
71
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False)
72
+ self.model_type = "text_llm"
73
+ self.tie_word_embeddings = config.tie_word_embeddings
74
+ if self.tie_word_embeddings:
75
+ if self.model_type == "speech_llm":
76
+ assert config.vocab_size == config.speech_vocab_size, "vocab_size and speech_vocab_size must be the same when tie_word_embeddings is True"
77
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: torch.Tensor,
82
+ positions: torch.Tensor,
83
+ ) -> torch.Tensor:
84
+ hidden_states = self.model(input_ids, positions)
85
+ return hidden_states
86
+
87
+ def compute_logits(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ ) -> torch.Tensor:
91
+ logits = self.lm_head(hidden_states)
92
+ return logits
@@ -0,0 +1,616 @@
1
+ from functools import lru_cache
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import triton
8
+ import triton.language as tl
9
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
10
+
11
+ from stepaudio2.flashcosyvoice.config import CosyVoice2LLMConfig
12
+ from stepaudio2.flashcosyvoice.utils.context import get_context
13
+
14
+
15
+ class SiluAndMul(nn.Module):
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ @torch.compile
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ x, y = x.chunk(2, -1)
23
+ return F.silu(x) * y
24
+
25
+
26
+ class RMSNorm(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ hidden_size: int,
31
+ eps: float = 1e-6,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.hidden_size = hidden_size
35
+ self.eps = eps
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+
38
+ @torch.compile
39
+ def rms_forward(
40
+ self,
41
+ x: torch.Tensor,
42
+ ) -> torch.Tensor:
43
+ orig_dtype = x.dtype
44
+ x = x.to(torch.float32)
45
+ var = x.pow(2).mean(dim=-1, keepdim=True)
46
+ x.mul_(torch.rsqrt(var + self.eps))
47
+ x = x.to(orig_dtype).mul_(self.weight)
48
+ return x
49
+
50
+ @torch.compile
51
+ def add_rms_forward(
52
+ self,
53
+ x: torch.Tensor,
54
+ residual: torch.Tensor,
55
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
56
+ orig_dtype = x.dtype
57
+ x = x.to(torch.float32).add_(residual.to(torch.float32))
58
+ residual = x.to(orig_dtype)
59
+ var = x.pow(2).mean(dim=-1, keepdim=True)
60
+ x.mul_(torch.rsqrt(var + self.eps))
61
+ x = x.to(orig_dtype).mul_(self.weight)
62
+ return x, residual
63
+
64
+ def forward(
65
+ self,
66
+ x: torch.Tensor,
67
+ residual: torch.Tensor | None = None,
68
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
69
+ if residual is None:
70
+ return self.rms_forward(x)
71
+ else:
72
+ return self.add_rms_forward(x, residual)
73
+
74
+
75
+ @triton.jit
76
+ def store_kvcache_kernel(
77
+ key_ptr,
78
+ key_stride,
79
+ value_ptr,
80
+ value_stride,
81
+ k_cache_ptr,
82
+ v_cache_ptr,
83
+ slot_mapping_ptr,
84
+ D: tl.constexpr,
85
+ ):
86
+ idx = tl.program_id(0)
87
+ key_offsets = idx * key_stride + tl.arange(0, D)
88
+ value_offsets = idx * value_stride + tl.arange(0, D)
89
+ key = tl.load(key_ptr + key_offsets)
90
+ value = tl.load(value_ptr + value_offsets)
91
+ slot = tl.load(slot_mapping_ptr + idx)
92
+ cache_offsets = slot * D + tl.arange(0, D)
93
+ tl.store(k_cache_ptr + cache_offsets, key)
94
+ tl.store(v_cache_ptr + cache_offsets, value)
95
+
96
+
97
+ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
98
+ N, num_heads, head_dim = key.shape
99
+ D = num_heads * head_dim
100
+ assert key.stride(-1) == 1 and value.stride(-1) == 1
101
+ assert key.stride(1) == head_dim and value.stride(1) == head_dim
102
+ assert k_cache.stride(1) == D and v_cache.stride(1) == D
103
+ assert slot_mapping.numel() == N
104
+ store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
105
+
106
+
107
+ class Attention(nn.Module):
108
+
109
+ def __init__(
110
+ self,
111
+ num_heads,
112
+ head_dim,
113
+ scale,
114
+ num_kv_heads,
115
+ ):
116
+ super().__init__()
117
+ self.num_heads = num_heads
118
+ self.head_dim = head_dim
119
+ self.scale = scale
120
+ self.num_kv_heads = num_kv_heads
121
+ self.k_cache = self.v_cache = torch.tensor([])
122
+
123
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
124
+ o: torch.Tensor
125
+ q = q.view(-1, self.num_heads, self.head_dim)
126
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
127
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
128
+ context = get_context()
129
+ k_cache, v_cache = self.k_cache, self.v_cache
130
+ if k_cache.numel() and v_cache.numel():
131
+ store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
132
+ if context.is_prefill:
133
+ if context.block_tables is not None: # prefix cache
134
+ k, v = k_cache, v_cache
135
+ o = flash_attn_varlen_func(q, k, v,
136
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
137
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
138
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
139
+ else: # decode
140
+ o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
141
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
142
+ softmax_scale=self.scale, causal=True)
143
+ o = o.view(-1, self.num_heads * self.head_dim)
144
+ return o
145
+
146
+
147
+ class VocabParallelEmbedding(nn.Module):
148
+
149
+ def __init__(
150
+ self,
151
+ num_embeddings: int,
152
+ embedding_dim: int,
153
+ ):
154
+ super().__init__()
155
+ # TODO(xcsong): support tp > 1
156
+ self.tp_rank = 0 # dist.get_rank()
157
+ self.tp_size = 1 # dist.get_world_size()
158
+ assert num_embeddings % self.tp_size == 0
159
+ self.num_embeddings = num_embeddings
160
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
161
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
162
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
163
+ self.embedding_dim = embedding_dim
164
+ self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
165
+ self.weight.weight_loader = self.weight_loader
166
+
167
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
168
+ param_data = param.data
169
+ shard_size = param_data.size(0)
170
+ start_idx = self.tp_rank * shard_size
171
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
172
+ assert param_data.size() == loaded_weight.size()
173
+ param_data.copy_(loaded_weight)
174
+
175
+ def forward(self, x: torch.Tensor):
176
+ if self.tp_size > 1:
177
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
178
+ x = mask * (x - self.vocab_start_idx)
179
+ y = F.embedding(x, self.weight)
180
+ if self.tp_size > 1:
181
+ y = mask.unsqueeze(1) * y
182
+ dist.all_reduce(y)
183
+ return y
184
+
185
+
186
+ class ParallelLMHead(VocabParallelEmbedding):
187
+
188
+ def __init__(
189
+ self,
190
+ num_embeddings: int,
191
+ embedding_dim: int,
192
+ bias: bool = False,
193
+ ):
194
+ super().__init__(num_embeddings, embedding_dim)
195
+ if bias:
196
+ self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
197
+ self.bias.weight_loader = self.weight_loader
198
+ else:
199
+ self.register_parameter("bias", None)
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ context = get_context()
203
+ if context.is_prefill:
204
+ last_indices = context.cu_seqlens_q[1:] - 1
205
+ x = x[last_indices].contiguous()
206
+ logits = F.linear(x, self.weight, self.bias)
207
+ if self.tp_size > 1:
208
+ all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
209
+ dist.gather(logits, all_logits, 0)
210
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
211
+ return logits
212
+
213
+
214
+ def divide(numerator, denominator):
215
+ assert numerator % denominator == 0
216
+ return numerator // denominator
217
+
218
+
219
+ class LinearBase(nn.Module):
220
+
221
+ def __init__(
222
+ self,
223
+ input_size: int,
224
+ output_size: int,
225
+ tp_dim: int | None = None,
226
+ ):
227
+ super().__init__()
228
+ self.input_size = input_size
229
+ self.output_size = output_size
230
+ self.tp_dim = tp_dim
231
+ # TODO(xcsong): support tp > 1
232
+ self.tp_rank = 0 # dist.get_rank()
233
+ self.tp_size = 1 # dist.get_world_size()
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ raise NotImplementedError
237
+
238
+
239
+ class ReplicatedLinear(LinearBase):
240
+
241
+ def __init__(
242
+ self,
243
+ input_size: int,
244
+ output_size: int,
245
+ bias: bool = False,
246
+ ):
247
+ super().__init__(input_size, output_size)
248
+ self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
249
+ self.weight.weight_loader = self.weight_loader
250
+ if bias:
251
+ self.bias = nn.Parameter(torch.empty(self.output_size))
252
+ self.bias.weight_loader = self.weight_loader
253
+ else:
254
+ self.register_parameter("bias", None)
255
+
256
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
257
+ assert param.size() == loaded_weight.size()
258
+ param.data.copy_(loaded_weight)
259
+
260
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
261
+ return F.linear(x, self.weight, self.bias)
262
+
263
+
264
+ class ColumnParallelLinear(LinearBase):
265
+
266
+ def __init__(
267
+ self,
268
+ input_size: int,
269
+ output_size: int,
270
+ bias: bool = False,
271
+ ):
272
+ super().__init__(input_size, output_size, 0)
273
+ self.input_size_per_partition = input_size
274
+ self.output_size_per_partition = divide(output_size, self.tp_size)
275
+ self.output_partition_sizes = [self.output_size_per_partition]
276
+ if hasattr(self, "output_sizes"):
277
+ self.output_partition_sizes = [
278
+ divide(output_size, self.tp_size)
279
+ for output_size in self.output_sizes
280
+ ]
281
+
282
+ self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
283
+ self.weight.weight_loader = self.weight_loader
284
+ if bias:
285
+ self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
286
+ self.bias.weight_loader = self.weight_loader
287
+ else:
288
+ self.register_parameter("bias", None)
289
+
290
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
291
+ param_data = param.data
292
+ shard_size = param_data.size(self.tp_dim)
293
+ start_idx = self.tp_rank * shard_size
294
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
295
+ assert param_data.size() == loaded_weight.size()
296
+ param_data.copy_(loaded_weight)
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ return F.linear(x, self.weight, self.bias)
300
+
301
+
302
+ class MergedColumnParallelLinear(ColumnParallelLinear):
303
+
304
+ def __init__(
305
+ self,
306
+ input_size: int,
307
+ output_sizes: list[int],
308
+ bias: bool = False,
309
+ ):
310
+ self.output_sizes = output_sizes
311
+ super().__init__(input_size, sum(output_sizes), bias=bias)
312
+
313
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
314
+ param_data = param.data
315
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
316
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
317
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
318
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
319
+ assert param_data.size() == loaded_weight.size()
320
+ param_data.copy_(loaded_weight)
321
+
322
+
323
+ class QKVParallelLinear(ColumnParallelLinear):
324
+
325
+ def __init__(
326
+ self,
327
+ hidden_size: int,
328
+ head_size: int,
329
+ total_num_heads: int,
330
+ total_num_kv_heads: int | None = None,
331
+ bias: bool = False,
332
+ ):
333
+ self.hidden_size = hidden_size
334
+ self.head_size = head_size
335
+ self.total_num_heads = total_num_heads
336
+ if total_num_kv_heads is None:
337
+ total_num_kv_heads = total_num_heads
338
+ self.total_num_kv_heads = total_num_kv_heads
339
+ # TODO(xcsong): support tp > 1
340
+ tp_size = 1 # dist.get_world_size()
341
+ self.num_heads = divide(self.total_num_heads, tp_size)
342
+ self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
343
+ input_size = self.hidden_size
344
+ output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
345
+ self.output_sizes = [
346
+ self.num_heads * self.head_size * tp_size, # q_proj
347
+ self.num_kv_heads * self.head_size * tp_size, # k_proj
348
+ self.num_kv_heads * self.head_size * tp_size, # v_proj
349
+ ]
350
+
351
+ super().__init__(input_size, output_size, bias)
352
+
353
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
354
+ param_data = param.data
355
+ assert loaded_shard_id in ["q", "k", "v"]
356
+ if loaded_shard_id == "q":
357
+ shard_size = self.num_heads * self.head_size
358
+ shard_offset = 0
359
+ elif loaded_shard_id == "k":
360
+ shard_size = self.num_kv_heads * self.head_size
361
+ shard_offset = self.num_heads * self.head_size
362
+ else:
363
+ shard_size = self.num_kv_heads * self.head_size
364
+ shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
365
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
366
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
367
+ assert param_data.size() == loaded_weight.size()
368
+ param_data.copy_(loaded_weight)
369
+
370
+
371
+ class RowParallelLinear(LinearBase):
372
+
373
+ def __init__(
374
+ self,
375
+ input_size: int,
376
+ output_size: int,
377
+ bias: bool = False,
378
+ ):
379
+ super().__init__(input_size, output_size, 1)
380
+ self.input_size_per_partition = divide(input_size, self.tp_size)
381
+ self.output_size_per_partition = output_size
382
+ self.output_partition_sizes = [output_size]
383
+
384
+ self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
385
+ self.weight.weight_loader = self.weight_loader
386
+ if bias:
387
+ self.bias = nn.Parameter(torch.empty(self.output_size))
388
+ self.bias.weight_loader = self.weight_loader
389
+ else:
390
+ self.register_parameter("bias", None)
391
+
392
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
393
+ param_data = param.data
394
+ shard_size = param_data.size(self.tp_dim)
395
+ start_idx = self.tp_rank * shard_size
396
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
397
+ assert param_data.size() == loaded_weight.size()
398
+ param_data.copy_(loaded_weight)
399
+
400
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
401
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
402
+ if self.tp_size > 1:
403
+ dist.all_reduce(y)
404
+ return y
405
+
406
+
407
+ def apply_rotary_emb(
408
+ x: torch.Tensor,
409
+ cos: torch.Tensor,
410
+ sin: torch.Tensor,
411
+ ) -> torch.Tensor:
412
+ cos = cos.unsqueeze(-2)
413
+ sin = sin.unsqueeze(-2)
414
+ x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
415
+ y1 = x1 * cos - x2 * sin
416
+ y2 = x2 * cos + x1 * sin
417
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
418
+
419
+
420
+ class RotaryEmbedding(nn.Module):
421
+
422
+ def __init__(
423
+ self,
424
+ head_size: int,
425
+ rotary_dim: int,
426
+ max_position_embeddings: int,
427
+ base: float,
428
+ ) -> None:
429
+ super().__init__()
430
+ self.head_size = head_size
431
+ assert rotary_dim == head_size
432
+ inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
433
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
434
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
435
+ cos = freqs.cos()
436
+ sin = freqs.sin()
437
+ cache = torch.cat((cos, sin), dim=-1)
438
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
439
+
440
+ @torch.compile
441
+ def forward(
442
+ self,
443
+ positions: torch.Tensor,
444
+ query: torch.Tensor,
445
+ key: torch.Tensor,
446
+ ) -> tuple[torch.Tensor, torch.Tensor]:
447
+ positions = positions.flatten()
448
+ num_tokens = positions.shape[0]
449
+ cos_sin = self.cos_sin_cache[positions]
450
+ cos, sin = cos_sin.chunk(2, dim=-1)
451
+ query_shape = query.shape
452
+ query = query.view(num_tokens, -1, self.head_size)
453
+ query = apply_rotary_emb(query, cos, sin).view(query_shape)
454
+ key_shape = key.shape
455
+ key = key.view(num_tokens, -1, self.head_size)
456
+ key = apply_rotary_emb(key, cos, sin).view(key_shape)
457
+ return query, key
458
+
459
+
460
+ @lru_cache(1)
461
+ def get_rope(
462
+ head_size: int,
463
+ rotary_dim: int,
464
+ max_position: int,
465
+ base: float,
466
+ rope_scaling: dict | None = None,
467
+ ):
468
+ assert rope_scaling is None
469
+ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
470
+ return rotary_emb
471
+
472
+
473
+ class Qwen2Attention(nn.Module):
474
+
475
+ def __init__(
476
+ self,
477
+ hidden_size: int,
478
+ num_heads: int,
479
+ num_kv_heads: int,
480
+ max_position: int = 4096 * 32,
481
+ head_dim: int | None = None,
482
+ rms_norm_eps: float = 1e-06,
483
+ qkv_bias: bool = True,
484
+ rope_theta: float = 1000000.0,
485
+ rope_scaling: tuple | None = None,
486
+ ) -> None:
487
+ super().__init__()
488
+ self.hidden_size = hidden_size
489
+ # TODO(xcsong): support tp > 1
490
+ tp_size = 1 # dist.get_world_size()
491
+ self.total_num_heads = num_heads
492
+ assert self.total_num_heads % tp_size == 0
493
+ self.num_heads = self.total_num_heads // tp_size
494
+ self.total_num_kv_heads = num_kv_heads
495
+ assert self.total_num_kv_heads % tp_size == 0
496
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
497
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
498
+ self.q_size = self.num_heads * self.head_dim
499
+ self.kv_size = self.num_kv_heads * self.head_dim
500
+ self.scaling = self.head_dim**-0.5
501
+ self.rope_theta = rope_theta
502
+
503
+ self.qkv_proj = QKVParallelLinear(
504
+ hidden_size,
505
+ self.head_dim,
506
+ self.total_num_heads,
507
+ self.total_num_kv_heads,
508
+ bias=qkv_bias,
509
+ )
510
+ self.o_proj = RowParallelLinear(
511
+ self.total_num_heads * self.head_dim,
512
+ hidden_size,
513
+ bias=False,
514
+ )
515
+
516
+ self.rotary_emb = get_rope(
517
+ self.head_dim,
518
+ rotary_dim=self.head_dim,
519
+ max_position=max_position,
520
+ base=self.rope_theta,
521
+ rope_scaling=rope_scaling,
522
+ )
523
+ self.attn = Attention(self.num_heads,
524
+ self.head_dim,
525
+ self.scaling,
526
+ num_kv_heads=self.num_kv_heads)
527
+
528
+ def forward(
529
+ self,
530
+ positions: torch.Tensor,
531
+ hidden_states: torch.Tensor,
532
+ ) -> torch.Tensor:
533
+ qkv = self.qkv_proj(hidden_states)
534
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
535
+ q, k = self.rotary_emb(positions, q, k)
536
+ o = self.attn(q, k, v)
537
+ output = self.o_proj(o)
538
+ return output
539
+
540
+
541
+ class Qwen2MLP(nn.Module):
542
+
543
+ def __init__(
544
+ self,
545
+ hidden_size: int,
546
+ intermediate_size: int,
547
+ hidden_act: str,
548
+ ) -> None:
549
+ super().__init__()
550
+ self.gate_up_proj = MergedColumnParallelLinear(
551
+ hidden_size,
552
+ [intermediate_size] * 2,
553
+ bias=False,
554
+ )
555
+ self.down_proj = RowParallelLinear(
556
+ intermediate_size,
557
+ hidden_size,
558
+ bias=False,
559
+ )
560
+ assert hidden_act == "silu"
561
+ self.act_fn = SiluAndMul()
562
+
563
+ def forward(self, x):
564
+ gate_up = self.gate_up_proj(x)
565
+ x = self.act_fn(gate_up)
566
+ x = self.down_proj(x)
567
+ return x
568
+
569
+
570
+ class Qwen2DecoderLayer(nn.Module):
571
+
572
+ def __init__(
573
+ self,
574
+ config: CosyVoice2LLMConfig,
575
+ ) -> None:
576
+ super().__init__()
577
+ self.hidden_size = config.hidden_size
578
+ self.self_attn = Qwen2Attention(
579
+ hidden_size=self.hidden_size,
580
+ num_heads=config.num_attention_heads,
581
+ num_kv_heads=config.num_key_value_heads,
582
+ max_position=config.max_position_embeddings,
583
+ rms_norm_eps=config.rms_norm_eps,
584
+ qkv_bias=getattr(config, "qkv_bias", True),
585
+ head_dim=getattr(config, "head_dim", None),
586
+ rope_theta=getattr(config, "rope_theta", 1000000.0),
587
+ rope_scaling=getattr(config, "rope_scaling", None),
588
+ )
589
+ self.mlp = Qwen2MLP(
590
+ hidden_size=config.hidden_size,
591
+ intermediate_size=config.intermediate_size,
592
+ hidden_act=config.hidden_act,
593
+ )
594
+ self.input_layernorm = RMSNorm(config.hidden_size,
595
+ eps=config.rms_norm_eps)
596
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
597
+ eps=config.rms_norm_eps)
598
+
599
+ def forward(
600
+ self,
601
+ positions: torch.Tensor,
602
+ hidden_states: torch.Tensor,
603
+ residual: torch.Tensor | None,
604
+ ) -> tuple[torch.Tensor, torch.Tensor]:
605
+ if residual is None:
606
+ residual = hidden_states
607
+ hidden_states = self.input_layernorm(hidden_states)
608
+ else:
609
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
610
+ hidden_states = self.self_attn(
611
+ positions=positions,
612
+ hidden_states=hidden_states,
613
+ )
614
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
615
+ hidden_states = self.mlp(hidden_states)
616
+ return hidden_states, residual