minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,160 @@
1
+ # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import time
15
+ from datetime import datetime
16
+
17
+ import s3tokenizer
18
+ import torch
19
+ from tqdm import tqdm
20
+
21
+ from stepaudio2.flashcosyvoice.config import Config, SamplingParams
22
+ from stepaudio2.flashcosyvoice.engine.llm_engine import LLMEngine
23
+ from stepaudio2.flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
24
+ from stepaudio2.flashcosyvoice.modules.hifigan import HiFTGenerator
25
+
26
+
27
+ class CosyVoice2(torch.nn.Module):
28
+ def __init__(self, config: Config = None):
29
+ super().__init__()
30
+ self.config = Config() if config is None else config
31
+
32
+ self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval()
33
+
34
+ self.llm = LLMEngine(**self.config.__dict__)
35
+
36
+ self.use_tqdm = torch.distributed.get_node_local_rank() == 0
37
+
38
+ self.flow = CausalMaskedDiffWithXvec()
39
+ if self.config.hf_config.fp16_flow:
40
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
41
+ tqdm.write(f"[{timestamp}] - [INFO] - Casting flow to fp16")
42
+ self.flow.half()
43
+ self.flow.load_state_dict(torch.load(f"{self.config.model}/flow.pt", map_location="cpu", weights_only=True), strict=True)
44
+ self.flow.cuda().eval()
45
+
46
+ self.hift = HiFTGenerator()
47
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{self.config.model}/hift.pt", map_location="cpu", weights_only=True).items()}
48
+ self.hift.load_state_dict(hift_state_dict, strict=True)
49
+ self.hift.cuda().eval()
50
+
51
+ @torch.inference_mode()
52
+ def forward(
53
+ self, prompt_mels_for_llm: torch.Tensor, prompt_mels_lens_for_llm: torch.Tensor,
54
+ prompt_text_tokens_for_llm: list[list[int]], text_tokens_for_llm: list[list[int]],
55
+ prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor,
56
+ spk_emb_for_flow: torch.Tensor,
57
+ sampling_params: SamplingParams | list[SamplingParams],
58
+ batch_size_flow: int,
59
+ only_llm: bool,
60
+ **kwargs, # for compatibility
61
+ ):
62
+ timing_stats = {}
63
+
64
+ # Audio tokenization
65
+ start_time = time.time()
66
+ prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
67
+ prompt_mels_for_llm.cuda(), prompt_mels_lens_for_llm.cuda()
68
+ )
69
+ timing_stats['audio_tokenization'] = time.time() - start_time
70
+
71
+ batch_size = prompt_speech_tokens.shape[0]
72
+ assert len(prompt_text_tokens_for_llm) == batch_size
73
+
74
+ # Prepare LLM inputs
75
+ start_time = time.time()
76
+ valid_prompt_speech_tokens = []
77
+ inputs = []
78
+ for i in range(batch_size):
79
+ speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
80
+ valid_prompt_speech_tokens.append(speech_tokens_i)
81
+ inputs.append([self.config.hf_config.speech_vocab_size] + prompt_text_tokens_for_llm[i] + text_tokens_for_llm[i] + [self.config.hf_config.speech_vocab_size + 1] + speech_tokens_i)
82
+ timing_stats['prepare_llm_inputs'] = time.time() - start_time
83
+
84
+ # LLM generation
85
+ start_time = time.time()
86
+ llm_outputs = self.llm.generate(inputs, sampling_params, use_tqdm=self.use_tqdm)
87
+ timing_stats['llm_generation'] = time.time() - start_time
88
+
89
+ results_dict = {
90
+ "prompt_speech_tokens": valid_prompt_speech_tokens,
91
+ "generated_speech_tokens": [o['token_ids'][:-1] for o in llm_outputs],
92
+ }
93
+ if only_llm:
94
+ return results_dict, timing_stats
95
+
96
+ # Prepare Flow inputs
97
+ start_time = time.time()
98
+ flow_inputs = []
99
+ flow_inputs_lens = []
100
+ for i, o in enumerate(llm_outputs):
101
+ generated_speech_tokens = o['token_ids'][:-1] # ignore last eos
102
+ prompt_speech_tokens = valid_prompt_speech_tokens[i]
103
+ flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
104
+ flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
105
+ flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
106
+ flow_inputs_lens = torch.tensor(flow_inputs_lens)
107
+ timing_stats['prepare_flow_inputs'] = time.time() - start_time
108
+
109
+ # Flow generation and HiFi-GAN generation (with batching)
110
+ total_batch_size = flow_inputs.shape[0]
111
+ generated_wavs = []
112
+ flow_total_time = 0.0
113
+ hifigan_total_time = 0.0
114
+
115
+ # Process in batches according to batch_size_flow, batch_size_flow <= total_batch_size
116
+ # NOTE(xcsong): When executing both LLM and Flow on the same GPU,
117
+ # Flow can easily fill up the SM and memory. Therefore, batch processing is required to avoid OOM.
118
+ num_batches = (total_batch_size + batch_size_flow - 1) // batch_size_flow
119
+ batch_iterator = range(0, total_batch_size, batch_size_flow)
120
+ if self.use_tqdm:
121
+ batch_iterator = tqdm(batch_iterator, desc="Generating wavs (Flow+HiFi-GAN)", leave=False, unit="batch",
122
+ total=num_batches, dynamic_ncols=True, position=self.config.rank + 1)
123
+
124
+ for start_idx in batch_iterator:
125
+ end_idx = min(start_idx + batch_size_flow, total_batch_size)
126
+ batch_flow_inputs = flow_inputs[start_idx:end_idx]
127
+ batch_flow_inputs_lens = flow_inputs_lens[start_idx:end_idx]
128
+ batch_prompt_mels = prompt_mels_for_flow[start_idx:end_idx]
129
+ batch_prompt_mels_lens = prompt_mels_lens_for_flow[start_idx:end_idx]
130
+ batch_spk_emb = spk_emb_for_flow[start_idx:end_idx]
131
+
132
+ # Flow generation for this batch
133
+ flow_start_time = time.time()
134
+ with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
135
+ batch_generated_mels, batch_generated_mels_lens = self.flow(
136
+ batch_flow_inputs.cuda(), batch_flow_inputs_lens.cuda(),
137
+ batch_prompt_mels.cuda(), batch_prompt_mels_lens.cuda(), batch_spk_emb.cuda(),
138
+ streaming=False, finalize=True
139
+ )
140
+ flow_total_time += time.time() - flow_start_time
141
+
142
+ # HiFi-GAN generation for this batch
143
+ hifigan_start_time = time.time()
144
+ batch_size_current = end_idx - start_idx
145
+ for i in range(batch_size_current):
146
+ mel = batch_generated_mels[i, :, batch_prompt_mels_lens[i].item():batch_generated_mels_lens[i].item()].unsqueeze(0)
147
+ wav, _ = self.hift(speech_feat=mel)
148
+ generated_wavs.append(wav)
149
+ hifigan_total_time += time.time() - hifigan_start_time
150
+
151
+ timing_stats['flow_generation'] = flow_total_time
152
+ timing_stats['hifigan_generation'] = hifigan_total_time
153
+
154
+ # Calculate total time and batch statistics
155
+ timing_stats['model.forward_total'] = sum(timing_stats.values())
156
+ timing_stats['batch_size'] = len(generated_wavs)
157
+ timing_stats['batch_size_flow'] = batch_size_flow
158
+
159
+ results_dict['generated_wavs'] = generated_wavs
160
+ return results_dict, timing_stats
@@ -0,0 +1 @@
1
+ # TODO(xcsong): Implement CosyVoice3 when it is released
File without changes
@@ -0,0 +1,114 @@
1
+ from collections import deque
2
+
3
+ import numpy as np
4
+ import xxhash
5
+
6
+ from stepaudio2.flashcosyvoice.engine.sequence import Sequence
7
+
8
+
9
+ class Block:
10
+
11
+ def __init__(self, block_id):
12
+ self.block_id = block_id
13
+ self.ref_count = 0
14
+ self.hash = -1
15
+ self.token_ids = []
16
+
17
+ def update(self, hash: int, token_ids: list[int]):
18
+ self.hash = hash
19
+ self.token_ids = token_ids
20
+
21
+ def reset(self):
22
+ self.ref_count = 1
23
+ self.hash = -1
24
+ self.token_ids = []
25
+
26
+
27
+ class BlockManager:
28
+
29
+ def __init__(self, num_blocks: int, block_size: int):
30
+ assert num_blocks > 0
31
+ self.block_size = block_size
32
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
33
+ self.hash_to_block_id: dict[int, int] = dict()
34
+ self.free_block_ids: deque[int] = deque(range(num_blocks))
35
+ self.used_block_ids: set[int] = set()
36
+
37
+ @classmethod
38
+ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
39
+ h = xxhash.xxh64()
40
+ if prefix != -1:
41
+ h.update(prefix.to_bytes(8, "little"))
42
+ h.update(np.array(token_ids).tobytes())
43
+ return h.intdigest()
44
+
45
+ def _allocate_block(self, block_id: int) -> Block:
46
+ block = self.blocks[block_id]
47
+ assert block.ref_count == 0
48
+ block.reset()
49
+ self.free_block_ids.remove(block_id)
50
+ self.used_block_ids.add(block_id)
51
+ return self.blocks[block_id]
52
+
53
+ def _deallocate_block(self, block_id: int) -> Block:
54
+ assert self.blocks[block_id].ref_count == 0
55
+ self.used_block_ids.remove(block_id)
56
+ self.free_block_ids.append(block_id)
57
+
58
+ def can_allocate(self, seq: Sequence) -> bool:
59
+ return len(self.free_block_ids) >= seq.num_blocks
60
+
61
+ def allocate(self, seq: Sequence):
62
+ assert not seq.block_table
63
+ h = -1
64
+ cache_miss = False
65
+ for i in range(seq.num_blocks):
66
+ token_ids = seq.block(i)
67
+ h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
68
+ block_id = self.hash_to_block_id.get(h, -1)
69
+ if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
70
+ cache_miss = True
71
+ if cache_miss:
72
+ block_id = self.free_block_ids[0]
73
+ block = self._allocate_block(block_id)
74
+ else:
75
+ seq.num_cached_tokens += self.block_size
76
+ if block_id in self.used_block_ids:
77
+ block = self.blocks[block_id]
78
+ block.ref_count += 1
79
+ else:
80
+ block = self._allocate_block(block_id)
81
+ if h != -1:
82
+ block.update(h, token_ids)
83
+ self.hash_to_block_id[h] = block_id
84
+ seq.block_table.append(block_id)
85
+
86
+ def deallocate(self, seq: Sequence):
87
+ for block_id in reversed(seq.block_table):
88
+ block = self.blocks[block_id]
89
+ block.ref_count -= 1
90
+ if block.ref_count == 0:
91
+ self._deallocate_block(block_id)
92
+ seq.num_cached_tokens = 0
93
+ seq.block_table.clear()
94
+
95
+ def can_append(self, seq: Sequence) -> bool:
96
+ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
97
+
98
+ def may_append(self, seq: Sequence):
99
+ block_table = seq.block_table
100
+ last_block = self.blocks[block_table[-1]]
101
+ if len(seq) % self.block_size == 1:
102
+ assert last_block.hash != -1
103
+ block_id = self.free_block_ids[0]
104
+ self._allocate_block(block_id)
105
+ block_table.append(block_id)
106
+ elif len(seq) % self.block_size == 0:
107
+ assert last_block.hash == -1
108
+ token_ids = seq.block(seq.num_blocks - 1)
109
+ prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
110
+ h = self.compute_hash(token_ids, prefix)
111
+ last_block.update(h, token_ids)
112
+ self.hash_to_block_id[h] = last_block.block_id
113
+ else:
114
+ assert last_block.hash == -1
@@ -0,0 +1,125 @@
1
+ import atexit
2
+ from dataclasses import fields
3
+ from time import perf_counter
4
+
5
+ import torch.multiprocessing as mp
6
+ from tqdm.auto import tqdm
7
+ from transformers import AutoTokenizer
8
+
9
+ from stepaudio2.flashcosyvoice.config import Config, SamplingParams
10
+ from stepaudio2.flashcosyvoice.engine.model_runner import ModelRunner
11
+ from stepaudio2.flashcosyvoice.engine.scheduler import Scheduler
12
+ from stepaudio2.flashcosyvoice.engine.sequence import Sequence
13
+
14
+
15
+ class LLMEngine:
16
+
17
+ def __init__(self, model, **kwargs):
18
+ config_fields = {field.name for field in fields(Config)}
19
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
20
+ config = Config(model, **config_kwargs)
21
+ self.ps = []
22
+ self.events = []
23
+ ctx = mp.get_context("spawn")
24
+ assert config.tensor_parallel_size == 1, "NOTE(xcsong): Currently only support tp=1"
25
+ for i in range(1, config.tensor_parallel_size):
26
+ event = ctx.Event()
27
+ process = ctx.Process(target=ModelRunner, args=(config, i, event))
28
+ process.start()
29
+ self.ps.append(process)
30
+ self.events.append(event)
31
+ if hasattr(config.hf_config, "speech_vocab_size"):
32
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
33
+ special_tokens = {
34
+ 'eos_token': '<|endoftext|>',
35
+ 'pad_token': '<|endoftext|>',
36
+ 'additional_special_tokens': [
37
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
38
+ '[breath]', '<strong>', '</strong>', '[noise]',
39
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
40
+ '[quick_breath]',
41
+ "<laughter>", "</laughter>",
42
+ "[hissing]", "[sigh]", "[vocalized-noise]",
43
+ "[lipsmack]", "[mn]"
44
+ ]
45
+ }
46
+ self.tokenizer = AutoTokenizer.from_pretrained(f"{config.model}/CosyVoice-BlankEN")
47
+ self.tokenizer.add_special_tokens(special_tokens)
48
+ self.skip_special_tokens = True
49
+ else:
50
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
51
+ if hasattr(config.hf_config, "eos_token_id"):
52
+ config.eos = config.hf_config.eos_token_id
53
+ else:
54
+ config.eos = self.tokenizer.eos_token_id
55
+ self.model_runner = ModelRunner(config, config.rank, self.events)
56
+ self.scheduler = Scheduler(config)
57
+ self.config = config
58
+ atexit.register(self.exit)
59
+
60
+ def exit(self):
61
+ self.model_runner.call("exit")
62
+ del self.model_runner
63
+ for p in self.ps:
64
+ p.join()
65
+
66
+ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
67
+ if isinstance(prompt, str):
68
+ prompt = self.tokenizer.encode(prompt)
69
+ seq = Sequence(prompt, sampling_params)
70
+ self.scheduler.add(seq)
71
+
72
+ def step(self):
73
+ seqs, is_prefill = self.scheduler.schedule()
74
+ token_ids = self.model_runner.call("run", seqs, is_prefill)
75
+ self.scheduler.postprocess(seqs, token_ids)
76
+ outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
77
+ num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
78
+ return outputs, num_tokens
79
+
80
+ def is_finished(self):
81
+ return self.scheduler.is_finished()
82
+
83
+ def generate(
84
+ self,
85
+ prompts: list[str] | list[list[int]],
86
+ sampling_params: SamplingParams | list[SamplingParams],
87
+ use_tqdm: bool = True,
88
+ ) -> list[str]:
89
+ if use_tqdm:
90
+ pbar = tqdm(total=len(prompts), desc="Generating tokens (LLM)", leave=False,
91
+ dynamic_ncols=True, position=self.config.rank + 1)
92
+ if not isinstance(sampling_params, list):
93
+ sampling_params = [sampling_params] * len(prompts)
94
+ for prompt, sp in zip(prompts, sampling_params):
95
+ self.add_request(prompt, sp)
96
+ outputs = {}
97
+ prefill_throughput = decode_throughput = instant_decode_throughput = 0.
98
+ total_decode_tokens = 0
99
+ total_decode_time = 0.
100
+ while not self.is_finished():
101
+ t = perf_counter()
102
+ output, num_tokens = self.step()
103
+ step_time = perf_counter() - t
104
+ if use_tqdm:
105
+ if num_tokens > 0:
106
+ prefill_throughput = num_tokens / step_time
107
+ else:
108
+ instant_decode_throughput = -num_tokens / step_time
109
+ total_decode_tokens += -num_tokens
110
+ total_decode_time += step_time
111
+ decode_throughput = total_decode_tokens / total_decode_time if total_decode_time > 0 else 0
112
+ pbar.set_postfix({
113
+ "Prefill": f"{int(prefill_throughput)}tok/s",
114
+ "AvgDecode": f"{int(decode_throughput)}tok/s",
115
+ "InstDecode": f"{int(instant_decode_throughput)}tok/s",
116
+ })
117
+ for seq_id, token_ids in output:
118
+ outputs[seq_id] = token_ids
119
+ if use_tqdm:
120
+ pbar.update(1)
121
+ outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
122
+ outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
123
+ if use_tqdm:
124
+ pbar.close()
125
+ return outputs