xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,39 @@
1
+ import io
2
+ import wave
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Optional, Tuple
5
+
6
+ import numpy as np
7
+
8
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
9
+
10
+
11
+ @dataclass
12
+ class InferenceResult:
13
+ code: Literal["header", "segment", "error", "final"]
14
+ audio: Optional[Tuple[int, np.ndarray | bytes]]
15
+ error: Optional[Exception]
16
+
17
+
18
+ def normalize_text(user_input: str, use_normalization: bool) -> str:
19
+ """Normalize user input text if needed."""
20
+ if use_normalization:
21
+ return ChnNormedText(raw_text=user_input).normalize()
22
+ else:
23
+ return user_input
24
+
25
+
26
+ def wav_chunk_header(
27
+ sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
28
+ ) -> bytes:
29
+ buffer = io.BytesIO()
30
+
31
+ with wave.open(buffer, "wb") as wav_file:
32
+ wav_file.setnchannels(channels)
33
+ wav_file.setsampwidth(bit_depth // 8)
34
+ wav_file.setframerate(sample_rate)
35
+
36
+ wav_header_bytes = buffer.getvalue()
37
+ buffer.close()
38
+
39
+ return wav_header_bytes
@@ -0,0 +1,57 @@
1
+ from typing import Callable
2
+
3
+ import torch
4
+ from loguru import logger
5
+
6
+ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
7
+
8
+
9
+ class VQManager:
10
+
11
+ def __init__(self):
12
+ # Make Pylance happy (attribut/method not defined...)
13
+ self.decoder_model: FireflyArchitecture
14
+ self.load_audio: Callable
15
+
16
+ def decode_vq_tokens(self, codes):
17
+ feature_lengths = torch.tensor(
18
+ [codes.shape[1]], device=self.decoder_model.device
19
+ )
20
+ logger.info(f"VQ features: {codes.shape}")
21
+
22
+ if isinstance(self.decoder_model, FireflyArchitecture):
23
+ return self.decoder_model.decode(
24
+ indices=codes[None],
25
+ feature_lengths=feature_lengths,
26
+ )[0].squeeze()
27
+
28
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
29
+
30
+ def encode_reference(self, reference_audio, enable_reference_audio):
31
+ if enable_reference_audio and reference_audio is not None:
32
+ # Load audios, and prepare basic info here
33
+ reference_audio_content = self.load_audio(
34
+ reference_audio, self.decoder_model.spec_transform.sample_rate
35
+ )
36
+
37
+ audios = torch.from_numpy(reference_audio_content).to(
38
+ self.decoder_model.device
39
+ )[None, None, :]
40
+ audio_lengths = torch.tensor(
41
+ [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
42
+ )
43
+ logger.info(
44
+ f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
45
+ )
46
+
47
+ # VQ Encoder
48
+ if isinstance(self.decoder_model, FireflyArchitecture):
49
+ prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
50
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
51
+ else:
52
+ raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
53
+ else:
54
+ prompt_tokens = None
55
+ logger.info("No reference audio provided")
56
+
57
+ return prompt_tokens
@@ -1,11 +1,11 @@
1
- # import pyrootutils
1
+ import pyrootutils
2
2
  import torch
3
3
  import torch.nn.functional as F
4
4
  from matplotlib import pyplot as plt
5
5
  from transformers import AutoTokenizer
6
6
 
7
7
  # register eval resolver and root
8
- # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
8
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
9
9
 
10
10
  from torch.utils.data import DataLoader
11
11
 
@@ -17,9 +17,16 @@ from loguru import logger
17
17
  from tqdm import tqdm
18
18
  from transformers import AutoTokenizer
19
19
 
20
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.conversation import (
21
+ CODEBOOK_PAD_TOKEN_ID,
22
+ Conversation,
23
+ Message,
24
+ TextPart,
25
+ VQPart,
26
+ )
21
27
  from fish_speech.models.text2semantic.llama import BaseModelArgs
22
28
  from fish_speech.text import clean_text, split_text
29
+ from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
23
30
 
24
31
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
32
  torch._inductor.config.coordinate_descent_tuning = True
@@ -145,8 +152,8 @@ def decode_one_token_ar_agent(
145
152
  model: DualARTransformer,
146
153
  x: torch.Tensor,
147
154
  input_pos: torch.Tensor,
155
+ semantic_ids: list,
148
156
  previous_tokens: torch.Tensor = None,
149
- semantic_id: int = 32003,
150
157
  **sampling_kwargs,
151
158
  ) -> torch.Tensor:
152
159
  # print(x, input_pos)
@@ -190,19 +197,13 @@ def decode_one_token_ar_agent(
190
197
  codebooks.append(a)
191
198
 
192
199
  codebooks = torch.stack(codebooks, dim=1)
200
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
193
201
  codebooks[:, 1:, :] = torch.masked_fill(
194
- codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
202
+ codebooks[:, 1:, :],
203
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
204
+ CODEBOOK_PAD_TOKEN_ID,
195
205
  )
196
206
 
197
- # for i in range(codebooks.size(1) - 1):
198
- # codebooks[:, i + 1, :] = torch.masked_fill(
199
- # codebooks[:, i + 1, :],
200
- # codebooks[:, :1, :] != semantic_id,
201
- # CODEBOOK_PAD_TOKEN_ID + i * 1024,
202
- # )
203
-
204
- # print(codebooks)
205
-
206
207
  return codebooks
207
208
 
208
209
 
@@ -210,8 +211,8 @@ def decode_one_token_naive_agent(
210
211
  model: NaiveTransformer,
211
212
  x: torch.Tensor,
212
213
  input_pos: torch.Tensor,
214
+ semantic_ids: list,
213
215
  previous_tokens: torch.Tensor = None,
214
- semantic_id: int = 32003,
215
216
  **sampling_kwargs,
216
217
  ) -> torch.Tensor:
217
218
  x = model.forward_generate(x, input_pos)
@@ -236,8 +237,11 @@ def decode_one_token_naive_agent(
236
237
  )
237
238
 
238
239
  codebooks = torch.stack(codebooks, dim=1)
240
+ semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
239
241
  codebooks[:, 1:, :] = torch.masked_fill(
240
- codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
242
+ codebooks[:, 1:, :],
243
+ ~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
244
+ CODEBOOK_PAD_TOKEN_ID,
241
245
  )
242
246
 
243
247
  return codebooks
@@ -247,8 +251,8 @@ def decode_one_token_ar(
247
251
  model: DualARTransformer,
248
252
  x: torch.Tensor,
249
253
  input_pos: torch.Tensor,
254
+ semantic_ids: list,
250
255
  previous_tokens: torch.Tensor = None,
251
- semantic_id: int = 0,
252
256
  **sampling_kwargs,
253
257
  ) -> torch.Tensor:
254
258
  x = model.forward_generate(x, input_pos)
@@ -261,21 +265,32 @@ def decode_one_token_ar(
261
265
  codebooks = [
262
266
  sample(
263
267
  x.logits,
264
- previous_tokens=None, # Disable repetition penalty for the token codebook
268
+ previous_tokens=(
269
+ previous_tokens[0] if previous_tokens is not None else None
270
+ ), # Disable repetition penalty for the token codebook
265
271
  **sampling_kwargs_main,
266
272
  )[0]
267
273
  ]
268
274
 
269
- x = x.hidden_states
275
+ hidden_states = x.hidden_states
270
276
 
271
277
  # Cleanup the cache
272
278
  for layer in model.fast_layers:
273
279
  layer.attention.kv_cache.k_cache.fill_(0)
274
280
  layer.attention.kv_cache.v_cache.fill_(0)
275
281
 
276
- for codebook_idx in range(model.config.num_codebooks):
277
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
278
- logits = model.forward_generate_fast(x, input_pos)
282
+ input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
283
+ model.forward_generate_fast(hidden_states, input_pos)
284
+ a = codebooks[0] - model.tokenizer.semantic_begin_id
285
+ a[a < 0] = 0
286
+ hidden_states = model.fast_embeddings(a)
287
+ codebooks.append(a)
288
+
289
+ for codebook_idx in range(1, model.config.num_codebooks):
290
+ input_pos = torch.tensor(
291
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
292
+ )
293
+ logits = model.forward_generate_fast(hidden_states, input_pos)
279
294
  a = sample(
280
295
  logits,
281
296
  previous_tokens=(
@@ -285,14 +300,16 @@ def decode_one_token_ar(
285
300
  ),
286
301
  **sampling_kwargs,
287
302
  )[0]
288
- x = model.fast_embeddings(a)
303
+ hidden_states = model.fast_embeddings(a)
289
304
  codebooks.append(a)
290
305
 
291
306
  codebooks = torch.stack(codebooks, dim=0)
292
- codebooks[1:, :] = torch.masked_fill(
293
- codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
294
- )
307
+ # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
308
+ # codebooks[1:, :] = torch.masked_fill(
309
+ # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
310
+ # )
295
311
 
312
+ # print(codebooks)
296
313
  return codebooks
297
314
 
298
315
 
@@ -337,9 +354,8 @@ def decode_n_tokens(
337
354
  cur_token: torch.Tensor,
338
355
  input_pos: torch.Tensor,
339
356
  num_new_tokens: int,
340
- im_end_id: int = 4,
357
+ semantic_ids: list,
341
358
  decode_one_token=decode_one_token_naive,
342
- semantic_id: int = 0,
343
359
  **sampling_kwargs,
344
360
  ):
345
361
  previous_tokens = torch.zeros(
@@ -368,7 +384,7 @@ def decode_n_tokens(
368
384
  x=cur_token,
369
385
  input_pos=input_pos,
370
386
  previous_tokens=window,
371
- semantic_id=semantic_id,
387
+ semantic_ids=semantic_ids,
372
388
  **sampling_kwargs,
373
389
  )
374
390
 
@@ -378,7 +394,7 @@ def decode_n_tokens(
378
394
  model.config.num_codebooks + 1, -1
379
395
  )
380
396
 
381
- if cur_token[0, 0, -1] == im_end_id:
397
+ if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
382
398
  break
383
399
 
384
400
  return previous_tokens[:, : i + 1]
@@ -391,7 +407,6 @@ def generate(
391
407
  model: NaiveTransformer,
392
408
  prompt: torch.Tensor,
393
409
  max_new_tokens: int,
394
- im_end_id: int = 4,
395
410
  decode_one_token=decode_one_token_naive,
396
411
  **sampling_kwargs,
397
412
  ) -> torch.Tensor:
@@ -401,7 +416,10 @@ def generate(
401
416
 
402
417
  # create an empty tensor of the expected final shape and fill in the current tokens
403
418
  T = prompt.size(1)
404
- semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
419
+ # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
420
+ semantic_ids = [
421
+ model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
422
+ ]
405
423
 
406
424
  if max_new_tokens:
407
425
  if T + max_new_tokens > model.config.max_seq_len:
@@ -435,7 +453,7 @@ def generate(
435
453
  model,
436
454
  prompt.view(1, codebook_dim, -1),
437
455
  input_pos,
438
- semantic_id=semantic_id,
456
+ semantic_ids=semantic_ids,
439
457
  **sampling_kwargs,
440
458
  )
441
459
  seq[:, T : T + 1] = next_token
@@ -446,9 +464,8 @@ def generate(
446
464
  next_token.view(1, codebook_dim, -1),
447
465
  input_pos,
448
466
  max_new_tokens - 1,
449
- im_end_id=im_end_id,
450
467
  decode_one_token=decode_one_token,
451
- semantic_id=semantic_id,
468
+ semantic_ids=semantic_ids,
452
469
  **sampling_kwargs,
453
470
  )
454
471
  # x = torch.cat(generated_tokens, dim=1)
@@ -463,8 +480,8 @@ def decode_n_tokens_agent(
463
480
  cur_token: torch.Tensor,
464
481
  input_pos: torch.Tensor,
465
482
  num_new_tokens: int,
483
+ semantic_ids: list,
466
484
  im_end_id: int = 4,
467
- semantic_id: int = 32003,
468
485
  decode_one_token=decode_one_token_naive_agent,
469
486
  early_stop_threshold: float = 0.6,
470
487
  **sampling_kwargs,
@@ -495,7 +512,7 @@ def decode_n_tokens_agent(
495
512
  x=cur_token,
496
513
  input_pos=input_pos,
497
514
  previous_tokens=window,
498
- semantic_id=semantic_id,
515
+ semantic_ids=semantic_ids,
499
516
  **sampling_kwargs,
500
517
  )
501
518
 
@@ -529,8 +546,8 @@ def generate_agent(
529
546
  model: BaseTransformer,
530
547
  prompt: torch.Tensor,
531
548
  max_new_tokens: int,
549
+ semantic_ids: list,
532
550
  im_end_id: int = 4,
533
- semantic_id: int = 32003,
534
551
  decode_one_token=decode_one_token_naive_agent,
535
552
  num_samples: int = 1,
536
553
  early_stop_threshold: float = 0.6,
@@ -574,7 +591,7 @@ def generate_agent(
574
591
  model,
575
592
  prompt,
576
593
  input_pos,
577
- semantic_id=semantic_id,
594
+ semantic_ids=semantic_ids,
578
595
  **sampling_kwargs,
579
596
  ).view(num_samples, codebook_dim, -1)
580
597
  yield next_token.cpu()
@@ -587,7 +604,7 @@ def generate_agent(
587
604
  input_pos,
588
605
  max_new_tokens - 1,
589
606
  im_end_id=im_end_id,
590
- semantic_id=semantic_id,
607
+ semantic_ids=semantic_ids,
591
608
  decode_one_token=decode_one_token,
592
609
  early_stop_threshold=early_stop_threshold,
593
610
  **sampling_kwargs,
@@ -602,65 +619,63 @@ def encode_tokens(
602
619
  num_codebooks=4,
603
620
  ):
604
621
  string = clean_text(string)
605
- string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
606
622
 
607
- new_tokens = tokenizer.encode(
608
- string,
609
- add_special_tokens=False,
610
- max_length=10**6,
611
- truncation=False,
623
+ messages = []
624
+ messages.append(
625
+ Message(
626
+ role="user",
627
+ parts=[TextPart(text=string)],
628
+ cal_loss=False,
629
+ )
612
630
  )
613
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
614
631
 
615
- # Codebooks
616
- zeros = (
617
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
618
- * CODEBOOK_PAD_TOKEN_ID
619
- )
620
- prompt = torch.cat((tokens, zeros), dim=0)
632
+ if prompt_tokens is not None:
633
+ if prompt_tokens.ndim == 3:
634
+ assert (
635
+ prompt_tokens.shape[0] == 1
636
+ ), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
637
+ prompt_tokens = prompt_tokens[0]
621
638
 
622
- if prompt_tokens is None:
623
- return prompt
639
+ assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
624
640
 
625
- # Get prompt tokens
626
- if prompt_tokens.ndim == 3:
627
- assert (
628
- prompt_tokens.shape[0] == 1
629
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
630
- prompt_tokens = prompt_tokens[0]
641
+ if prompt_tokens.shape[0] > num_codebooks:
642
+ logger.warning(
643
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
644
+ )
645
+ prompt_tokens = prompt_tokens[:num_codebooks]
631
646
 
632
- assert prompt_tokens.ndim == 2
633
- data = prompt_tokens + 1
647
+ vq_part = VQPart(codes=prompt_tokens.to(device))
634
648
 
635
- if prompt_tokens.shape[0] > num_codebooks:
636
- logger.warning(
637
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
649
+ messages.append(
650
+ Message(
651
+ role="assistant",
652
+ parts=[TextPart(text="<|voice|>"), vq_part],
653
+ cal_loss=False,
654
+ )
655
+ )
656
+ else:
657
+ messages.append(
658
+ Message(
659
+ role="assistant",
660
+ parts=[TextPart(text="<|voice|>")],
661
+ cal_loss=False,
662
+ add_im_end=False,
663
+ )
638
664
  )
639
- data = data[:num_codebooks]
640
-
641
- # Add pad token for each codebook
642
- data = torch.cat(
643
- (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
644
- dim=1,
645
- )
646
665
 
647
- # Since 1.0, we use <|semantic|>
648
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
649
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
650
- main_token_ids = (
651
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
666
+ conversation = Conversation(messages=messages)
667
+ # conversation.visualize(tokenizer)
668
+ encoded = conversation.encode_for_inference(
669
+ tokenizer=tokenizer,
670
+ num_codebooks=num_codebooks,
652
671
  )
653
- main_token_ids[0, -1] = end_token_id
654
-
655
- data = torch.cat((main_token_ids, data), dim=0)
656
- prompt = torch.cat((prompt, data), dim=1)
657
672
 
658
- return prompt
673
+ return encoded.to(device)
659
674
 
660
675
 
661
676
  def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
662
677
  model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
663
- checkpoint_path, load_weights=True
678
+ checkpoint_path, load_weights=True, is_agent=is_agent
664
679
  )
665
680
 
666
681
  model = model.to(device=device, dtype=precision)
@@ -729,11 +744,26 @@ def generate_long(
729
744
 
730
745
  model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
731
746
  tokenizer = model.tokenizer
732
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
747
+ im_end_id = tokenizer.get_token_id("<|im_end|>")
733
748
 
734
749
  encoded = []
735
750
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
736
- encoded_prompts = []
751
+ encoded_prompts = [
752
+ Conversation(
753
+ messages=[
754
+ Message(
755
+ role="system",
756
+ parts=[TextPart(text="Speak out the provided text.")],
757
+ cal_loss=False,
758
+ )
759
+ ]
760
+ )
761
+ .encode_for_inference(
762
+ tokenizer=tokenizer,
763
+ num_codebooks=model.config.num_codebooks,
764
+ )
765
+ .to(device)
766
+ ]
737
767
 
738
768
  if use_prompt:
739
769
  for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
@@ -812,7 +842,6 @@ def generate_long(
812
842
  model=model,
813
843
  prompt=cat_encoded,
814
844
  max_new_tokens=max_new_tokens,
815
- im_end_id=im_end_id,
816
845
  decode_one_token=decode_one_token,
817
846
  temperature=temperature,
818
847
  top_p=top_p,
@@ -842,12 +871,11 @@ def generate_long(
842
871
  )
843
872
 
844
873
  # Put the generated tokens
845
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
846
- codes = y[1:, prompt_length:-1].clone()
847
- codes = codes - 1
874
+ # since there is <im_end>, we remove last token
875
+ codes = y[1:, prompt_length + 1 :].clone()
848
876
  assert (codes >= 0).all(), f"Negative code found"
849
877
 
850
- decoded = y[:, prompt_length:-1].clone()
878
+ decoded = y[:, prompt_length:].clone()
851
879
  # But for global encoding, we should keep the <im_end> token
852
880
 
853
881
  global_encoded.append(decoded)
@@ -0,0 +1,104 @@
1
+ import os
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import pyrootutils
6
+ import torch
7
+ from loguru import logger
8
+
9
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
10
+
11
+ from tools.inference_engine import TTSInferenceEngine
12
+ from tools.llama.generate import launch_thread_safe_queue
13
+ from tools.schema import ServeTTSRequest
14
+ from tools.vqgan.inference import load_model as load_decoder_model
15
+ from tools.webui import build_app
16
+ from tools.webui.inference import get_inference_wrapper
17
+
18
+ # Make einx happy
19
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
20
+
21
+
22
+ def parse_args():
23
+ parser = ArgumentParser()
24
+ parser.add_argument(
25
+ "--llama-checkpoint-path",
26
+ type=Path,
27
+ default="checkpoints/fish-speech-1.5",
28
+ )
29
+ parser.add_argument(
30
+ "--decoder-checkpoint-path",
31
+ type=Path,
32
+ default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
33
+ )
34
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
35
+ parser.add_argument("--device", type=str, default="cuda")
36
+ parser.add_argument("--half", action="store_true")
37
+ parser.add_argument("--compile", action="store_true")
38
+ parser.add_argument("--max-gradio-length", type=int, default=0)
39
+ parser.add_argument("--theme", type=str, default="light")
40
+
41
+ return parser.parse_args()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ args = parse_args()
46
+ args.precision = torch.half if args.half else torch.bfloat16
47
+
48
+ # Check if MPS or CUDA is available
49
+ if torch.backends.mps.is_available():
50
+ args.device = "mps"
51
+ logger.info("mps is available, running on mps.")
52
+ elif not torch.cuda.is_available():
53
+ logger.info("CUDA is not available, running on CPU.")
54
+ args.device = "cpu"
55
+
56
+ logger.info("Loading Llama model...")
57
+ llama_queue = launch_thread_safe_queue(
58
+ checkpoint_path=args.llama_checkpoint_path,
59
+ device=args.device,
60
+ precision=args.precision,
61
+ compile=args.compile,
62
+ )
63
+
64
+ logger.info("Loading VQ-GAN model...")
65
+ decoder_model = load_decoder_model(
66
+ config_name=args.decoder_config_name,
67
+ checkpoint_path=args.decoder_checkpoint_path,
68
+ device=args.device,
69
+ )
70
+
71
+ logger.info("Decoder model loaded, warming up...")
72
+
73
+ # Create the inference engine
74
+ inference_engine = TTSInferenceEngine(
75
+ llama_queue=llama_queue,
76
+ decoder_model=decoder_model,
77
+ compile=args.compile,
78
+ precision=args.precision,
79
+ )
80
+
81
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
82
+ list(
83
+ inference_engine.inference(
84
+ ServeTTSRequest(
85
+ text="Hello world.",
86
+ references=[],
87
+ reference_id=None,
88
+ max_new_tokens=1024,
89
+ chunk_length=200,
90
+ top_p=0.7,
91
+ repetition_penalty=1.5,
92
+ temperature=0.7,
93
+ format="wav",
94
+ )
95
+ )
96
+ )
97
+
98
+ logger.info("Warming up done, launching the web UI...")
99
+
100
+ # Get the inference function with the immutable arguments
101
+ inference_fct = get_inference_wrapper(inference_engine)
102
+
103
+ app = build_app(inference_fct, args.theme)
104
+ app.launch(show_api=True)