xinference 1.5.1__py3-none-any.whl → 1.6.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (96) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +97 -8
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/core/media_interface.py +758 -0
  5. xinference/core/model.py +49 -9
  6. xinference/core/worker.py +31 -37
  7. xinference/deploy/utils.py +0 -3
  8. xinference/model/audio/__init__.py +16 -27
  9. xinference/model/audio/core.py +1 -0
  10. xinference/model/audio/cosyvoice.py +4 -2
  11. xinference/model/audio/model_spec.json +20 -3
  12. xinference/model/audio/model_spec_modelscope.json +18 -1
  13. xinference/model/embedding/__init__.py +16 -24
  14. xinference/model/image/__init__.py +15 -25
  15. xinference/model/llm/__init__.py +37 -110
  16. xinference/model/llm/core.py +15 -6
  17. xinference/model/llm/llama_cpp/core.py +25 -353
  18. xinference/model/llm/llm_family.json +613 -89
  19. xinference/model/llm/llm_family.py +9 -1
  20. xinference/model/llm/llm_family_modelscope.json +540 -90
  21. xinference/model/llm/mlx/core.py +6 -3
  22. xinference/model/llm/reasoning_parser.py +281 -5
  23. xinference/model/llm/sglang/core.py +16 -3
  24. xinference/model/llm/transformers/chatglm.py +2 -2
  25. xinference/model/llm/transformers/cogagent.py +1 -1
  26. xinference/model/llm/transformers/cogvlm2.py +1 -1
  27. xinference/model/llm/transformers/core.py +9 -3
  28. xinference/model/llm/transformers/glm4v.py +1 -1
  29. xinference/model/llm/transformers/minicpmv26.py +1 -1
  30. xinference/model/llm/transformers/qwen-omni.py +6 -0
  31. xinference/model/llm/transformers/qwen_vl.py +1 -1
  32. xinference/model/llm/utils.py +68 -45
  33. xinference/model/llm/vllm/core.py +38 -18
  34. xinference/model/llm/vllm/xavier/test/test_xavier.py +1 -10
  35. xinference/model/rerank/__init__.py +13 -24
  36. xinference/model/video/__init__.py +15 -25
  37. xinference/model/video/core.py +3 -3
  38. xinference/model/video/diffusers.py +133 -16
  39. xinference/model/video/model_spec.json +54 -0
  40. xinference/model/video/model_spec_modelscope.json +56 -0
  41. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  42. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  43. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  44. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  45. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  46. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  47. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  48. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  49. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  50. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  51. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  52. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  53. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  54. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  55. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  56. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  57. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  58. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  59. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  60. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  63. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  64. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  65. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  66. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  67. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  68. xinference/types.py +0 -71
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  72. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  75. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  76. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  77. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  79. xinference/web/ui/src/locales/en.json +6 -4
  80. xinference/web/ui/src/locales/zh.json +6 -4
  81. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/METADATA +59 -39
  82. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/RECORD +87 -87
  83. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/WHEEL +1 -1
  84. xinference/core/image_interface.py +0 -377
  85. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  86. xinference/web/ui/build/static/js/main.91e77b5c.js +0 -3
  87. xinference/web/ui/build/static/js/main.91e77b5c.js.map +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/5e6edb0fb87e3798f142e9abf8dd2dc46bab33a60d31dff525797c0c99887097.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/6087820be1bd5c02c42dff797e7df365448ef35ab26dd5d6bd33e967e05cbfd4.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.91e77b5c.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  94. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/entry_points.txt +0 -0
  95. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/licenses/LICENSE +0 -0
  96. {xinference-1.5.1.dist-info → xinference-1.6.0.post1.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,13 @@
14
14
 
15
15
  import base64
16
16
  import logging
17
+ import operator
17
18
  import os
18
19
  import time
19
20
  import uuid
20
21
  from concurrent.futures import ThreadPoolExecutor
21
- from functools import partial
22
- from typing import TYPE_CHECKING, List, Union
22
+ from functools import partial, reduce
23
+ from typing import TYPE_CHECKING, Any, List, Optional, Union
23
24
 
24
25
  import numpy as np
25
26
  import PIL.Image
@@ -29,6 +30,7 @@ from ...device_utils import gpu_count, move_model_to_available_device
29
30
  from ...types import Video, VideoList
30
31
 
31
32
  if TYPE_CHECKING:
33
+ from ....core.progress_tracker import Progressor
32
34
  from .core import VideoModelFamilyV1
33
35
 
34
36
 
@@ -53,7 +55,7 @@ def export_to_video_imageio(
53
55
  return output_video_path
54
56
 
55
57
 
56
- class DiffUsersVideoModel:
58
+ class DiffusersVideoModel:
57
59
  def __init__(
58
60
  self,
59
61
  model_uid: str,
@@ -111,11 +113,27 @@ class DiffUsersVideoModel:
111
113
  self._model_path, transformer=transformer, **kwargs
112
114
  )
113
115
  elif self.model_spec.model_family == "Wan":
114
- from diffusers import WanPipeline
116
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, WanPipeline
117
+ from transformers import CLIPVisionModel
115
118
 
116
- pipeline = self._model = WanPipeline.from_pretrained(
117
- self._model_path, **kwargs
118
- )
119
+ if "text2video" in self.model_spec.model_ability:
120
+ pipeline = self._model = WanPipeline.from_pretrained(
121
+ self._model_path, **kwargs
122
+ )
123
+ else:
124
+ assert "image2video" in self.model_spec.model_ability
125
+
126
+ image_encoder = CLIPVisionModel.from_pretrained(
127
+ self._model_path,
128
+ subfolder="image_encoder",
129
+ torch_dtype=torch.float32,
130
+ )
131
+ vae = AutoencoderKLWan.from_pretrained(
132
+ self._model_path, subfolder="vae", torch_dtype=torch.float32
133
+ )
134
+ pipeline = self._model = WanImageToVideoPipeline.from_pretrained(
135
+ self._model_path, vae=vae, image_encoder=image_encoder, **kwargs
136
+ )
119
137
  else:
120
138
  raise Exception(
121
139
  f"Unsupported model family: {self._model_spec.model_family}"
@@ -130,6 +148,11 @@ class DiffUsersVideoModel:
130
148
  pipeline.transformer = torch.compile(
131
149
  pipeline.transformer, mode="max-autotune", fullgraph=True
132
150
  )
151
+ if kwargs.get("layerwise_cast", False):
152
+ compute_dtype = pipeline.transformer.dtype
153
+ pipeline.transformer.enable_layerwise_casting(
154
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=compute_dtype
155
+ )
133
156
  if kwargs.get("cpu_offload", False):
134
157
  logger.debug("CPU offloading model")
135
158
  pipeline.enable_model_cpu_offload()
@@ -145,6 +168,33 @@ class DiffUsersVideoModel:
145
168
  except AttributeError:
146
169
  # model does support tiling
147
170
  pass
171
+ elif kwargs.get("group_offload", False):
172
+ from diffusers.hooks.group_offloading import apply_group_offloading
173
+
174
+ onload_device = torch.device("cuda")
175
+ offload_device = torch.device("cpu")
176
+
177
+ apply_group_offloading(
178
+ pipeline.text_encoder,
179
+ onload_device=onload_device,
180
+ offload_device=offload_device,
181
+ offload_type="block_level",
182
+ num_blocks_per_group=4,
183
+ )
184
+ group_offload_kwargs = {}
185
+ if kwargs.get("use_stream", False):
186
+ group_offload_kwargs["offload_type"] = "block_level"
187
+ group_offload_kwargs["num_blocks_per_group"] = 4
188
+ else:
189
+ group_offload_kwargs["offload_type"] = "leaf_level"
190
+ group_offload_kwargs["use_stream"] = True
191
+ pipeline.transformer.enable_group_offload(
192
+ onload_device=onload_device,
193
+ offload_device=offload_device,
194
+ **group_offload_kwargs,
195
+ )
196
+ # Since we've offloaded the larger models already, we can move the rest of the model components to GPU
197
+ pipeline = move_model_to_available_device(pipeline)
148
198
  elif not kwargs.get("device_map"):
149
199
  logger.debug("Loading model to available device")
150
200
  if gpu_count() > 1:
@@ -154,6 +204,26 @@ class DiffUsersVideoModel:
154
204
  # Recommended if your computer has < 64 GB of RAM
155
205
  pipeline.enable_attention_slicing()
156
206
 
207
+ @staticmethod
208
+ def _process_progressor(kwargs: dict):
209
+ import diffusers
210
+
211
+ progressor: Progressor = kwargs.pop("progressor", None)
212
+
213
+ def report_status_callback(
214
+ pipe: diffusers.DiffusionPipeline,
215
+ step: int,
216
+ timestep: int,
217
+ callback_kwargs: dict,
218
+ ):
219
+ num_steps = pipe.num_timesteps
220
+ progressor.set_progress((step + 1) / num_steps)
221
+
222
+ return callback_kwargs
223
+
224
+ if progressor and progressor.request_id:
225
+ kwargs["callback_on_step_end"] = report_status_callback
226
+
157
227
  def text_to_video(
158
228
  self,
159
229
  prompt: str,
@@ -162,15 +232,6 @@ class DiffUsersVideoModel:
162
232
  response_format: str = "b64_json",
163
233
  **kwargs,
164
234
  ) -> VideoList:
165
- import gc
166
-
167
- from diffusers.utils import export_to_video
168
-
169
- # cv2 bug will cause the video cannot be normally displayed
170
- # thus we use the imageio one
171
- # from diffusers.utils import export_to_video
172
- from ...device_utils import empty_cache
173
-
174
235
  assert self._model is not None
175
236
  assert callable(self._model)
176
237
  generate_kwargs = self._model_spec.default_generate_config.copy()
@@ -181,11 +242,67 @@ class DiffUsersVideoModel:
181
242
  "diffusers text_to_video args: %s",
182
243
  generate_kwargs,
183
244
  )
245
+ self._process_progressor(generate_kwargs)
184
246
  output = self._model(
185
247
  prompt=prompt,
186
248
  num_inference_steps=num_inference_steps,
187
249
  **generate_kwargs,
188
250
  )
251
+ return self._output_to_video(output, fps, response_format)
252
+
253
+ def image_to_video(
254
+ self,
255
+ image: PIL.Image,
256
+ prompt: str,
257
+ n: int = 1,
258
+ num_inference_steps: Optional[int] = None,
259
+ response_format: str = "b64_json",
260
+ **kwargs,
261
+ ):
262
+ assert self._model is not None
263
+ assert callable(self._model)
264
+ generate_kwargs = self._model_spec.default_generate_config.copy()
265
+ generate_kwargs.update(kwargs)
266
+ generate_kwargs["num_videos_per_prompt"] = n
267
+ if num_inference_steps:
268
+ generate_kwargs["num_inference_steps"] = num_inference_steps
269
+ fps = generate_kwargs.pop("fps", 10)
270
+
271
+ # process image
272
+ max_area = generate_kwargs.pop("max_area")
273
+ if isinstance(max_area, str):
274
+ max_area = [int(v) for v in max_area.split("*")]
275
+ max_area = reduce(operator.mul, max_area, 1)
276
+ image = self._process_image(image, max_area)
277
+
278
+ height, width = image.height, image.width
279
+ generate_kwargs.pop("width", None)
280
+ generate_kwargs.pop("height", None)
281
+ self._process_progressor(generate_kwargs)
282
+ output = self._model(
283
+ image=image, prompt=prompt, height=height, width=width, **generate_kwargs
284
+ )
285
+ return self._output_to_video(output, fps, response_format)
286
+
287
+ def _process_image(self, image: PIL.Image, max_area: int) -> PIL.Image:
288
+ assert self._model is not None
289
+ aspect_ratio = image.height / image.width
290
+ mod_value = (
291
+ self._model.vae_scale_factor_spatial
292
+ * self._model.transformer.config.patch_size[1]
293
+ )
294
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
295
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
296
+ return image.resize((width, height))
297
+
298
+ def _output_to_video(self, output: Any, fps: int, response_format: str):
299
+ import gc
300
+
301
+ # cv2 bug will cause the video cannot be normally displayed
302
+ # thus we use the imageio one
303
+ from diffusers.utils import export_to_video
304
+
305
+ from ...device_utils import empty_cache
189
306
 
190
307
  # clean cache
191
308
  gc.collect()
@@ -91,5 +91,59 @@
91
91
  "numpy==1.26.4"
92
92
  ]
93
93
  }
94
+ },
95
+ {
96
+ "model_name": "Wan2.1-i2v-14B-480p",
97
+ "model_family": "Wan",
98
+ "model_id": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
99
+ "model_revision": "b184e23a8a16b20f108f727c902e769e873ffc73",
100
+ "model_ability": [
101
+ "image2video"
102
+ ],
103
+ "default_model_config": {
104
+ "torch_dtype": "bfloat16"
105
+ },
106
+ "default_generate_config": {
107
+ "max_area": [
108
+ 480,
109
+ 832
110
+ ]
111
+ },
112
+ "virtualenv": {
113
+ "packages": [
114
+ "diffusers>=0.33.0",
115
+ "ftfy",
116
+ "imageio-ffmpeg",
117
+ "imageio",
118
+ "numpy==1.26.4"
119
+ ]
120
+ }
121
+ },
122
+ {
123
+ "model_name": "Wan2.1-i2v-14B-720p",
124
+ "model_family": "Wan",
125
+ "model_id": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
126
+ "model_revision": "eb849f76dfa246545b65774a9e25943ee69b3fa3",
127
+ "model_ability": [
128
+ "image2video"
129
+ ],
130
+ "default_model_config": {
131
+ "torch_dtype": "bfloat16"
132
+ },
133
+ "default_generate_config": {
134
+ "max_area": [
135
+ 720,
136
+ 1280
137
+ ]
138
+ },
139
+ "virtualenv": {
140
+ "packages": [
141
+ "diffusers>=0.33.0",
142
+ "ftfy",
143
+ "imageio-ffmpeg",
144
+ "imageio",
145
+ "numpy==1.26.4"
146
+ ]
147
+ }
94
148
  }
95
149
  ]
@@ -96,5 +96,61 @@
96
96
  "numpy==1.26.4"
97
97
  ]
98
98
  }
99
+ },
100
+ {
101
+ "model_name": "Wan2.1-i2v-14B-480p",
102
+ "model_family": "Wan",
103
+ "model_hub": "modelscope",
104
+ "model_id": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
105
+ "model_revision": "master",
106
+ "model_ability": [
107
+ "image2video"
108
+ ],
109
+ "default_model_config": {
110
+ "torch_dtype": "bfloat16"
111
+ },
112
+ "default_generate_config": {
113
+ "max_area": [
114
+ 480,
115
+ 832
116
+ ]
117
+ },
118
+ "virtualenv": {
119
+ "packages": [
120
+ "diffusers>=0.33.0",
121
+ "ftfy",
122
+ "imageio-ffmpeg",
123
+ "imageio",
124
+ "numpy==1.26.4"
125
+ ]
126
+ }
127
+ },
128
+ {
129
+ "model_name": "Wan2.1-i2v-14B-720p",
130
+ "model_family": "Wan",
131
+ "model_hub": "modelscope",
132
+ "model_id": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
133
+ "model_revision": "master",
134
+ "model_ability": [
135
+ "image2video"
136
+ ],
137
+ "default_model_config": {
138
+ "torch_dtype": "bfloat16"
139
+ },
140
+ "default_generate_config": {
141
+ "max_area": [
142
+ 720,
143
+ 1280
144
+ ]
145
+ },
146
+ "virtualenv": {
147
+ "packages": [
148
+ "diffusers>=0.33.0",
149
+ "ftfy",
150
+ "imageio-ffmpeg",
151
+ "imageio",
152
+ "numpy==1.26.4"
153
+ ]
154
+ }
99
155
  }
100
156
  ]
@@ -75,10 +75,11 @@ def main():
75
75
  print('Processing {}'.format(path))
76
76
  states = torch.load(path, map_location=torch.device('cpu'))
77
77
  for k in states.keys():
78
- if k not in avg.keys():
79
- avg[k] = states[k].clone()
80
- else:
81
- avg[k] += states[k]
78
+ if k not in ['step', 'epoch']:
79
+ if k not in avg.keys():
80
+ avg[k] = states[k].clone()
81
+ else:
82
+ avg[k] += states[k]
82
83
  # average
83
84
  for k in avg.keys():
84
85
  if avg[k] is not None:
@@ -23,7 +23,8 @@ import torch
23
23
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
24
  sys.path.append('{}/../..'.format(ROOT_DIR))
25
25
  sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
- from cosyvoice.cli.cosyvoice import CosyVoice
26
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
27
+ from cosyvoice.utils.file_utils import logging
27
28
 
28
29
 
29
30
  def get_args():
@@ -37,6 +38,16 @@ def get_args():
37
38
  return args
38
39
 
39
40
 
41
+ def get_optimized_script(model, preserved_attrs=[]):
42
+ script = torch.jit.script(model)
43
+ if preserved_attrs != []:
44
+ script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
45
+ else:
46
+ script = torch.jit.freeze(script)
47
+ script = torch.jit.optimize_for_inference(script)
48
+ return script
49
+
50
+
40
51
  def main():
41
52
  args = get_args()
42
53
  logging.basicConfig(level=logging.DEBUG,
@@ -46,28 +57,47 @@ def main():
46
57
  torch._C._jit_set_profiling_mode(False)
47
58
  torch._C._jit_set_profiling_executor(False)
48
59
 
49
- cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
60
+ try:
61
+ model = CosyVoice(args.model_dir)
62
+ except Exception:
63
+ try:
64
+ # NOTE set use_flow_cache=True when export jit for cache inference
65
+ model = CosyVoice2(args.model_dir, use_flow_cache=True)
66
+ except Exception:
67
+ raise TypeError('no valid model_type!')
50
68
 
51
- # 1. export llm text_encoder
52
- llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
53
- script = torch.jit.script(llm_text_encoder)
54
- script = torch.jit.freeze(script)
55
- script = torch.jit.optimize_for_inference(script)
56
- script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
69
+ if not isinstance(model, CosyVoice2):
70
+ # 1. export llm text_encoder
71
+ llm_text_encoder = model.model.llm.text_encoder
72
+ script = get_optimized_script(llm_text_encoder)
73
+ script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
74
+ script = get_optimized_script(llm_text_encoder.half())
75
+ script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
76
+ logging.info('successfully export llm_text_encoder')
57
77
 
58
- # 2. export llm llm
59
- llm_llm = cosyvoice.model.llm.llm.half()
60
- script = torch.jit.script(llm_llm)
61
- script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
62
- script = torch.jit.optimize_for_inference(script)
63
- script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
78
+ # 2. export llm llm
79
+ llm_llm = model.model.llm.llm
80
+ script = get_optimized_script(llm_llm, ['forward_chunk'])
81
+ script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
82
+ script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
83
+ script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
84
+ logging.info('successfully export llm_llm')
64
85
 
65
- # 3. export flow encoder
66
- flow_encoder = cosyvoice.model.flow.encoder
67
- script = torch.jit.script(flow_encoder)
68
- script = torch.jit.freeze(script)
69
- script = torch.jit.optimize_for_inference(script)
70
- script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
86
+ # 3. export flow encoder
87
+ flow_encoder = model.model.flow.encoder
88
+ script = get_optimized_script(flow_encoder)
89
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
90
+ script = get_optimized_script(flow_encoder.half())
91
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
92
+ logging.info('successfully export flow_encoder')
93
+ else:
94
+ # 3. export flow encoder
95
+ flow_encoder = model.model.flow.encoder
96
+ script = get_optimized_script(flow_encoder, ['forward_chunk'])
97
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
98
+ script = get_optimized_script(flow_encoder.half(), ['forward_chunk'])
99
+ script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
100
+ logging.info('successfully export flow_encoder')
71
101
 
72
102
 
73
103
  if __name__ == '__main__':
@@ -27,7 +27,8 @@ from tqdm import tqdm
27
27
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
28
  sys.path.append('{}/../..'.format(ROOT_DIR))
29
29
  sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
- from cosyvoice.cli.cosyvoice import CosyVoice
30
+ from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
31
+ from cosyvoice.utils.file_utils import logging
31
32
 
32
33
 
33
34
  def get_dummy_input(batch_size, seq_len, out_channels, device):
@@ -51,61 +52,145 @@ def get_args():
51
52
  return args
52
53
 
53
54
 
55
+ @torch.no_grad()
54
56
  def main():
55
57
  args = get_args()
56
58
  logging.basicConfig(level=logging.DEBUG,
57
59
  format='%(asctime)s %(levelname)s %(message)s')
58
60
 
59
- cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
60
-
61
- # 1. export flow decoder estimator
62
- estimator = cosyvoice.model.flow.decoder.estimator
63
-
64
- device = cosyvoice.model.device
65
- batch_size, seq_len = 1, 256
66
- out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
67
- x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
68
- torch.onnx.export(
69
- estimator,
70
- (x, mask, mu, t, spks, cond),
71
- '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
72
- export_params=True,
73
- opset_version=18,
74
- do_constant_folding=True,
75
- input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
76
- output_names=['estimator_out'],
77
- dynamic_axes={
78
- 'x': {0: 'batch_size', 2: 'seq_len'},
79
- 'mask': {0: 'batch_size', 2: 'seq_len'},
80
- 'mu': {0: 'batch_size', 2: 'seq_len'},
81
- 'cond': {0: 'batch_size', 2: 'seq_len'},
82
- 't': {0: 'batch_size'},
83
- 'spks': {0: 'batch_size'},
84
- 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
85
- }
86
- )
87
-
88
- # 2. test computation consistency
89
- option = onnxruntime.SessionOptions()
90
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
91
- option.intra_op_num_threads = 1
92
- providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
93
- estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
94
- sess_options=option, providers=providers)
95
-
96
- for _ in tqdm(range(10)):
97
- x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
98
- output_pytorch = estimator(x, mask, mu, t, spks, cond)
99
- ort_inputs = {
100
- 'x': x.cpu().numpy(),
101
- 'mask': mask.cpu().numpy(),
102
- 'mu': mu.cpu().numpy(),
103
- 't': t.cpu().numpy(),
104
- 'spks': spks.cpu().numpy(),
105
- 'cond': cond.cpu().numpy()
106
- }
107
- output_onnx = estimator_onnx.run(None, ort_inputs)[0]
108
- torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
61
+ try:
62
+ model = CosyVoice(args.model_dir)
63
+ except Exception:
64
+ try:
65
+ # NOTE set use_flow_cache=True when export jit for cache inference
66
+ model = CosyVoice2(args.model_dir, use_flow_cache=True)
67
+ except Exception:
68
+ raise TypeError('no valid model_type!')
69
+
70
+ if not isinstance(model, CosyVoice2):
71
+ # 1. export flow decoder estimator
72
+ estimator = model.model.flow.decoder.estimator
73
+ estimator.eval()
74
+
75
+ device = model.model.device
76
+ batch_size, seq_len = 2, 256
77
+ out_channels = model.model.flow.decoder.estimator.out_channels
78
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
79
+ torch.onnx.export(
80
+ estimator,
81
+ (x, mask, mu, t, spks, cond),
82
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
83
+ export_params=True,
84
+ opset_version=18,
85
+ do_constant_folding=True,
86
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
87
+ output_names=['estimator_out'],
88
+ dynamic_axes={
89
+ 'x': {2: 'seq_len'},
90
+ 'mask': {2: 'seq_len'},
91
+ 'mu': {2: 'seq_len'},
92
+ 'cond': {2: 'seq_len'},
93
+ 'estimator_out': {2: 'seq_len'},
94
+ }
95
+ )
96
+
97
+ # 2. test computation consistency
98
+ option = onnxruntime.SessionOptions()
99
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
100
+ option.intra_op_num_threads = 1
101
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
102
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
103
+ sess_options=option, providers=providers)
104
+
105
+ for _ in tqdm(range(10)):
106
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
107
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
108
+ ort_inputs = {
109
+ 'x': x.cpu().numpy(),
110
+ 'mask': mask.cpu().numpy(),
111
+ 'mu': mu.cpu().numpy(),
112
+ 't': t.cpu().numpy(),
113
+ 'spks': spks.cpu().numpy(),
114
+ 'cond': cond.cpu().numpy()
115
+ }
116
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
117
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
118
+ logging.info('successfully export estimator')
119
+ else:
120
+ # 1. export flow decoder estimator
121
+ estimator = model.model.flow.decoder.estimator
122
+ estimator.forward = estimator.forward_chunk
123
+ estimator.eval()
124
+
125
+ device = model.model.device
126
+ batch_size, seq_len = 2, 256
127
+ out_channels = model.model.flow.decoder.estimator.out_channels
128
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
129
+ cache = model.model.init_flow_cache()['decoder_cache']
130
+ cache.pop('offset')
131
+ cache = {k: v[0] for k, v in cache.items()}
132
+ torch.onnx.export(
133
+ estimator,
134
+ (x, mask, mu, t, spks, cond,
135
+ cache['down_blocks_conv_cache'],
136
+ cache['down_blocks_kv_cache'],
137
+ cache['mid_blocks_conv_cache'],
138
+ cache['mid_blocks_kv_cache'],
139
+ cache['up_blocks_conv_cache'],
140
+ cache['up_blocks_kv_cache'],
141
+ cache['final_blocks_conv_cache']),
142
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
143
+ export_params=True,
144
+ opset_version=18,
145
+ do_constant_folding=True,
146
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond', 'down_blocks_conv_cache', 'down_blocks_kv_cache', 'mid_blocks_conv_cache', 'mid_blocks_kv_cache',
147
+ 'up_blocks_conv_cache', 'up_blocks_kv_cache', 'final_blocks_conv_cache'],
148
+ output_names=['estimator_out', 'down_blocks_conv_cache_out', 'down_blocks_kv_cache_out', 'mid_blocks_conv_cache_out', 'mid_blocks_kv_cache_out',
149
+ 'up_blocks_conv_cache_out', 'up_blocks_kv_cache_out', 'final_blocks_conv_cache_out'],
150
+ dynamic_axes={
151
+ 'x': {2: 'seq_len'},
152
+ 'mask': {2: 'seq_len'},
153
+ 'mu': {2: 'seq_len'},
154
+ 'cond': {2: 'seq_len'},
155
+ 'down_blocks_kv_cache': {3: 'cache_in_len'},
156
+ 'mid_blocks_kv_cache': {3: 'cache_in_len'},
157
+ 'up_blocks_kv_cache': {3: 'cache_in_len'},
158
+ 'estimator_out': {2: 'seq_len'},
159
+ 'down_blocks_kv_cache_out': {3: 'cache_out_len'},
160
+ 'mid_blocks_kv_cache_out': {3: 'cache_out_len'},
161
+ 'up_blocks_kv_cache_out': {3: 'cache_out_len'},
162
+ }
163
+ )
164
+
165
+ # 2. test computation consistency
166
+ option = onnxruntime.SessionOptions()
167
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
168
+ option.intra_op_num_threads = 1
169
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
170
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
171
+ sess_options=option, providers=providers)
172
+
173
+ for iter in tqdm(range(10)):
174
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
175
+ cache = model.model.init_flow_cache()['decoder_cache']
176
+ cache.pop('offset')
177
+ cache = {k: v[0] for k, v in cache.items()}
178
+ output_pytorch = estimator(x, mask, mu, t, spks, cond, **{k: v.clone() for k, v in cache.items()})
179
+ ort_inputs = {
180
+ 'x': x.cpu().numpy(),
181
+ 'mask': mask.cpu().numpy(),
182
+ 'mu': mu.cpu().numpy(),
183
+ 't': t.cpu().numpy(),
184
+ 'spks': spks.cpu().numpy(),
185
+ 'cond': cond.cpu().numpy(),
186
+ }
187
+ output_onnx = estimator_onnx.run(None, {**ort_inputs, **{k: v.clone().cpu().numpy() for k, v in cache.items()}})
188
+ if iter == 0:
189
+ # NOTE why can not pass first iteration check?
190
+ continue
191
+ for i, j in zip(output_pytorch, output_onnx):
192
+ torch.testing.assert_allclose(i, torch.from_numpy(j).to(device), rtol=1e-2, atol=1e-4)
193
+ logging.info('successfully export estimator')
109
194
 
110
195
 
111
196
  if __name__ == "__main__":