xinference 1.8.1rc1__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (108) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +2 -1
  3. xinference/core/model.py +8 -4
  4. xinference/core/supervisor.py +2 -3
  5. xinference/core/worker.py +7 -5
  6. xinference/deploy/cmdline.py +2 -0
  7. xinference/deploy/local.py +5 -0
  8. xinference/deploy/test/test_cmdline.py +1 -1
  9. xinference/deploy/worker.py +6 -0
  10. xinference/model/audio/cosyvoice.py +0 -1
  11. xinference/model/audio/model_spec.json +44 -20
  12. xinference/model/core.py +3 -0
  13. xinference/model/embedding/flag/core.py +5 -0
  14. xinference/model/embedding/llama_cpp/core.py +22 -19
  15. xinference/model/embedding/sentence_transformers/core.py +18 -4
  16. xinference/model/embedding/vllm/core.py +36 -9
  17. xinference/model/image/cache_manager.py +56 -0
  18. xinference/model/image/core.py +9 -0
  19. xinference/model/image/model_spec.json +178 -1
  20. xinference/model/image/stable_diffusion/core.py +155 -23
  21. xinference/model/llm/cache_manager.py +17 -3
  22. xinference/model/llm/harmony.py +245 -0
  23. xinference/model/llm/llama_cpp/core.py +41 -40
  24. xinference/model/llm/llm_family.json +688 -11
  25. xinference/model/llm/llm_family.py +1 -1
  26. xinference/model/llm/sglang/core.py +108 -5
  27. xinference/model/llm/transformers/core.py +20 -18
  28. xinference/model/llm/transformers/gemma3.py +1 -1
  29. xinference/model/llm/transformers/gpt_oss.py +91 -0
  30. xinference/model/llm/transformers/multimodal/core.py +1 -1
  31. xinference/model/llm/transformers/multimodal/gemma3.py +1 -1
  32. xinference/model/llm/transformers/multimodal/glm4_1v.py +2 -2
  33. xinference/model/llm/transformers/multimodal/ovis2.py +1 -1
  34. xinference/model/llm/transformers/multimodal/qwen-omni.py +7 -8
  35. xinference/model/llm/transformers/multimodal/qwen2_vl.py +9 -6
  36. xinference/model/llm/transformers/utils.py +1 -33
  37. xinference/model/llm/utils.py +61 -7
  38. xinference/model/llm/vllm/core.py +44 -8
  39. xinference/model/rerank/__init__.py +66 -23
  40. xinference/model/rerank/cache_manager.py +35 -0
  41. xinference/model/rerank/core.py +87 -339
  42. xinference/model/rerank/custom.py +33 -8
  43. xinference/model/rerank/model_spec.json +251 -212
  44. xinference/model/rerank/rerank_family.py +137 -0
  45. xinference/model/rerank/sentence_transformers/__init__.py +13 -0
  46. xinference/model/rerank/sentence_transformers/core.py +337 -0
  47. xinference/model/rerank/vllm/__init__.py +13 -0
  48. xinference/model/rerank/vllm/core.py +156 -0
  49. xinference/model/utils.py +108 -0
  50. xinference/model/video/model_spec.json +95 -1
  51. xinference/thirdparty/cosyvoice/bin/export_jit.py +3 -4
  52. xinference/thirdparty/cosyvoice/bin/export_onnx.py +49 -126
  53. xinference/thirdparty/cosyvoice/bin/{inference.py → inference_deprecated.py} +1 -0
  54. xinference/thirdparty/cosyvoice/bin/train.py +23 -3
  55. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +8 -4
  56. xinference/thirdparty/cosyvoice/cli/frontend.py +4 -4
  57. xinference/thirdparty/cosyvoice/cli/model.py +53 -75
  58. xinference/thirdparty/cosyvoice/dataset/dataset.py +5 -18
  59. xinference/thirdparty/cosyvoice/dataset/processor.py +24 -25
  60. xinference/thirdparty/cosyvoice/flow/decoder.py +24 -433
  61. xinference/thirdparty/cosyvoice/flow/flow.py +6 -14
  62. xinference/thirdparty/cosyvoice/flow/flow_matching.py +33 -145
  63. xinference/thirdparty/cosyvoice/hifigan/generator.py +169 -1
  64. xinference/thirdparty/cosyvoice/llm/llm.py +108 -17
  65. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +14 -115
  66. xinference/thirdparty/cosyvoice/utils/common.py +20 -0
  67. xinference/thirdparty/cosyvoice/utils/executor.py +8 -4
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +45 -1
  69. xinference/thirdparty/cosyvoice/utils/losses.py +37 -0
  70. xinference/thirdparty/cosyvoice/utils/mask.py +35 -1
  71. xinference/thirdparty/cosyvoice/utils/train_utils.py +24 -6
  72. xinference/thirdparty/cosyvoice/vllm/cosyvoice2.py +103 -0
  73. xinference/types.py +2 -0
  74. xinference/ui/gradio/chat_interface.py +2 -0
  75. xinference/ui/gradio/media_interface.py +353 -7
  76. xinference/ui/web/ui/build/asset-manifest.json +3 -3
  77. xinference/ui/web/ui/build/index.html +1 -1
  78. xinference/ui/web/ui/build/static/js/main.1086c759.js +3 -0
  79. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +1 -0
  80. xinference/ui/web/ui/node_modules/.cache/babel-loader/28012da921a51f1082549956d3ae82acd769a754b22afda9acddd98a4daf9ea4.json +1 -0
  81. xinference/ui/web/ui/node_modules/.cache/babel-loader/3c5758bd12fa334294b1de0ff6b1a4bac8d963c45472eab9dc3e530d82aa6b3f.json +1 -0
  82. xinference/ui/web/ui/node_modules/.cache/babel-loader/475936ebe725eca62a6f52ce182c06a19b2cef4df9545a05ed0591ee0c539d43.json +1 -0
  83. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +1 -0
  84. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +1 -0
  85. xinference/ui/web/ui/node_modules/.cache/babel-loader/aee5aaba26f2b1e816a3ea9efa68bad8b95695a3d80adcfd8dd57a7bb17ac71a.json +1 -0
  86. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +1 -0
  87. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +1 -0
  88. xinference/ui/web/ui/src/locales/en.json +2 -0
  89. xinference/ui/web/ui/src/locales/ja.json +2 -0
  90. xinference/ui/web/ui/src/locales/ko.json +2 -0
  91. xinference/ui/web/ui/src/locales/zh.json +2 -0
  92. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/METADATA +15 -10
  93. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/RECORD +98 -89
  94. xinference/ui/web/ui/build/static/js/main.b969199a.js +0 -3
  95. xinference/ui/web/ui/build/static/js/main.b969199a.js.map +0 -1
  96. xinference/ui/web/ui/node_modules/.cache/babel-loader/1409a96b9f9f9f5de99a89ab0f738f6da62b449521b0a8d3e4efcf7f5c23534d.json +0 -1
  97. xinference/ui/web/ui/node_modules/.cache/babel-loader/3d2a89f0eccc1f90fc5036c9a1d587c2120e6a6b128aae31d1db7d6bad52722b.json +0 -1
  98. xinference/ui/web/ui/node_modules/.cache/babel-loader/43b889c3a8e2634092ade463d52481c7c5581c72ded8f23bc5f012ea0ef8cea5.json +0 -1
  99. xinference/ui/web/ui/node_modules/.cache/babel-loader/5d47532fb42128280d87f57c8a0b02bc1930f7ef764aa7e90579247df18bba83.json +0 -1
  100. xinference/ui/web/ui/node_modules/.cache/babel-loader/830882bb275468a969614824a9ab8983f874b4581f2eb625e9c66426cdc65e5b.json +0 -1
  101. xinference/ui/web/ui/node_modules/.cache/babel-loader/8e5cb82c2ff3299c6a44563fe6b1c5515c9750613c51bb63abee0b1d70fc5019.json +0 -1
  102. xinference/ui/web/ui/node_modules/.cache/babel-loader/9df08abcb5a7c1e48a4eb25c5d5f5d7253ea6854a4397e6d74d1fd75a14acda1.json +0 -1
  103. xinference/ui/web/ui/node_modules/.cache/babel-loader/b99034986a06445701accc7a4914bb9320947435e8d4e15793392ca4f679316c.json +0 -1
  104. /xinference/ui/web/ui/build/static/js/{main.b969199a.js.LICENSE.txt → main.1086c759.js.LICENSE.txt} +0 -0
  105. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/WHEEL +0 -0
  106. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/entry_points.txt +0 -0
  107. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/licenses/LICENSE +0 -0
  108. {xinference-1.8.1rc1.dist-info → xinference-1.9.1.dist-info}/top_level.txt +0 -0
@@ -224,7 +224,7 @@
224
224
  },
225
225
  "virtualenv": {
226
226
  "packages": [
227
- "git+https://github.com/huggingface/diffusers",
227
+ "diffusers==0.35.1",
228
228
  "ftfy",
229
229
  "imageio-ffmpeg",
230
230
  "imageio",
@@ -241,5 +241,99 @@
241
241
  "model_revision": "master"
242
242
  }
243
243
  }
244
+ },
245
+ {
246
+ "version": 2,
247
+ "model_name": "Wan2.2-A14B",
248
+ "model_family": "Wan",
249
+ "model_ability": [
250
+ "text2video"
251
+ ],
252
+ "default_model_config": {
253
+ "torch_dtype": "bfloat16"
254
+ },
255
+ "default_generate_config": {},
256
+ "virtualenv": {
257
+ "packages": [
258
+ "diffusers==0.35.1",
259
+ "ftfy",
260
+ "imageio-ffmpeg",
261
+ "imageio",
262
+ "#system_numpy#"
263
+ ]
264
+ },
265
+ "model_src": {
266
+ "huggingface": {
267
+ "model_id": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
268
+ "model_revision": "5be7df9619b54f4e2667b2755bc6a756675b5cd7"
269
+ },
270
+ "modelscope": {
271
+ "model_id": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
272
+ "model_revision": "master"
273
+ }
274
+ }
275
+ },
276
+ {
277
+ "version": 2,
278
+ "model_name": "Wan2.2-i2v-A14B",
279
+ "model_family": "Wan",
280
+ "model_ability": [
281
+ "image2video"
282
+ ],
283
+ "default_model_config": {
284
+ "torch_dtype": "bfloat16"
285
+ },
286
+ "default_generate_config": {},
287
+ "virtualenv": {
288
+ "packages": [
289
+ "diffusers==0.35.1",
290
+ "ftfy",
291
+ "imageio-ffmpeg",
292
+ "imageio",
293
+ "#system_numpy#"
294
+ ]
295
+ },
296
+ "model_src": {
297
+ "huggingface": {
298
+ "model_id": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
299
+ "model_revision": "596658fd9ca6b7b71d5057529bbf319ecbc61d74"
300
+ },
301
+ "modelscope": {
302
+ "model_id": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
303
+ "model_revision": "master"
304
+ }
305
+ }
306
+ },
307
+ {
308
+ "version": 2,
309
+ "model_name": "Wan2.2-ti2v-5B",
310
+ "model_family": "Wan",
311
+ "model_ability": [
312
+ "text2video",
313
+ "image2video"
314
+ ],
315
+ "default_model_config": {
316
+ "torch_dtype": "bfloat16"
317
+ },
318
+ "default_generate_config": {},
319
+ "virtualenv": {
320
+ "packages": [
321
+ "diffusers==0.35.1",
322
+ "ftfy",
323
+ "imageio-ffmpeg",
324
+ "imageio",
325
+ "#system_numpy#"
326
+ ]
327
+ },
328
+ "model_src": {
329
+ "huggingface": {
330
+ "model_id": "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
331
+ "model_revision": "b8fff7315c768468a5333511427288870b2e9635"
332
+ },
333
+ "modelscope": {
334
+ "model_id": "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
335
+ "model_revision": "master"
336
+ }
337
+ }
244
338
  }
245
339
  ]
@@ -61,8 +61,7 @@ def main():
61
61
  model = CosyVoice(args.model_dir)
62
62
  except Exception:
63
63
  try:
64
- # NOTE set use_flow_cache=True when export jit for cache inference
65
- model = CosyVoice2(args.model_dir, use_flow_cache=True)
64
+ model = CosyVoice2(args.model_dir)
66
65
  except Exception:
67
66
  raise TypeError('no valid model_type!')
68
67
 
@@ -93,9 +92,9 @@ def main():
93
92
  else:
94
93
  # 3. export flow encoder
95
94
  flow_encoder = model.model.flow.encoder
96
- script = get_optimized_script(flow_encoder, ['forward_chunk'])
95
+ script = get_optimized_script(flow_encoder)
97
96
  script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
98
- script = get_optimized_script(flow_encoder.half(), ['forward_chunk'])
97
+ script = get_optimized_script(flow_encoder.half())
99
98
  script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
100
99
  logging.info('successfully export flow_encoder')
101
100
 
@@ -62,135 +62,58 @@ def main():
62
62
  model = CosyVoice(args.model_dir)
63
63
  except Exception:
64
64
  try:
65
- # NOTE set use_flow_cache=True when export jit for cache inference
66
- model = CosyVoice2(args.model_dir, use_flow_cache=True)
65
+ model = CosyVoice2(args.model_dir)
67
66
  except Exception:
68
67
  raise TypeError('no valid model_type!')
69
68
 
70
- if not isinstance(model, CosyVoice2):
71
- # 1. export flow decoder estimator
72
- estimator = model.model.flow.decoder.estimator
73
- estimator.eval()
74
-
75
- device = model.model.device
76
- batch_size, seq_len = 2, 256
77
- out_channels = model.model.flow.decoder.estimator.out_channels
78
- x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
79
- torch.onnx.export(
80
- estimator,
81
- (x, mask, mu, t, spks, cond),
82
- '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
83
- export_params=True,
84
- opset_version=18,
85
- do_constant_folding=True,
86
- input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
87
- output_names=['estimator_out'],
88
- dynamic_axes={
89
- 'x': {2: 'seq_len'},
90
- 'mask': {2: 'seq_len'},
91
- 'mu': {2: 'seq_len'},
92
- 'cond': {2: 'seq_len'},
93
- 'estimator_out': {2: 'seq_len'},
94
- }
95
- )
96
-
97
- # 2. test computation consistency
98
- option = onnxruntime.SessionOptions()
99
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
100
- option.intra_op_num_threads = 1
101
- providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
102
- estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
103
- sess_options=option, providers=providers)
104
-
105
- for _ in tqdm(range(10)):
106
- x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
107
- output_pytorch = estimator(x, mask, mu, t, spks, cond)
108
- ort_inputs = {
109
- 'x': x.cpu().numpy(),
110
- 'mask': mask.cpu().numpy(),
111
- 'mu': mu.cpu().numpy(),
112
- 't': t.cpu().numpy(),
113
- 'spks': spks.cpu().numpy(),
114
- 'cond': cond.cpu().numpy()
115
- }
116
- output_onnx = estimator_onnx.run(None, ort_inputs)[0]
117
- torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
118
- logging.info('successfully export estimator')
119
- else:
120
- # 1. export flow decoder estimator
121
- estimator = model.model.flow.decoder.estimator
122
- estimator.forward = estimator.forward_chunk
123
- estimator.eval()
124
-
125
- device = model.model.device
126
- batch_size, seq_len = 2, 256
127
- out_channels = model.model.flow.decoder.estimator.out_channels
128
- x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
129
- cache = model.model.init_flow_cache()['decoder_cache']
130
- cache.pop('offset')
131
- cache = {k: v[0] for k, v in cache.items()}
132
- torch.onnx.export(
133
- estimator,
134
- (x, mask, mu, t, spks, cond,
135
- cache['down_blocks_conv_cache'],
136
- cache['down_blocks_kv_cache'],
137
- cache['mid_blocks_conv_cache'],
138
- cache['mid_blocks_kv_cache'],
139
- cache['up_blocks_conv_cache'],
140
- cache['up_blocks_kv_cache'],
141
- cache['final_blocks_conv_cache']),
142
- '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
143
- export_params=True,
144
- opset_version=18,
145
- do_constant_folding=True,
146
- input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
147
- 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
148
- output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
149
- 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
150
- dynamic_axes={
151
- 'x': {2: 'seq_len'},
152
- 'mask': {2: 'seq_len'},
153
- 'mu': {2: 'seq_len'},
154
- 'cond': {2: 'seq_len'},
155
- 'down_blocks_kv_cache': {3: 'cache_in_len'},
156
- 'mid_blocks_kv_cache': {3: 'cache_in_len'},
157
- 'up_blocks_kv_cache': {3: 'cache_in_len'},
158
- 'estimator_out': {2: 'seq_len'},
159
- 'down_blocks_kv_cache_out': {3: 'cache_out_len'},
160
- 'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
161
- 'up_blocks_kv_cache_out': {3: 'cache_out_len'},
162
- }
163
- )
164
-
165
- # 2. test computation consistency
166
- option = onnxruntime.SessionOptions()
167
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
168
- option.intra_op_num_threads = 1
169
- providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
170
- estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
171
- sess_options=option, providers=providers)
172
-
173
- for iter in tqdm(range(10)):
174
- x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
175
- cache = model.model.init_flow_cache()['decoder_cache']
176
- cache.pop('offset')
177
- cache = {k: v[0] for k, v in cache.items()}
178
- output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
179
- ort_inputs = {
180
- 'x': x.cpu().numpy(),
181
- 'mask': mask.cpu().numpy(),
182
- 'mu': mu.cpu().numpy(),
183
- 't': t.cpu().numpy(),
184
- 'spks': spks.cpu().numpy(),
185
- 'cond': cond.cpu().numpy(),
186
- }
187
- output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
188
- if iter == 0:
189
- # NOTE why can not pass first iteration check?
190
- continue
191
- for i, j in zip(output_pytorch, output_onnx):
192
- torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
193
- logging.info('successfully export estimator')
69
+ # 1. export flow decoder estimator
70
+ estimator = model.model.flow.decoder.estimator
71
+ estimator.eval()
72
+
73
+ device = model.model.device
74
+ batch_size, seq_len = 2, 256
75
+ out_channels = model.model.flow.decoder.estimator.out_channels
76
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
77
+ torch.onnx.export(
78
+ estimator,
79
+ (x, mask, mu, t, spks, cond),
80
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
81
+ export_params=True,
82
+ opset_version=18,
83
+ do_constant_folding=True,
84
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
85
+ output_names=['estimator_out'],
86
+ dynamic_axes={
87
+ 'x': {2: 'seq_len'},
88
+ 'mask': {2: 'seq_len'},
89
+ 'mu': {2: 'seq_len'},
90
+ 'cond': {2: 'seq_len'},
91
+ 'estimator_out': {2: 'seq_len'},
92
+ }
93
+ )
94
+
95
+ # 2. test computation consistency
96
+ option = onnxruntime.SessionOptions()
97
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
98
+ option.intra_op_num_threads = 1
99
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
100
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
101
+ sess_options=option, providers=providers)
102
+
103
+ for _ in tqdm(range(10)):
104
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
105
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
106
+ ort_inputs = {
107
+ 'x': x.cpu().numpy(),
108
+ 'mask': mask.cpu().numpy(),
109
+ 'mu': mu.cpu().numpy(),
110
+ 't': t.cpu().numpy(),
111
+ 'spks': spks.cpu().numpy(),
112
+ 'cond': cond.cpu().numpy()
113
+ }
114
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
115
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
116
+ logging.info('successfully export estimator')
194
117
 
195
118
 
196
119
  if __name__ == "__main__":
@@ -122,4 +122,5 @@ def main():
122
122
 
123
123
 
124
124
  if __name__ == '__main__':
125
+ logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!')
125
126
  main()
@@ -27,6 +27,7 @@ from hyperpyyaml import load_hyperpyyaml
27
27
 
28
28
  from torch.distributed.elastic.multiprocessing.errors import record
29
29
 
30
+ from cosyvoice.utils.losses import DPOLoss
30
31
  from cosyvoice.utils.executor import Executor
31
32
  from cosyvoice.utils.train_utils import (
32
33
  init_distributed,
@@ -43,6 +44,7 @@ def get_args():
43
44
  choices=['torch_ddp', 'deepspeed'],
44
45
  help='Engine for paralleled training')
45
46
  parser.add_argument('--model', required=True, help='model which will be trained')
47
+ parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
46
48
  parser.add_argument('--config', required=True, help='config file')
47
49
  parser.add_argument('--train_data', required=True, help='train data file')
48
50
  parser.add_argument('--cv_data', required=True, help='cv data file')
@@ -73,6 +75,10 @@ def get_args():
73
75
  action='store_true',
74
76
  default=False,
75
77
  help='Use automatic mixed precision training')
78
+ parser.add_argument('--dpo',
79
+ action='store_true',
80
+ default=False,
81
+ help='Use Direct Preference Optimization')
76
82
  parser.add_argument('--deepspeed.save_states',
77
83
  dest='save_states',
78
84
  default='model_only',
@@ -113,7 +119,7 @@ def main():
113
119
 
114
120
  # Get dataset & dataloader
115
121
  train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
116
- init_dataset_and_dataloader(args, configs, gan)
122
+ init_dataset_and_dataloader(args, configs, gan, args.dpo)
117
123
 
118
124
  # Do some sanity checks and save config to arsg.model_dir
119
125
  configs = check_modify_and_save_config(args, configs)
@@ -122,6 +128,8 @@ def main():
122
128
  writer = init_summarywriter(args)
123
129
 
124
130
  # load checkpoint
131
+ if args.dpo is True:
132
+ configs[args.model].forward = configs[args.model].forward_dpo
125
133
  model = configs[args.model]
126
134
  start_step, start_epoch = 0, -1
127
135
  if args.checkpoint is not None:
@@ -150,13 +158,25 @@ def main():
150
158
  info_dict['epoch'] = start_epoch
151
159
  save_model(model, 'init', info_dict)
152
160
 
161
+ # DPO related
162
+ if args.dpo is True:
163
+ ref_model = deepcopy(configs[args.model])
164
+ state_dict = torch.load(args.ref_model, map_location='cpu')
165
+ ref_model.load_state_dict(state_dict, strict=False)
166
+ dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
167
+ # NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
168
+ ref_model = wrap_cuda_model(args, ref_model)
169
+ else:
170
+ ref_model, dpo_loss = None, None
171
+
153
172
  # Get executor
154
- executor = Executor(gan=gan)
173
+ executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
155
174
  executor.step = start_step
156
175
 
157
176
  # Init scaler, used for pytorch amp mixed precision training
158
177
  scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
159
178
  print('start step {} start epoch {}'.format(start_step, start_epoch))
179
+
160
180
  # Start training loop
161
181
  for epoch in range(start_epoch + 1, info_dict['max_epoch']):
162
182
  executor.epoch = epoch
@@ -167,7 +187,7 @@ def main():
167
187
  executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
168
188
  writer, info_dict, scaler, group_join)
169
189
  else:
170
- executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
190
+ executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
171
191
  dist.destroy_process_group(group_join)
172
192
 
173
193
 
@@ -26,7 +26,7 @@ from cosyvoice.utils.class_utils import get_model_type
26
26
 
27
27
  class CosyVoice:
28
28
 
29
- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
29
+ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
30
30
  self.instruct = True if '-Instruct' in model_dir else False
31
31
  self.model_dir = model_dir
32
32
  self.fp16 = fp16
@@ -59,6 +59,7 @@ class CosyVoice:
59
59
  if load_trt:
60
60
  self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
61
61
  '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
62
+ trt_concurrent,
62
63
  self.fp16)
63
64
  del configs
64
65
 
@@ -140,7 +141,7 @@ class CosyVoice:
140
141
 
141
142
  class CosyVoice2(CosyVoice):
142
143
 
143
- def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_flow_cache=False):
144
+ def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
144
145
  self.instruct = True if '-Instruct' in model_dir else False
145
146
  self.model_dir = model_dir
146
147
  self.fp16 = fp16
@@ -162,15 +163,18 @@ class CosyVoice2(CosyVoice):
162
163
  if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
163
164
  load_jit, load_trt, fp16 = False, False, False
164
165
  logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
165
- self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16, use_flow_cache)
166
+ self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
166
167
  self.model.load('{}/llm.pt'.format(model_dir),
167
- '{}/flow.pt'.format(model_dir) if use_flow_cache is False else '{}/flow.cache.pt'.format(model_dir),
168
+ '{}/flow.pt'.format(model_dir),
168
169
  '{}/hift.pt'.format(model_dir))
170
+ if load_vllm:
171
+ self.model.load_vllm('{}/vllm'.format(model_dir))
169
172
  if load_jit:
170
173
  self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
171
174
  if load_trt:
172
175
  self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
173
176
  '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
177
+ trt_concurrent,
174
178
  self.fp16)
175
179
  del configs
176
180
 
@@ -28,9 +28,9 @@ try:
28
28
  import ttsfrd
29
29
  use_ttsfrd = True
30
30
  except ImportError:
31
- print("failed to import ttsfrd, use WeTextProcessing instead")
32
- from tn.chinese.normalizer import Normalizer as ZhNormalizer
33
- from tn.english.normalizer import Normalizer as EnNormalizer
31
+ print("failed to import ttsfrd, use wetext instead")
32
+ from wetext import Normalizer as ZhNormalizer
33
+ from wetext import Normalizer as EnNormalizer
34
34
  use_ttsfrd = False
35
35
  from cosyvoice.utils.file_utils import logging
36
36
  from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
@@ -68,7 +68,7 @@ class CosyVoiceFrontEnd:
68
68
  'failed to initialize ttsfrd resource'
69
69
  self.frd.set_lang_type('pinyinvg')
70
70
  else:
71
- self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, overwrite_cache=True)
71
+ self.zh_tn_model = ZhNormalizer(remove_erhua=False)
72
72
  self.en_tn_model = EnNormalizer()
73
73
  self.inflect_parser = inflect.engine()
74
74