xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,496 @@
1
+ import random
2
+ from dataclasses import dataclass
3
+ from itertools import chain
4
+ from pathlib import Path
5
+ from random import Random
6
+ from typing import Optional, Union
7
+
8
+ import numpy as np
9
+ import pyarrow.parquet as pq
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from datasets.download.streaming_download_manager import xopen
13
+ from huggingface_hub import HfApi
14
+ from lightning import LightningDataModule
15
+ from torch.distributed import get_rank, get_world_size, is_initialized
16
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
+ from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
+ from fish_speech.text.clean import clean_text
23
+ from fish_speech.utils import RankedLogger
24
+ from fish_speech.utils.braceexpand import braceexpand
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+
29
+ def split_by_rank_worker(files):
30
+ # We need to know the total number of devices
31
+ # to split the data properly
32
+
33
+ total_devices = 1
34
+ if is_initialized():
35
+ total_devices = get_world_size()
36
+
37
+ worker_info = get_worker_info()
38
+ if worker_info is not None:
39
+ total_devices *= worker_info.num_workers
40
+
41
+ if len(files) < total_devices:
42
+ # Repeat the files N times to match the number of devices
43
+ files = files * (total_devices // len(files) + 1)
44
+
45
+ # DDP
46
+ if is_initialized():
47
+ files = files[get_rank() :: get_world_size()]
48
+
49
+ # Split by worker
50
+ if worker_info is not None:
51
+ files = files[worker_info.id :: worker_info.num_workers]
52
+
53
+ return files
54
+
55
+
56
+ class AutoTextSemanticInstructionDataset(IterableDataset):
57
+ """
58
+ Auto Augment Dataset by Speaker
59
+
60
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
61
+ 2. Automatically normalize the text
62
+
63
+ For interactive mode, we use the following format (multiple sequences):
64
+ <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
65
+
66
+ For non-interactive mode, we use the following format (one long sequence):
67
+ <s> [INST] text [/INST] ... </s>
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ proto_files: list[str],
73
+ seed: int = 42,
74
+ interactive_prob: float = 0.5,
75
+ max_length: int = 1024,
76
+ tokenizer: AutoTokenizer = None,
77
+ use_speaker: bool | float = True,
78
+ causal: bool = True,
79
+ num_codebooks: Optional[int] = None,
80
+ skip_text_prob: float = 0.0,
81
+ ):
82
+ """
83
+ Args:
84
+ proto_files: proto buf files if using local data
85
+ seed: random seed
86
+ interactive_prob: probability to use interactive mode
87
+ max_length: max length of the text
88
+ tokenizer: tokenizer
89
+ use_speaker: include speaker information in the prompt
90
+ causal: use causal sampling when using local data, disable will lead to random sampling
91
+ num_codebooks: number of codebooks, if None, it will be automatically detected
92
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
93
+ """
94
+
95
+ super().__init__()
96
+
97
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
98
+
99
+ self.seed = seed
100
+ self.max_length = max_length
101
+ self.tokenizer = tokenizer
102
+ self.interactive_prob = interactive_prob
103
+ self.use_speaker = use_speaker
104
+ self.proto_files = proto_files
105
+ self.causal = causal
106
+ self.num_codebooks = num_codebooks
107
+ self.skip_text_prob = skip_text_prob
108
+
109
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
110
+ self.groups = None
111
+
112
+ def init_mock_data_server(self):
113
+ if self.groups is not None:
114
+ return
115
+
116
+ # Expand the proto files
117
+ expanded_proto_files = []
118
+ for filename in self.proto_files:
119
+ for i in braceexpand(filename):
120
+ i = Path(i)
121
+ if i.is_file():
122
+ expanded_proto_files.append(i)
123
+ elif i.is_dir():
124
+ expanded_proto_files.extend(i.rglob("*.proto"))
125
+ expanded_proto_files.extend(i.rglob("*.protos"))
126
+ else:
127
+ raise ValueError(f"{i} is not a file or directory")
128
+
129
+ expanded_proto_files = sorted(expanded_proto_files)
130
+ Random(self.seed).shuffle(expanded_proto_files)
131
+
132
+ self.groups = []
133
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
134
+ log.info(
135
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
136
+ )
137
+
138
+ count = 0
139
+ for filename in shard_proto_files:
140
+ with open(filename, "rb") as f:
141
+ for text_data in read_pb_stream(f):
142
+ self.groups.append(text_data)
143
+ count += 1
144
+
145
+ log.info(f"Read total {count} groups of data")
146
+
147
+ # Shuffle the lines
148
+ Random(self.seed).shuffle(self.groups)
149
+ self.group_weights = [len(i.sentences) for i in self.groups]
150
+
151
+ def __iter__(self):
152
+ while True:
153
+ yield self.augment()
154
+
155
+ def tokenize_sentence(self, sentence: str):
156
+ sentence = clean_text(sentence)
157
+ tokens = self.tokenizer.encode(
158
+ f"{sentence}",
159
+ max_length=10**6,
160
+ add_special_tokens=False,
161
+ truncation=False,
162
+ )
163
+ return sentence, len(tokens)
164
+
165
+ def sample_data(self):
166
+ if self.groups is None:
167
+ self.init_mock_data_server()
168
+
169
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
170
+ num_samples = self.max_length // 20
171
+
172
+ # choice group based on their number of samples
173
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
174
+
175
+ if self.causal:
176
+ # Sample in order
177
+ if num_samples >= len(group.sentences):
178
+ samples = group.sentences
179
+ else:
180
+ begin = random.randint(0, len(group.sentences) - num_samples)
181
+ samples = group.sentences[begin : begin + num_samples]
182
+ else:
183
+ samples = random.choices(
184
+ group.sentences, k=min(num_samples, len(group.sentences))
185
+ )
186
+
187
+ return SampledData(
188
+ source=group.source,
189
+ name=group.name,
190
+ samples=samples,
191
+ )
192
+
193
+ def augment(self):
194
+ final_text, final_semantic = [], []
195
+ response = self.sample_data()
196
+ if len(response.samples) == 0:
197
+ # Invalid group
198
+ return None
199
+
200
+ samples = list(response.samples)
201
+ idx = 0
202
+ use_interactive = random.random() < self.interactive_prob
203
+
204
+ if use_interactive is False:
205
+ # Random sample based on speaker using a truncated normal distribution
206
+ a = torch.tensor([0], dtype=torch.float32)
207
+ torch.nn.init.trunc_normal_(
208
+ a,
209
+ mean=self.max_length // 2,
210
+ std=self.max_length // 4,
211
+ a=10,
212
+ b=self.max_length,
213
+ )
214
+ remaining_tokens = a.long().item() - 4
215
+ else:
216
+ remaining_tokens = self.max_length
217
+
218
+ # Use speaker
219
+ if isinstance(self.use_speaker, float):
220
+ use_speaker = random.random() < self.use_speaker
221
+ else:
222
+ use_speaker = self.use_speaker
223
+
224
+ all_tokens, all_labels = [], []
225
+ while remaining_tokens > 0 and len(samples) > 0:
226
+ sentence = samples.pop(0)
227
+
228
+ text = random.choice(sentence.texts)
229
+ text, length = self.tokenize_sentence(text)
230
+ remaining_tokens -= length + len(sentence.semantics[0].values)
231
+
232
+ if use_interactive is False:
233
+ final_text.append(text)
234
+ final_semantic.append(sentence.semantics)
235
+ else:
236
+ # For interactive mode, we only apply speaker for the first sentence
237
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
238
+ tokens, labels = self.pack_sentences(
239
+ sentences=[text],
240
+ semantics=[sentence.semantics],
241
+ speaker=response.name if use_speaker else None,
242
+ skip_text=random.random() < self.skip_text_prob,
243
+ )
244
+
245
+ all_tokens.append(tokens)
246
+ all_labels.append(labels)
247
+
248
+ idx += 1
249
+
250
+ if use_interactive is False:
251
+ tokens, labels = self.pack_sentences(
252
+ final_text,
253
+ semantics=final_semantic,
254
+ speaker=response.name if use_speaker else None,
255
+ )
256
+ all_tokens.append(tokens)
257
+ all_labels.append(labels)
258
+
259
+ tokens = torch.cat(all_tokens, dim=1)
260
+ labels = torch.cat(all_labels, dim=1)
261
+
262
+ # Verify that the length is correct
263
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
264
+
265
+ data = {"tokens": tokens, "labels": labels}
266
+
267
+ return data
268
+
269
+ def pack_sentences(
270
+ self,
271
+ sentences: list[str],
272
+ semantics: list,
273
+ speaker: Optional[str] = None,
274
+ skip_text: bool = False,
275
+ ):
276
+ if speaker is None:
277
+ speaker = "assistant"
278
+
279
+ cated_sentences = " ".join(sentences)
280
+ if skip_text:
281
+ cated_sentences = "<|skip_text|>"
282
+
283
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
284
+ final_text = final_text + f"<|im_start|>{speaker}\n"
285
+
286
+ encoded = self.tokenizer.encode(
287
+ final_text,
288
+ add_special_tokens=False,
289
+ truncation=False,
290
+ max_length=10**6,
291
+ )
292
+ semantic_length = sum([len(i[0].values) for i in semantics])
293
+ prompt_length = len(encoded)
294
+ num_codebooks = (
295
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
296
+ )
297
+
298
+ # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
299
+ tokens = (
300
+ encoded
301
+ + [self.semantic_token_id] * semantic_length
302
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
303
+ )
304
+
305
+ # Codebook bos/padding: 0, eos: 1
306
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
307
+ for segment in semantics:
308
+ for book_idx, book in zip(range(num_codebooks), segment):
309
+ for j in book.values:
310
+ codes[book_idx].append(int(j) + 1)
311
+
312
+ for book in codes:
313
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
314
+
315
+ tokens = [tokens] + codes
316
+
317
+ tokens = torch.tensor(tokens, dtype=torch.long)
318
+ labels = tokens.clone()
319
+
320
+ if skip_text:
321
+ # If text is not provided, the sentence is used for condition only, all labels are -100
322
+ torch.fill_(labels, -100)
323
+ return tokens, labels
324
+
325
+ # Mask out the <s> tokens for semantic, predict semantic tokens only
326
+ # Since we don't mask out the input tokens, the language modeling still works
327
+ labels[1:, :prompt_length] = -100
328
+
329
+ tokens = tokens[:, :-1]
330
+ labels = labels[:, 1:]
331
+
332
+ # Verify the padding is correct, and the last token is eos
333
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
334
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
335
+
336
+ return tokens, labels
337
+
338
+
339
+ @dataclass
340
+ class TextDataCollator:
341
+ tokenizer: AutoTokenizer
342
+ max_length: int = 1024
343
+
344
+ def __call__(self, examples):
345
+ if "negative_tokens" in examples:
346
+ positive_examples = []
347
+ negative_examples = []
348
+
349
+ for i in examples:
350
+ positive_examples.append(
351
+ {
352
+ "tokens": i["tokens"],
353
+ "labels": i["labels"],
354
+ }
355
+ )
356
+ negative_examples.append(
357
+ {
358
+ "tokens": i["negative_tokens"],
359
+ "labels": i["negative_labels"],
360
+ }
361
+ )
362
+
363
+ examples = positive_examples + negative_examples
364
+
365
+ return self.batchify(examples)
366
+
367
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
368
+ tokens, attention_masks, labels = [], [], []
369
+
370
+ # Calculate the max length
371
+ max_tokens_length = 0
372
+ for example in examples:
373
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
374
+ max_tokens_length = min(max_tokens_length, self.max_length)
375
+
376
+ for example in examples:
377
+ _tokens = example[tokens_key][:, :max_tokens_length]
378
+ _labels = example[labels_key][:, :max_tokens_length]
379
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
380
+ tokens_length = _tokens.size(1)
381
+ _attention_mask[:tokens_length] = False
382
+
383
+ assert tokens_length == _labels.size(
384
+ 1
385
+ ), f"{tokens_length} != {_labels.size(1)}"
386
+
387
+ if tokens_length < max_tokens_length:
388
+ _tokens = F.pad(
389
+ _tokens,
390
+ (0, max_tokens_length - tokens_length),
391
+ value=self.tokenizer.eos_token_id,
392
+ )
393
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
394
+ _labels = F.pad(
395
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
396
+ )
397
+
398
+ tokens.append(_tokens)
399
+ attention_masks.append(_attention_mask)
400
+ labels.append(_labels)
401
+
402
+ tokens = torch.stack(tokens, dim=0)
403
+ attention_masks = torch.stack(attention_masks, dim=0)
404
+ labels = torch.stack(labels, dim=0)
405
+
406
+ return {
407
+ "inputs": tokens,
408
+ "attention_masks": attention_masks,
409
+ "labels": labels,
410
+ }
411
+
412
+
413
+ class InterleaveDataset(IterableDataset):
414
+ def __init__(
415
+ self,
416
+ datasets: list[IterableDataset],
417
+ probabilities: list[float],
418
+ seed: int = 42,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.datasets = datasets
423
+ self.probabilities = probabilities
424
+ self.seed = seed
425
+
426
+ def __iter__(self):
427
+ rng = np.random.default_rng(self.seed)
428
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
429
+
430
+ while True:
431
+ # Random choice one
432
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
433
+ dataset_iterator = dataset_iterators[dataset_idx]
434
+
435
+ try:
436
+ yield next(dataset_iterator)
437
+ except StopIteration:
438
+ # Exhausted, create a new iterator
439
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
440
+ yield next(dataset_iterators[dataset_idx])
441
+
442
+
443
+ class SemanticDataModule(LightningDataModule):
444
+ def __init__(
445
+ self,
446
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
447
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
448
+ batch_size: int = 32,
449
+ tokenizer: AutoTokenizer = None,
450
+ max_length: int = 1024,
451
+ num_workers: int = 4,
452
+ ):
453
+ super().__init__()
454
+
455
+ self.train_dataset = train_dataset
456
+ self.val_dataset = val_dataset
457
+ self.batch_size = batch_size
458
+ self.tokenizer = tokenizer
459
+ self.max_length = max_length
460
+ self.num_workers = num_workers
461
+
462
+ def train_dataloader(self):
463
+ return DataLoader(
464
+ self.train_dataset,
465
+ batch_size=self.batch_size,
466
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
467
+ num_workers=self.num_workers,
468
+ persistent_workers=True,
469
+ )
470
+
471
+ def val_dataloader(self):
472
+ return DataLoader(
473
+ self.val_dataset,
474
+ batch_size=self.batch_size,
475
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
476
+ num_workers=self.num_workers,
477
+ persistent_workers=True,
478
+ )
479
+
480
+
481
+ if __name__ == "__main__":
482
+ from tqdm import tqdm
483
+
484
+ ds = AutoTextSemanticInstructionDataset(
485
+ ["data/protos"],
486
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
487
+ use_speaker=False,
488
+ interactive_prob=1.0,
489
+ skip_text_prob=0.5,
490
+ )
491
+
492
+ for i in ds:
493
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
494
+ # i["labels"][0][i["labels"][0] == -100] = 0
495
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
496
+ break
@@ -0,0 +1,147 @@
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ from lightning import LightningDataModule
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ from fish_speech.utils import RankedLogger
12
+
13
+ logger = RankedLogger(__name__, rank_zero_only=False)
14
+
15
+
16
+ class VQGANDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ filelist: str,
20
+ sample_rate: int = 32000,
21
+ hop_length: int = 640,
22
+ slice_frames: Optional[int] = None,
23
+ ):
24
+ super().__init__()
25
+
26
+ filelist = Path(filelist)
27
+ root = filelist.parent
28
+
29
+ self.files = [
30
+ root / line.strip()
31
+ for line in filelist.read_text(encoding="utf-8").splitlines()
32
+ if line.strip()
33
+ ]
34
+ self.sample_rate = sample_rate
35
+ self.hop_length = hop_length
36
+ self.slice_frames = slice_frames
37
+
38
+ def __len__(self):
39
+ return len(self.files)
40
+
41
+ def get_item(self, idx):
42
+ file = self.files[idx]
43
+
44
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
45
+
46
+ # Slice audio and features
47
+ if (
48
+ self.slice_frames is not None
49
+ and audio.shape[0] > self.slice_frames * self.hop_length
50
+ ):
51
+ start = np.random.randint(
52
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
53
+ )
54
+ audio = audio[start : start + self.slice_frames * self.hop_length]
55
+
56
+ if len(audio) == 0:
57
+ return None
58
+
59
+ max_value = np.abs(audio).max()
60
+ if max_value > 1.0:
61
+ audio = audio / max_value
62
+
63
+ return {
64
+ "audio": torch.from_numpy(audio),
65
+ }
66
+
67
+ def __getitem__(self, idx):
68
+ try:
69
+ return self.get_item(idx)
70
+ except Exception as e:
71
+ import traceback
72
+
73
+ traceback.print_exc()
74
+ logger.error(f"Error loading {self.files[idx]}: {e}")
75
+ return None
76
+
77
+
78
+ @dataclass
79
+ class VQGANCollator:
80
+ def __call__(self, batch):
81
+ batch = [x for x in batch if x is not None]
82
+
83
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
84
+ audio_maxlen = audio_lengths.max()
85
+
86
+ # Rounds up to nearest multiple of 2 (audio_lengths)
87
+ audios = []
88
+ for x in batch:
89
+ audios.append(
90
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
91
+ )
92
+
93
+ return {
94
+ "audios": torch.stack(audios),
95
+ "audio_lengths": audio_lengths,
96
+ }
97
+
98
+
99
+ class VQGANDataModule(LightningDataModule):
100
+ def __init__(
101
+ self,
102
+ train_dataset: VQGANDataset,
103
+ val_dataset: VQGANDataset,
104
+ batch_size: int = 32,
105
+ num_workers: int = 4,
106
+ val_batch_size: Optional[int] = None,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.train_dataset = train_dataset
111
+ self.val_dataset = val_dataset
112
+ self.batch_size = batch_size
113
+ self.val_batch_size = val_batch_size or batch_size
114
+ self.num_workers = num_workers
115
+
116
+ def train_dataloader(self):
117
+ return DataLoader(
118
+ self.train_dataset,
119
+ batch_size=self.batch_size,
120
+ collate_fn=VQGANCollator(),
121
+ num_workers=self.num_workers,
122
+ shuffle=True,
123
+ persistent_workers=True,
124
+ )
125
+
126
+ def val_dataloader(self):
127
+ return DataLoader(
128
+ self.val_dataset,
129
+ batch_size=self.val_batch_size,
130
+ collate_fn=VQGANCollator(),
131
+ num_workers=self.num_workers,
132
+ persistent_workers=True,
133
+ )
134
+
135
+
136
+ if __name__ == "__main__":
137
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
138
+ dataloader = DataLoader(
139
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
140
+ )
141
+
142
+ for batch in dataloader:
143
+ print(batch["audios"].shape)
144
+ print(batch["features"].shape)
145
+ print(batch["audio_lengths"])
146
+ print(batch["feature_lengths"])
147
+ break
@@ -0,0 +1,3 @@
1
+ from .core import i18n
2
+
3
+ __all__ = ["i18n"]
@@ -0,0 +1,40 @@
1
+ import json
2
+ import locale
3
+ from pathlib import Path
4
+
5
+ I18N_FILE_PATH = Path(__file__).parent / "locale"
6
+ DEFAULT_LANGUAGE = "en_US"
7
+
8
+
9
+ def load_language_list(language):
10
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
+ language_list = json.load(f)
12
+
13
+ return language_list
14
+
15
+
16
+ class I18nAuto:
17
+ def __init__(self):
18
+ i18n_file = Path(".locale")
19
+
20
+ if i18n_file.exists():
21
+ with open(i18n_file, "r", encoding="utf-8") as f:
22
+ language = f.read().strip()
23
+ else:
24
+ # getlocale can't identify the system's language ((None, None))
25
+ language = locale.getdefaultlocale()[0]
26
+
27
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
+ language = DEFAULT_LANGUAGE
29
+
30
+ self.language = language
31
+ self.language_map = load_language_list(language)
32
+
33
+ def __call__(self, key):
34
+ return self.language_map.get(key, key)
35
+
36
+ def __repr__(self):
37
+ return "Use Language: " + self.language
38
+
39
+
40
+ i18n = I18nAuto()