xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +2 -1
- xinference/core/model.py +8 -4
- xinference/core/supervisor.py +2 -3
- xinference/core/worker.py +7 -5
- xinference/deploy/cmdline.py +2 -0
- xinference/deploy/local.py +5 -0
- xinference/deploy/test/test_cmdline.py +1 -1
- xinference/deploy/worker.py +6 -0
- xinference/model/audio/cosyvoice.py +0 -1
- xinference/model/audio/model_spec.json +44 -20
- xinference/model/core.py +3 -0
- xinference/model/embedding/flag/core.py +5 -0
- xinference/model/embedding/llama_cpp/core.py +22 -19
- xinference/model/embedding/sentence_transformers/core.py +18 -4
- xinference/model/embedding/vllm/core.py +36 -9
- xinference/model/image/cache_manager.py +56 -0
- xinference/model/image/core.py +9 -0
- xinference/model/image/model_spec.json +178 -1
- xinference/model/image/stable_diffusion/core.py +155 -23
- xinference/model/llm/cache_manager.py +17 -3
- xinference/model/llm/harmony.py +245 -0
- xinference/model/llm/llama_cpp/core.py +41 -40
- xinference/model/llm/llm_family.json +688 -11
- xinference/model/llm/llm_family.py +1 -1
- xinference/model/llm/sglang/core.py +108 -5
- xinference/model/llm/transformers/core.py +20 -18
- xinference/model/llm/transformers/gemma3.py +1 -1
- xinference/model/llm/transformers/gpt_oss.py +91 -0
- xinference/model/llm/transformers/multimodal/core.py +1 -1
- xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
- xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
- xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
- xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
- xinference/model/llm/transformers/utils.py +1 -33
- xinference/model/llm/utils.py +61 -7
- xinference/model/llm/vllm/core.py +44 -8
- xinference/model/rerank/__init__.py +66 -23
- xinference/model/rerank/cache_manager.py +35 -0
- xinference/model/rerank/core.py +87 -339
- xinference/model/rerank/custom.py +33 -8
- xinference/model/rerank/model_spec.json +251 -212
- xinference/model/rerank/rerank_family.py +137 -0
- xinference/model/rerank/sentence_transformers/__init__.py +13 -0
- xinference/model/rerank/sentence_transformers/core.py +337 -0
- xinference/model/rerank/vllm/__init__.py +13 -0
- xinference/model/rerank/vllm/core.py +156 -0
- xinference/model/utils.py +108 -0
- xinference/model/video/model_spec.json +95 -1
- xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
- xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
- xinference/thirdparty/cosyvoice/bin/train.py +23 -3
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
- xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
- xinference/thirdparty/cosyvoice/cli/model.py +53 -75
- xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
- xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
- xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
- xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
- xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
- xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
- xinference/thirdparty/cosyvoice/utils/common.py +20 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
- xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
- xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
- xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
- xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
- xinference/types.py +2 -0
- xinference/ui/gradio/chat_interface.py +2 -0
- xinference/ui/gradio/media_interface.py +353 -7
- xinference/ui/web/ui/build/asset-manifest.json +3 -3
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
- xinference/ui/web/ui/src/locales/en.json +2 -0
- xinference/ui/web/ui/src/locales/ja.json +2 -0
- xinference/ui/web/ui/src/locales/ko.json +2 -0
- xinference/ui/web/ui/src/locales/zh.json +2 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
- xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
- xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
- /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
|
@@ -224,7 +224,7 @@
|
|
|
224
224
|
},
|
|
225
225
|
"virtualenv": {
|
|
226
226
|
"packages": [
|
|
227
|
-
"
|
|
227
|
+
"diffusers==0.35.1",
|
|
228
228
|
"ftfy",
|
|
229
229
|
"imageio-ffmpeg",
|
|
230
230
|
"imageio",
|
|
@@ -241,5 +241,99 @@
|
|
|
241
241
|
"model_revision": "master"
|
|
242
242
|
}
|
|
243
243
|
}
|
|
244
|
+
},
|
|
245
|
+
{
|
|
246
|
+
"version": 2,
|
|
247
|
+
"model_name": "Wan2.2-A14B",
|
|
248
|
+
"model_family": "Wan",
|
|
249
|
+
"model_ability": [
|
|
250
|
+
"text2video"
|
|
251
|
+
],
|
|
252
|
+
"default_model_config": {
|
|
253
|
+
"torch_dtype": "bfloat16"
|
|
254
|
+
},
|
|
255
|
+
"default_generate_config": {},
|
|
256
|
+
"virtualenv": {
|
|
257
|
+
"packages": [
|
|
258
|
+
"diffusers==0.35.1",
|
|
259
|
+
"ftfy",
|
|
260
|
+
"imageio-ffmpeg",
|
|
261
|
+
"imageio",
|
|
262
|
+
"#system_numpy#"
|
|
263
|
+
]
|
|
264
|
+
},
|
|
265
|
+
"model_src": {
|
|
266
|
+
"huggingface": {
|
|
267
|
+
"model_id": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
|
|
268
|
+
"model_revision": "5be7df9619b54f4e2667b2755bc6a756675b5cd7"
|
|
269
|
+
},
|
|
270
|
+
"modelscope": {
|
|
271
|
+
"model_id": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
|
|
272
|
+
"model_revision": "master"
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
},
|
|
276
|
+
{
|
|
277
|
+
"version": 2,
|
|
278
|
+
"model_name": "Wan2.2-i2v-A14B",
|
|
279
|
+
"model_family": "Wan",
|
|
280
|
+
"model_ability": [
|
|
281
|
+
"image2video"
|
|
282
|
+
],
|
|
283
|
+
"default_model_config": {
|
|
284
|
+
"torch_dtype": "bfloat16"
|
|
285
|
+
},
|
|
286
|
+
"default_generate_config": {},
|
|
287
|
+
"virtualenv": {
|
|
288
|
+
"packages": [
|
|
289
|
+
"diffusers==0.35.1",
|
|
290
|
+
"ftfy",
|
|
291
|
+
"imageio-ffmpeg",
|
|
292
|
+
"imageio",
|
|
293
|
+
"#system_numpy#"
|
|
294
|
+
]
|
|
295
|
+
},
|
|
296
|
+
"model_src": {
|
|
297
|
+
"huggingface": {
|
|
298
|
+
"model_id": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
|
|
299
|
+
"model_revision": "596658fd9ca6b7b71d5057529bbf319ecbc61d74"
|
|
300
|
+
},
|
|
301
|
+
"modelscope": {
|
|
302
|
+
"model_id": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
|
|
303
|
+
"model_revision": "master"
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
},
|
|
307
|
+
{
|
|
308
|
+
"version": 2,
|
|
309
|
+
"model_name": "Wan2.2-ti2v-5B",
|
|
310
|
+
"model_family": "Wan",
|
|
311
|
+
"model_ability": [
|
|
312
|
+
"text2video",
|
|
313
|
+
"image2video"
|
|
314
|
+
],
|
|
315
|
+
"default_model_config": {
|
|
316
|
+
"torch_dtype": "bfloat16"
|
|
317
|
+
},
|
|
318
|
+
"default_generate_config": {},
|
|
319
|
+
"virtualenv": {
|
|
320
|
+
"packages": [
|
|
321
|
+
"diffusers==0.35.1",
|
|
322
|
+
"ftfy",
|
|
323
|
+
"imageio-ffmpeg",
|
|
324
|
+
"imageio",
|
|
325
|
+
"#system_numpy#"
|
|
326
|
+
]
|
|
327
|
+
},
|
|
328
|
+
"model_src": {
|
|
329
|
+
"huggingface": {
|
|
330
|
+
"model_id": "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
|
|
331
|
+
"model_revision": "b8fff7315c768468a5333511427288870b2e9635"
|
|
332
|
+
},
|
|
333
|
+
"modelscope": {
|
|
334
|
+
"model_id": "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
|
|
335
|
+
"model_revision": "master"
|
|
336
|
+
}
|
|
337
|
+
}
|
|
244
338
|
}
|
|
245
339
|
]
|
|
@@ -61,8 +61,7 @@ def main():
|
|
|
61
61
|
model = CosyVoice(args.model_dir)
|
|
62
62
|
except Exception:
|
|
63
63
|
try:
|
|
64
|
-
|
|
65
|
-
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
|
64
|
+
model = CosyVoice2(args.model_dir)
|
|
66
65
|
except Exception:
|
|
67
66
|
raise TypeError('no valid model_type!')
|
|
68
67
|
|
|
@@ -93,9 +92,9 @@ def main():
|
|
|
93
92
|
else:
|
|
94
93
|
# 3. export flow encoder
|
|
95
94
|
flow_encoder = model.model.flow.encoder
|
|
96
|
-
script = get_optimized_script(flow_encoder
|
|
95
|
+
script = get_optimized_script(flow_encoder)
|
|
97
96
|
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
|
98
|
-
script = get_optimized_script(flow_encoder.half()
|
|
97
|
+
script = get_optimized_script(flow_encoder.half())
|
|
99
98
|
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
|
100
99
|
logging.info('successfully export flow_encoder')
|
|
101
100
|
|
|
@@ -62,135 +62,58 @@ def main():
|
|
|
62
62
|
model = CosyVoice(args.model_dir)
|
|
63
63
|
except Exception:
|
|
64
64
|
try:
|
|
65
|
-
|
|
66
|
-
model = CosyVoice2(args.model_dir, use_flow_cache=True)
|
|
65
|
+
model = CosyVoice2(args.model_dir)
|
|
67
66
|
except Exception:
|
|
68
67
|
raise TypeError('no valid model_type!')
|
|
69
68
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
logging.info('successfully export estimator')
|
|
119
|
-
else:
|
|
120
|
-
# 1. export flow decoder estimator
|
|
121
|
-
estimator = model.model.flow.decoder.estimator
|
|
122
|
-
estimator.forward = estimator.forward_chunk
|
|
123
|
-
estimator.eval()
|
|
124
|
-
|
|
125
|
-
device = model.model.device
|
|
126
|
-
batch_size, seq_len = 2, 256
|
|
127
|
-
out_channels = model.model.flow.decoder.estimator.out_channels
|
|
128
|
-
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
129
|
-
cache = model.model.init_flow_cache()['decoder_cache']
|
|
130
|
-
cache.pop('offset')
|
|
131
|
-
cache = {k: v[0] for k, v in cache.items()}
|
|
132
|
-
torch.onnx.export(
|
|
133
|
-
estimator,
|
|
134
|
-
(x, mask, mu, t, spks, cond,
|
|
135
|
-
cache['down_blocks_conv_cache'],
|
|
136
|
-
cache['down_blocks_kv_cache'],
|
|
137
|
-
cache['mid_blocks_conv_cache'],
|
|
138
|
-
cache['mid_blocks_kv_cache'],
|
|
139
|
-
cache['up_blocks_conv_cache'],
|
|
140
|
-
cache['up_blocks_kv_cache'],
|
|
141
|
-
cache['final_blocks_conv_cache']),
|
|
142
|
-
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
143
|
-
export_params=True,
|
|
144
|
-
opset_version=18,
|
|
145
|
-
do_constant_folding=True,
|
|
146
|
-
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
|
|
147
|
-
'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
|
|
148
|
-
output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
|
|
149
|
-
'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
|
|
150
|
-
dynamic_axes={
|
|
151
|
-
'x': {2: 'seq_len'},
|
|
152
|
-
'mask': {2: 'seq_len'},
|
|
153
|
-
'mu': {2: 'seq_len'},
|
|
154
|
-
'cond': {2: 'seq_len'},
|
|
155
|
-
'down_blocks_kv_cache': {3: 'cache_in_len'},
|
|
156
|
-
'mid_blocks_kv_cache': {3: 'cache_in_len'},
|
|
157
|
-
'up_blocks_kv_cache': {3: 'cache_in_len'},
|
|
158
|
-
'estimator_out': {2: 'seq_len'},
|
|
159
|
-
'down_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
160
|
-
'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
161
|
-
'up_blocks_kv_cache_out': {3: 'cache_out_len'},
|
|
162
|
-
}
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
# 2. test computation consistency
|
|
166
|
-
option = onnxruntime.SessionOptions()
|
|
167
|
-
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
168
|
-
option.intra_op_num_threads = 1
|
|
169
|
-
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
170
|
-
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
171
|
-
sess_options=option, providers=providers)
|
|
172
|
-
|
|
173
|
-
for iter in tqdm(range(10)):
|
|
174
|
-
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
|
175
|
-
cache = model.model.init_flow_cache()['decoder_cache']
|
|
176
|
-
cache.pop('offset')
|
|
177
|
-
cache = {k: v[0] for k, v in cache.items()}
|
|
178
|
-
output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
|
|
179
|
-
ort_inputs = {
|
|
180
|
-
'x': x.cpu().numpy(),
|
|
181
|
-
'mask': mask.cpu().numpy(),
|
|
182
|
-
'mu': mu.cpu().numpy(),
|
|
183
|
-
't': t.cpu().numpy(),
|
|
184
|
-
'spks': spks.cpu().numpy(),
|
|
185
|
-
'cond': cond.cpu().numpy(),
|
|
186
|
-
}
|
|
187
|
-
output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
|
|
188
|
-
if iter == 0:
|
|
189
|
-
# NOTE why can not pass first iteration check?
|
|
190
|
-
continue
|
|
191
|
-
for i, j in zip(output_pytorch, output_onnx):
|
|
192
|
-
torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
|
|
193
|
-
logging.info('successfully export estimator')
|
|
69
|
+
# 1. export flow decoder estimator
|
|
70
|
+
estimator = model.model.flow.decoder.estimator
|
|
71
|
+
estimator.eval()
|
|
72
|
+
|
|
73
|
+
device = model.model.device
|
|
74
|
+
batch_size, seq_len = 2, 256
|
|
75
|
+
out_channels = model.model.flow.decoder.estimator.out_channels
|
|
76
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
|
77
|
+
torch.onnx.export(
|
|
78
|
+
estimator,
|
|
79
|
+
(x, mask, mu, t, spks, cond),
|
|
80
|
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
81
|
+
export_params=True,
|
|
82
|
+
opset_version=18,
|
|
83
|
+
do_constant_folding=True,
|
|
84
|
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
|
85
|
+
output_names=['estimator_out'],
|
|
86
|
+
dynamic_axes={
|
|
87
|
+
'x': {2: 'seq_len'},
|
|
88
|
+
'mask': {2: 'seq_len'},
|
|
89
|
+
'mu': {2: 'seq_len'},
|
|
90
|
+
'cond': {2: 'seq_len'},
|
|
91
|
+
'estimator_out': {2: 'seq_len'},
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# 2. test computation consistency
|
|
96
|
+
option = onnxruntime.SessionOptions()
|
|
97
|
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
98
|
+
option.intra_op_num_threads = 1
|
|
99
|
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
|
100
|
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
101
|
+
sess_options=option, providers=providers)
|
|
102
|
+
|
|
103
|
+
for _ in tqdm(range(10)):
|
|
104
|
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
|
105
|
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
|
106
|
+
ort_inputs = {
|
|
107
|
+
'x': x.cpu().numpy(),
|
|
108
|
+
'mask': mask.cpu().numpy(),
|
|
109
|
+
'mu': mu.cpu().numpy(),
|
|
110
|
+
't': t.cpu().numpy(),
|
|
111
|
+
'spks': spks.cpu().numpy(),
|
|
112
|
+
'cond': cond.cpu().numpy()
|
|
113
|
+
}
|
|
114
|
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
|
115
|
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
|
116
|
+
logging.info('successfully export estimator')
|
|
194
117
|
|
|
195
118
|
|
|
196
119
|
if __name__ == "__main__":
|
|
@@ -27,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
|
|
|
27
27
|
|
|
28
28
|
from torch.distributed.elastic.multiprocessing.errors import record
|
|
29
29
|
|
|
30
|
+
from cosyvoice.utils.losses import DPOLoss
|
|
30
31
|
from cosyvoice.utils.executor import Executor
|
|
31
32
|
from cosyvoice.utils.train_utils import (
|
|
32
33
|
init_distributed,
|
|
@@ -43,6 +44,7 @@ def get_args():
|
|
|
43
44
|
choices=['torch_ddp', 'deepspeed'],
|
|
44
45
|
help='Engine for paralleled training')
|
|
45
46
|
parser.add_argument('--model', required=True, help='model which will be trained')
|
|
47
|
+
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
|
|
46
48
|
parser.add_argument('--config', required=True, help='config file')
|
|
47
49
|
parser.add_argument('--train_data', required=True, help='train data file')
|
|
48
50
|
parser.add_argument('--cv_data', required=True, help='cv data file')
|
|
@@ -73,6 +75,10 @@ def get_args():
|
|
|
73
75
|
action='store_true',
|
|
74
76
|
default=False,
|
|
75
77
|
help='Use automatic mixed precision training')
|
|
78
|
+
parser.add_argument('--dpo',
|
|
79
|
+
action='store_true',
|
|
80
|
+
default=False,
|
|
81
|
+
help='Use Direct Preference Optimization')
|
|
76
82
|
parser.add_argument('--deepspeed.save_states',
|
|
77
83
|
dest='save_states',
|
|
78
84
|
default='model_only',
|
|
@@ -113,7 +119,7 @@ def main():
|
|
|
113
119
|
|
|
114
120
|
# Get dataset & dataloader
|
|
115
121
|
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
|
116
|
-
init_dataset_and_dataloader(args, configs, gan)
|
|
122
|
+
init_dataset_and_dataloader(args, configs, gan, args.dpo)
|
|
117
123
|
|
|
118
124
|
# Do some sanity checks and save config to arsg.model_dir
|
|
119
125
|
configs = check_modify_and_save_config(args, configs)
|
|
@@ -122,6 +128,8 @@ def main():
|
|
|
122
128
|
writer = init_summarywriter(args)
|
|
123
129
|
|
|
124
130
|
# load checkpoint
|
|
131
|
+
if args.dpo is True:
|
|
132
|
+
configs[args.model].forward = configs[args.model].forward_dpo
|
|
125
133
|
model = configs[args.model]
|
|
126
134
|
start_step, start_epoch = 0, -1
|
|
127
135
|
if args.checkpoint is not None:
|
|
@@ -150,13 +158,25 @@ def main():
|
|
|
150
158
|
info_dict['epoch'] = start_epoch
|
|
151
159
|
save_model(model, 'init', info_dict)
|
|
152
160
|
|
|
161
|
+
# DPO related
|
|
162
|
+
if args.dpo is True:
|
|
163
|
+
ref_model = deepcopy(configs[args.model])
|
|
164
|
+
state_dict = torch.load(args.ref_model, map_location='cpu')
|
|
165
|
+
ref_model.load_state_dict(state_dict, strict=False)
|
|
166
|
+
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
|
|
167
|
+
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
|
|
168
|
+
ref_model = wrap_cuda_model(args, ref_model)
|
|
169
|
+
else:
|
|
170
|
+
ref_model, dpo_loss = None, None
|
|
171
|
+
|
|
153
172
|
# Get executor
|
|
154
|
-
executor = Executor(gan=gan)
|
|
173
|
+
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
|
|
155
174
|
executor.step = start_step
|
|
156
175
|
|
|
157
176
|
# Init scaler, used for pytorch amp mixed precision training
|
|
158
177
|
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
|
159
178
|
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
|
179
|
+
|
|
160
180
|
# Start training loop
|
|
161
181
|
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
|
162
182
|
executor.epoch = epoch
|
|
@@ -167,7 +187,7 @@ def main():
|
|
|
167
187
|
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
|
168
188
|
writer, info_dict, scaler, group_join)
|
|
169
189
|
else:
|
|
170
|
-
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
|
190
|
+
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
|
|
171
191
|
dist.destroy_process_group(group_join)
|
|
172
192
|
|
|
173
193
|
|
|
@@ -26,7 +26,7 @@ from cosyvoice.utils.class_utils import get_model_type
|
|
|
26
26
|
|
|
27
27
|
class CosyVoice:
|
|
28
28
|
|
|
29
|
-
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
|
|
29
|
+
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
|
30
30
|
self.instruct = True if '-Instruct' in model_dir else False
|
|
31
31
|
self.model_dir = model_dir
|
|
32
32
|
self.fp16 = fp16
|
|
@@ -59,6 +59,7 @@ class CosyVoice:
|
|
|
59
59
|
if load_trt:
|
|
60
60
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
61
61
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
|
62
|
+
trt_concurrent,
|
|
62
63
|
self.fp16)
|
|
63
64
|
del configs
|
|
64
65
|
|
|
@@ -140,7 +141,7 @@ class CosyVoice:
|
|
|
140
141
|
|
|
141
142
|
class CosyVoice2(CosyVoice):
|
|
142
143
|
|
|
143
|
-
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False,
|
|
144
|
+
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
|
144
145
|
self.instruct = True if '-Instruct' in model_dir else False
|
|
145
146
|
self.model_dir = model_dir
|
|
146
147
|
self.fp16 = fp16
|
|
@@ -162,15 +163,18 @@ class CosyVoice2(CosyVoice):
|
|
|
162
163
|
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
|
163
164
|
load_jit, load_trt, fp16 = False, False, False
|
|
164
165
|
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
|
165
|
-
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16
|
|
166
|
+
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
|
166
167
|
self.model.load('{}/llm.pt'.format(model_dir),
|
|
167
|
-
'{}/flow.pt'.format(model_dir)
|
|
168
|
+
'{}/flow.pt'.format(model_dir),
|
|
168
169
|
'{}/hift.pt'.format(model_dir))
|
|
170
|
+
if load_vllm:
|
|
171
|
+
self.model.load_vllm('{}/vllm'.format(model_dir))
|
|
169
172
|
if load_jit:
|
|
170
173
|
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
|
171
174
|
if load_trt:
|
|
172
175
|
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
|
173
176
|
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
|
177
|
+
trt_concurrent,
|
|
174
178
|
self.fp16)
|
|
175
179
|
del configs
|
|
176
180
|
|
|
@@ -28,9 +28,9 @@ try:
|
|
|
28
28
|
import ttsfrd
|
|
29
29
|
use_ttsfrd = True
|
|
30
30
|
except ImportError:
|
|
31
|
-
print("failed to import ttsfrd, use
|
|
32
|
-
from
|
|
33
|
-
from
|
|
31
|
+
print("failed to import ttsfrd, use wetext instead")
|
|
32
|
+
from wetext import Normalizer as ZhNormalizer
|
|
33
|
+
from wetext import Normalizer as EnNormalizer
|
|
34
34
|
use_ttsfrd = False
|
|
35
35
|
from cosyvoice.utils.file_utils import logging
|
|
36
36
|
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
|
@@ -68,7 +68,7 @@ class CosyVoiceFrontEnd:
|
|
|
68
68
|
'failed to initialize ttsfrd resource'
|
|
69
69
|
self.frd.set_lang_type('pinyinvg')
|
|
70
70
|
else:
|
|
71
|
-
self.zh_tn_model = ZhNormalizer(remove_erhua=False
|
|
71
|
+
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
|
|
72
72
|
self.en_tn_model = EnNormalizer()
|
|
73
73
|
self.inflect_parser = inflect.engine()
|
|
74
74
|
|