xinference 1.5.0.post2__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (137) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +107 -11
  3. xinference/client/restful/restful_client.py +51 -11
  4. xinference/constants.py +5 -1
  5. xinference/core/media_interface.py +758 -0
  6. xinference/core/model.py +49 -9
  7. xinference/core/supervisor.py +1 -1
  8. xinference/core/utils.py +1 -1
  9. xinference/core/worker.py +33 -39
  10. xinference/deploy/cmdline.py +17 -0
  11. xinference/deploy/utils.py +0 -3
  12. xinference/model/audio/__init__.py +16 -27
  13. xinference/model/audio/core.py +2 -1
  14. xinference/model/audio/cosyvoice.py +4 -2
  15. xinference/model/audio/model_spec.json +63 -46
  16. xinference/model/audio/model_spec_modelscope.json +31 -14
  17. xinference/model/embedding/__init__.py +16 -24
  18. xinference/model/image/__init__.py +15 -25
  19. xinference/model/llm/__init__.py +40 -115
  20. xinference/model/llm/core.py +29 -6
  21. xinference/model/llm/llama_cpp/core.py +30 -347
  22. xinference/model/llm/llm_family.json +1674 -2203
  23. xinference/model/llm/llm_family.py +71 -7
  24. xinference/model/llm/llm_family_csghub.json +0 -32
  25. xinference/model/llm/llm_family_modelscope.json +1838 -2016
  26. xinference/model/llm/llm_family_openmind_hub.json +19 -325
  27. xinference/model/llm/lmdeploy/core.py +7 -2
  28. xinference/model/llm/mlx/core.py +23 -7
  29. xinference/model/llm/reasoning_parser.py +281 -5
  30. xinference/model/llm/sglang/core.py +39 -11
  31. xinference/model/llm/transformers/chatglm.py +9 -2
  32. xinference/model/llm/transformers/cogagent.py +10 -12
  33. xinference/model/llm/transformers/cogvlm2.py +6 -3
  34. xinference/model/llm/transformers/cogvlm2_video.py +3 -6
  35. xinference/model/llm/transformers/core.py +58 -60
  36. xinference/model/llm/transformers/deepseek_v2.py +4 -2
  37. xinference/model/llm/transformers/deepseek_vl.py +10 -4
  38. xinference/model/llm/transformers/deepseek_vl2.py +9 -4
  39. xinference/model/llm/transformers/gemma3.py +4 -5
  40. xinference/model/llm/transformers/glm4v.py +3 -21
  41. xinference/model/llm/transformers/glm_edge_v.py +3 -20
  42. xinference/model/llm/transformers/intern_vl.py +3 -6
  43. xinference/model/llm/transformers/internlm2.py +1 -1
  44. xinference/model/llm/transformers/minicpmv25.py +4 -2
  45. xinference/model/llm/transformers/minicpmv26.py +5 -3
  46. xinference/model/llm/transformers/omnilmm.py +1 -1
  47. xinference/model/llm/transformers/opt.py +1 -1
  48. xinference/model/llm/transformers/ovis2.py +302 -0
  49. xinference/model/llm/transformers/qwen-omni.py +8 -1
  50. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  51. xinference/model/llm/transformers/qwen2_vl.py +5 -1
  52. xinference/model/llm/transformers/qwen_vl.py +5 -2
  53. xinference/model/llm/utils.py +96 -45
  54. xinference/model/llm/vllm/core.py +108 -24
  55. xinference/model/llm/vllm/distributed_executor.py +8 -7
  56. xinference/model/llm/vllm/xavier/allocator.py +1 -1
  57. xinference/model/llm/vllm/xavier/block_manager.py +1 -1
  58. xinference/model/llm/vllm/xavier/block_tracker.py +3 -3
  59. xinference/model/llm/vllm/xavier/executor.py +1 -1
  60. xinference/model/llm/vllm/xavier/test/test_xavier.py +2 -11
  61. xinference/model/rerank/__init__.py +13 -24
  62. xinference/model/video/__init__.py +15 -25
  63. xinference/model/video/core.py +3 -3
  64. xinference/model/video/diffusers.py +157 -13
  65. xinference/model/video/model_spec.json +100 -0
  66. xinference/model/video/model_spec_modelscope.json +104 -0
  67. xinference/thirdparty/cosyvoice/bin/average_model.py +5 -4
  68. xinference/thirdparty/cosyvoice/bin/export_jit.py +50 -20
  69. xinference/thirdparty/cosyvoice/bin/export_onnx.py +136 -51
  70. xinference/thirdparty/cosyvoice/bin/inference.py +15 -5
  71. xinference/thirdparty/cosyvoice/bin/train.py +7 -2
  72. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +72 -52
  73. xinference/thirdparty/cosyvoice/cli/frontend.py +58 -58
  74. xinference/thirdparty/cosyvoice/cli/model.py +140 -155
  75. xinference/thirdparty/cosyvoice/dataset/processor.py +9 -5
  76. xinference/thirdparty/cosyvoice/flow/decoder.py +656 -54
  77. xinference/thirdparty/cosyvoice/flow/flow.py +69 -11
  78. xinference/thirdparty/cosyvoice/flow/flow_matching.py +167 -63
  79. xinference/thirdparty/cosyvoice/flow/length_regulator.py +1 -0
  80. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +91 -1
  81. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +4 -1
  82. xinference/thirdparty/cosyvoice/hifigan/generator.py +4 -1
  83. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +2 -2
  84. xinference/thirdparty/cosyvoice/llm/llm.py +198 -18
  85. xinference/thirdparty/cosyvoice/transformer/embedding.py +12 -4
  86. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +124 -21
  87. xinference/thirdparty/cosyvoice/utils/class_utils.py +13 -0
  88. xinference/thirdparty/cosyvoice/utils/common.py +1 -1
  89. xinference/thirdparty/cosyvoice/utils/file_utils.py +40 -2
  90. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +7 -0
  91. xinference/thirdparty/cosyvoice/utils/mask.py +4 -0
  92. xinference/thirdparty/cosyvoice/utils/train_utils.py +5 -1
  93. xinference/thirdparty/matcha/hifigan/xutils.py +3 -3
  94. xinference/types.py +2 -71
  95. xinference/web/ui/build/asset-manifest.json +6 -6
  96. xinference/web/ui/build/index.html +1 -1
  97. xinference/web/ui/build/static/css/{main.0f6523be.css → main.337afe76.css} +2 -2
  98. xinference/web/ui/build/static/css/main.337afe76.css.map +1 -0
  99. xinference/web/ui/build/static/js/main.ae579a97.js +3 -0
  100. xinference/web/ui/build/static/js/main.ae579a97.js.map +1 -0
  101. xinference/web/ui/node_modules/.cache/babel-loader/0196a4b09e3264614e54360d5f832c46b31d964ec58296765ebff191ace6adbf.json +1 -0
  102. xinference/web/ui/node_modules/.cache/babel-loader/12e02ee790dbf57ead09a241a93bb5f893393aa36628ca741d44390e836a103f.json +1 -0
  103. xinference/web/ui/node_modules/.cache/babel-loader/18fa271456b31cded36c05c4c71c6b2b1cf4e4128c1e32f0e45d8b9f21764397.json +1 -0
  104. xinference/web/ui/node_modules/.cache/babel-loader/2fdc61dcb6a9d1fbcb44be592d0e87d8c3f21297a7327559ef5345665f8343f7.json +1 -0
  105. xinference/web/ui/node_modules/.cache/babel-loader/3d596a3e8dd6430d7ce81d164e32c31f8d47cfa5f725c328a298754d78563e14.json +1 -0
  106. xinference/web/ui/node_modules/.cache/babel-loader/5c08e2cd07809ed3e41486b16652253404cbb63a3ff8d0366ee50f57e2413cea.json +1 -0
  107. xinference/web/ui/node_modules/.cache/babel-loader/6798e126f3bc5f95a4c16a9c2ad52ffe77970c62406d83e20604dfda7ffd2247.json +1 -0
  108. xinference/web/ui/node_modules/.cache/babel-loader/8472e58a31720892d534f3febda31f746b25ec4aa60787eef34217b074e67965.json +1 -0
  109. xinference/web/ui/node_modules/.cache/babel-loader/b617f7d21a95045fc57b26a9373551740f1978a826134cbf705c3a1bf8714a93.json +1 -0
  110. xinference/web/ui/node_modules/.cache/babel-loader/c1506cb142151366074975f30fa1ff9cd6e5e978b62a4b074dfc16fe08d70d75.json +1 -0
  111. xinference/web/ui/node_modules/.cache/babel-loader/c5c7c2cd1b863ce41adff2c4737bba06eef3a1acf28288cb83d992060f6b8923.json +1 -0
  112. xinference/web/ui/src/locales/en.json +7 -4
  113. xinference/web/ui/src/locales/zh.json +7 -4
  114. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/METADATA +56 -36
  115. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/RECORD +120 -121
  116. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/WHEEL +1 -1
  117. xinference/core/image_interface.py +0 -377
  118. xinference/model/llm/transformers/compression.py +0 -258
  119. xinference/model/llm/transformers/yi_vl.py +0 -239
  120. xinference/thirdparty/cosyvoice/bin/export_trt.sh +0 -9
  121. xinference/web/ui/build/static/css/main.0f6523be.css.map +0 -1
  122. xinference/web/ui/build/static/js/main.4b67a723.js +0 -3
  123. xinference/web/ui/build/static/js/main.4b67a723.js.map +0 -1
  124. xinference/web/ui/node_modules/.cache/babel-loader/0f0adb2283a8f469d097a7a0ebb754624fa52414c83b83696c41f2e6a737ceda.json +0 -1
  125. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +0 -1
  126. xinference/web/ui/node_modules/.cache/babel-loader/8157db83995c671eb57abc316c337f867d1dc63fb83520bb4ff351fee57dcce2.json +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/8f9af2979e45d4648f0cfae108363e58ee421c29a9d4e7329b6f06d9adfd4133.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/9c8b1a86e7c65b2b2599a205e30920652d6c2105f926508ef5bcf29a3ef4ce76.json +0 -1
  129. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +0 -1
  130. xinference/web/ui/node_modules/.cache/babel-loader/e4ba658c6b3b0490910acdae0c535a892257efb61539a24adf8038fc653bd22f.json +0 -1
  131. xinference/web/ui/node_modules/.cache/babel-loader/efe7cd132c27a8f9fd5352a394c491fd5fb0da0348cf9fcbd923164a32365eab.json +0 -1
  132. xinference/web/ui/node_modules/.cache/babel-loader/f04f666b77b44d7be3e16034d6b0074de2ba9c254f1fae15222b3148608fa8b3.json +0 -1
  133. xinference/web/ui/node_modules/.cache/babel-loader/f199e8173f6409a5802ed44acb95f218388131136504b2e9132129e150c92f9a.json +0 -1
  134. /xinference/web/ui/build/static/js/{main.4b67a723.js.LICENSE.txt → main.ae579a97.js.LICENSE.txt} +0 -0
  135. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/entry_points.txt +0 -0
  136. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/licenses/LICENSE +0 -0
  137. {xinference-1.5.0.post2.dist-info → xinference-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,377 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import base64
16
- import io
17
- import logging
18
- import os
19
- import threading
20
- import time
21
- import uuid
22
- from typing import Dict, List, Optional, Union
23
-
24
- import gradio as gr
25
- import PIL.Image
26
- from gradio import Markdown
27
-
28
- from ..client.restful.restful_client import RESTfulImageModelHandle
29
-
30
- logger = logging.getLogger(__name__)
31
-
32
-
33
- class ImageInterface:
34
- def __init__(
35
- self,
36
- endpoint: str,
37
- model_uid: str,
38
- model_family: str,
39
- model_name: str,
40
- model_id: str,
41
- model_revision: str,
42
- model_ability: List[str],
43
- controlnet: Union[None, List[Dict[str, Union[str, None]]]],
44
- access_token: Optional[str],
45
- ):
46
- self.endpoint = endpoint
47
- self.model_uid = model_uid
48
- self.model_family = model_family
49
- self.model_name = model_name
50
- self.model_id = model_id
51
- self.model_revision = model_revision
52
- self.model_ability = model_ability
53
- self.controlnet = controlnet
54
- self.access_token = (
55
- access_token.replace("Bearer ", "") if access_token is not None else None
56
- )
57
-
58
- def build(self) -> gr.Blocks:
59
- assert "stable_diffusion" in self.model_family
60
-
61
- interface = self.build_main_interface()
62
- interface.queue()
63
- # Gradio initiates the queue during a startup event, but since the app has already been
64
- # started, that event will not run, so manually invoke the startup events.
65
- # See: https://github.com/gradio-app/gradio/issues/5228
66
- try:
67
- interface.run_startup_events()
68
- except AttributeError:
69
- # compatibility
70
- interface.startup_events()
71
- favicon_path = os.path.join(
72
- os.path.dirname(os.path.abspath(__file__)),
73
- os.path.pardir,
74
- "web",
75
- "ui",
76
- "public",
77
- "favicon.svg",
78
- )
79
- interface.favicon_path = favicon_path
80
- return interface
81
-
82
- def text2image_interface(self) -> "gr.Blocks":
83
- from ..model.image.stable_diffusion.core import SAMPLING_METHODS
84
-
85
- def text_generate_image(
86
- prompt: str,
87
- n: int,
88
- size_width: int,
89
- size_height: int,
90
- guidance_scale: int,
91
- num_inference_steps: int,
92
- negative_prompt: Optional[str] = None,
93
- sampler_name: Optional[str] = None,
94
- progress=gr.Progress(),
95
- ) -> PIL.Image.Image:
96
- from ..client import RESTfulClient
97
-
98
- client = RESTfulClient(self.endpoint)
99
- client._set_token(self.access_token)
100
- model = client.get_model(self.model_uid)
101
- assert isinstance(model, RESTfulImageModelHandle)
102
-
103
- size = f"{int(size_width)}*{int(size_height)}"
104
- guidance_scale = None if guidance_scale == -1 else guidance_scale # type: ignore
105
- num_inference_steps = (
106
- None if num_inference_steps == -1 else num_inference_steps # type: ignore
107
- )
108
- sampler_name = None if sampler_name == "default" else sampler_name
109
-
110
- response = None
111
- exc = None
112
- request_id = str(uuid.uuid4())
113
-
114
- def run_in_thread():
115
- nonlocal exc, response
116
- try:
117
- response = model.text_to_image(
118
- request_id=request_id,
119
- prompt=prompt,
120
- n=n,
121
- size=size,
122
- num_inference_steps=num_inference_steps,
123
- guidance_scale=guidance_scale,
124
- negative_prompt=negative_prompt,
125
- sampler_name=sampler_name,
126
- response_format="b64_json",
127
- )
128
- except Exception as e:
129
- exc = e
130
-
131
- t = threading.Thread(target=run_in_thread)
132
- t.start()
133
- while t.is_alive():
134
- try:
135
- cur_progress = client.get_progress(request_id)["progress"]
136
- except (KeyError, RuntimeError):
137
- cur_progress = 0.0
138
-
139
- progress(cur_progress, desc="Generating images")
140
- time.sleep(1)
141
-
142
- if exc:
143
- raise exc
144
-
145
- images = []
146
- for image_dict in response["data"]: # type: ignore
147
- assert image_dict["b64_json"] is not None
148
- image_data = base64.b64decode(image_dict["b64_json"])
149
- image = PIL.Image.open(io.BytesIO(image_data))
150
- images.append(image)
151
-
152
- return images
153
-
154
- with gr.Blocks() as text2image_vl_interface:
155
- with gr.Column():
156
- with gr.Row():
157
- with gr.Column(scale=10):
158
- prompt = gr.Textbox(
159
- label="Prompt",
160
- show_label=True,
161
- placeholder="Enter prompt here...",
162
- )
163
- negative_prompt = gr.Textbox(
164
- label="Negative prompt",
165
- show_label=True,
166
- placeholder="Enter negative prompt here...",
167
- )
168
- with gr.Column(scale=1):
169
- generate_button = gr.Button("Generate")
170
-
171
- with gr.Row():
172
- n = gr.Number(label="Number of Images", value=1)
173
- size_width = gr.Number(label="Width", value=1024)
174
- size_height = gr.Number(label="Height", value=1024)
175
- with gr.Row():
176
- guidance_scale = gr.Number(label="Guidance scale", value=-1)
177
- num_inference_steps = gr.Number(
178
- label="Inference Step Number", value=-1
179
- )
180
- sampler_name = gr.Dropdown(
181
- choices=SAMPLING_METHODS,
182
- value="default",
183
- label="Sampling method",
184
- )
185
-
186
- with gr.Column():
187
- image_output = gr.Gallery()
188
-
189
- generate_button.click(
190
- text_generate_image,
191
- inputs=[
192
- prompt,
193
- n,
194
- size_width,
195
- size_height,
196
- guidance_scale,
197
- num_inference_steps,
198
- negative_prompt,
199
- sampler_name,
200
- ],
201
- outputs=image_output,
202
- )
203
-
204
- return text2image_vl_interface
205
-
206
- def image2image_interface(self) -> "gr.Blocks":
207
- from ..model.image.stable_diffusion.core import SAMPLING_METHODS
208
-
209
- def image_generate_image(
210
- prompt: str,
211
- negative_prompt: str,
212
- image: PIL.Image.Image,
213
- n: int,
214
- size_width: int,
215
- size_height: int,
216
- num_inference_steps: int,
217
- padding_image_to_multiple: int,
218
- sampler_name: Optional[str] = None,
219
- progress=gr.Progress(),
220
- ) -> PIL.Image.Image:
221
- from ..client import RESTfulClient
222
-
223
- client = RESTfulClient(self.endpoint)
224
- client._set_token(self.access_token)
225
- model = client.get_model(self.model_uid)
226
- assert isinstance(model, RESTfulImageModelHandle)
227
-
228
- if size_width > 0 and size_height > 0:
229
- size = f"{int(size_width)}*{int(size_height)}"
230
- else:
231
- size = None
232
- num_inference_steps = (
233
- None if num_inference_steps == -1 else num_inference_steps # type: ignore
234
- )
235
- padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
236
- sampler_name = None if sampler_name == "default" else sampler_name
237
-
238
- bio = io.BytesIO()
239
- image.save(bio, format="png")
240
-
241
- response = None
242
- exc = None
243
- request_id = str(uuid.uuid4())
244
-
245
- def run_in_thread():
246
- nonlocal exc, response
247
- try:
248
- response = model.image_to_image(
249
- request_id=request_id,
250
- prompt=prompt,
251
- negative_prompt=negative_prompt,
252
- n=n,
253
- image=bio.getvalue(),
254
- size=size,
255
- response_format="b64_json",
256
- num_inference_steps=num_inference_steps,
257
- padding_image_to_multiple=padding_image_to_multiple,
258
- sampler_name=sampler_name,
259
- )
260
- except Exception as e:
261
- exc = e
262
-
263
- t = threading.Thread(target=run_in_thread)
264
- t.start()
265
- while t.is_alive():
266
- try:
267
- cur_progress = client.get_progress(request_id)["progress"]
268
- except (KeyError, RuntimeError):
269
- cur_progress = 0.0
270
-
271
- progress(cur_progress, desc="Generating images")
272
- time.sleep(1)
273
-
274
- if exc:
275
- raise exc
276
-
277
- images = []
278
- for image_dict in response["data"]: # type: ignore
279
- assert image_dict["b64_json"] is not None
280
- image_data = base64.b64decode(image_dict["b64_json"])
281
- image = PIL.Image.open(io.BytesIO(image_data))
282
- images.append(image)
283
-
284
- return images
285
-
286
- with gr.Blocks() as image2image_inteface:
287
- with gr.Column():
288
- with gr.Row():
289
- with gr.Column(scale=10):
290
- prompt = gr.Textbox(
291
- label="Prompt",
292
- show_label=True,
293
- placeholder="Enter prompt here...",
294
- )
295
- negative_prompt = gr.Textbox(
296
- label="Negative Prompt",
297
- show_label=True,
298
- placeholder="Enter negative prompt here...",
299
- )
300
- with gr.Column(scale=1):
301
- generate_button = gr.Button("Generate")
302
-
303
- with gr.Row():
304
- n = gr.Number(label="Number of image", value=1)
305
- size_width = gr.Number(label="Width", value=-1)
306
- size_height = gr.Number(label="Height", value=-1)
307
-
308
- with gr.Row():
309
- num_inference_steps = gr.Number(
310
- label="Inference Step Number", value=-1
311
- )
312
- padding_image_to_multiple = gr.Number(
313
- label="Padding image to multiple", value=-1
314
- )
315
- sampler_name = gr.Dropdown(
316
- choices=SAMPLING_METHODS,
317
- value="default",
318
- label="Sampling method",
319
- )
320
-
321
- with gr.Row():
322
- with gr.Column(scale=1):
323
- uploaded_image = gr.Image(type="pil", label="Upload Image")
324
- with gr.Column(scale=1):
325
- output_gallery = gr.Gallery()
326
-
327
- generate_button.click(
328
- image_generate_image,
329
- inputs=[
330
- prompt,
331
- negative_prompt,
332
- uploaded_image,
333
- n,
334
- size_width,
335
- size_height,
336
- num_inference_steps,
337
- padding_image_to_multiple,
338
- sampler_name,
339
- ],
340
- outputs=output_gallery,
341
- )
342
- return image2image_inteface
343
-
344
- def build_main_interface(self) -> "gr.Blocks":
345
- with gr.Blocks(
346
- title=f"🎨 Xinference Stable Diffusion: {self.model_name} 🎨",
347
- css="""
348
- .center{
349
- display: flex;
350
- justify-content: center;
351
- align-items: center;
352
- padding: 0px;
353
- color: #9ea4b0 !important;
354
- }
355
- """,
356
- analytics_enabled=False,
357
- ) as app:
358
- Markdown(
359
- f"""
360
- <h1 class="center" style='text-align: center; margin-bottom: 1rem'>🎨 Xinference Stable Diffusion: {self.model_name} 🎨</h1>
361
- """
362
- )
363
- Markdown(
364
- f"""
365
- <div class="center">
366
- Model ID: {self.model_uid}
367
- </div>
368
- """
369
- )
370
- if "text2image" in self.model_ability:
371
- with gr.Tab("Text to Image"):
372
- self.text2image_interface()
373
- if "image2image" in self.model_ability:
374
- with gr.Tab("Image to Image"):
375
- self.image2image_interface()
376
-
377
- return app
@@ -1,258 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import dataclasses
16
- import gc
17
- import glob
18
- import os
19
-
20
- import torch
21
- import torch.nn as nn
22
- from huggingface_hub import snapshot_download
23
- from torch import Tensor
24
- from torch.nn import functional as F
25
- from tqdm import tqdm
26
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
27
-
28
- from ....device_utils import empty_cache
29
-
30
-
31
- @dataclasses.dataclass
32
- class CompressionConfig:
33
- """Group-wise quantization."""
34
-
35
- num_bits: int
36
- group_size: int
37
- group_dim: int
38
- symmetric: bool
39
- enabled: bool = True
40
-
41
-
42
- default_compression_config = CompressionConfig(
43
- num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True
44
- )
45
-
46
-
47
- class CLinear(nn.Module):
48
- """Compressed Linear Layer."""
49
-
50
- def __init__(self, weight=None, bias=None, device=None):
51
- super().__init__()
52
- if weight is None:
53
- self.weight = None
54
- elif isinstance(weight, Tensor):
55
- self.weight = compress(weight.data.to(device), default_compression_config)
56
- else:
57
- self.weight = weight
58
- self.bias = bias
59
-
60
- def forward(self, input: Tensor) -> Tensor:
61
- weight = decompress(self.weight, default_compression_config)
62
- if self.bias is None:
63
- return F.linear(input.to(weight.dtype), weight)
64
- return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype))
65
-
66
-
67
- def get_compressed_list(module, prefix=""):
68
- compressed_list = []
69
- for attr_str in dir(module):
70
- target_attr = getattr(module, attr_str)
71
- if type(target_attr) == torch.nn.Linear:
72
- full_name = (
73
- f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
74
- )
75
- compressed_list.append(full_name)
76
- for name, child in module.named_children():
77
- child_prefix = f"{prefix}.{name}" if prefix else name
78
- for each in get_compressed_list(child, child_prefix):
79
- compressed_list.append(each)
80
- return compressed_list
81
-
82
-
83
- def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""):
84
- for attr_str in dir(module):
85
- target_attr = getattr(module, attr_str)
86
- if type(target_attr) == torch.nn.Linear:
87
- full_name = (
88
- f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
89
- )
90
- setattr(
91
- module,
92
- attr_str,
93
- CLinear(
94
- compressed_state_dict[full_name], target_attr.bias, target_device
95
- ),
96
- )
97
- for name, child in module.named_children():
98
- child_prefix = f"{prefix}.{name}" if prefix else name
99
- apply_compressed_weight(
100
- child, compressed_state_dict, target_device, child_prefix
101
- )
102
-
103
-
104
- def load_compress_model(
105
- model_path: str,
106
- device: str,
107
- torch_dtype: torch.dtype,
108
- use_fast: bool,
109
- revision: str = "main",
110
- ):
111
- from accelerate import init_empty_weights
112
- from accelerate.utils import set_module_tensor_to_device
113
-
114
- # partially load model
115
- tokenizer = AutoTokenizer.from_pretrained(
116
- model_path,
117
- use_fast=use_fast,
118
- trust_remote_code=True,
119
- revision=revision,
120
- )
121
-
122
- with init_empty_weights():
123
- config = AutoConfig.from_pretrained(
124
- model_path,
125
- low_cpu_mem_usage=True,
126
- torch_dtype=torch_dtype,
127
- trust_remote_code=True,
128
- revision=revision,
129
- )
130
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
131
- linear_weights = get_compressed_list(model)
132
-
133
- if os.path.exists(model_path):
134
- # `model_path` is a local folder
135
- base_pattern = os.path.join(model_path, "pytorch_model*.bin")
136
- else:
137
- # `model_path` is a cached Hugging Face repo
138
- model_path = snapshot_download(model_path, revision=revision)
139
- base_pattern = os.path.join(model_path, "pytorch_model*.bin")
140
-
141
- files = glob.glob(base_pattern)
142
-
143
- compressed_state_dict = {}
144
-
145
- for filename in tqdm(files):
146
- tmp_state_dict = torch.load(filename, map_location=torch.device(device))
147
- for name in tmp_state_dict:
148
- if name in linear_weights:
149
- tensor = tmp_state_dict[name].to(device).data.to(torch_dtype)
150
- compressed_state_dict[name] = compress(
151
- tensor, default_compression_config
152
- )
153
- else:
154
- compressed_state_dict[name] = tmp_state_dict[name].to(device)
155
- tmp_state_dict[name] = None
156
- tensor = None
157
- gc.collect()
158
- empty_cache()
159
-
160
- for name in model.state_dict():
161
- if name not in linear_weights:
162
- set_module_tensor_to_device(
163
- model, name, device, value=compressed_state_dict[name]
164
- )
165
- apply_compressed_weight(model, compressed_state_dict, device)
166
-
167
- model.to(device)
168
-
169
- return model, tokenizer
170
-
171
-
172
- def compress(tensor, config):
173
- """Simulate group-wise quantization."""
174
- if not config.enabled:
175
- return tensor
176
-
177
- group_size, num_bits, group_dim, symmetric = (
178
- config.group_size,
179
- config.num_bits,
180
- config.group_dim,
181
- config.symmetric,
182
- )
183
- assert num_bits <= 8
184
-
185
- original_shape = tensor.shape
186
- num_groups = (original_shape[group_dim] + group_size - 1) // group_size
187
- new_shape = (
188
- original_shape[:group_dim]
189
- + (num_groups, group_size)
190
- + original_shape[group_dim + 1 :]
191
- )
192
-
193
- # Pad
194
- pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
195
- if pad_len != 0:
196
- pad_shape = (
197
- original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
198
- )
199
- tensor = torch.cat(
200
- [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
201
- dim=group_dim,
202
- )
203
- data = tensor.view(new_shape)
204
-
205
- # Quantize
206
- if symmetric:
207
- B = 2 ** (num_bits - 1) - 1
208
- scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
209
- data = data * scale
210
- data = data.clamp_(-B, B).round_().to(torch.int8)
211
- return data, scale, original_shape
212
- else:
213
- B = 2**num_bits - 1
214
- mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
215
- mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
216
-
217
- scale = B / (mx - mn)
218
- data = data - mn
219
- data.mul_(scale)
220
-
221
- data = data.clamp_(0, B).round_().to(torch.uint8)
222
- return data, mn, scale, original_shape
223
-
224
-
225
- def decompress(packed_data, config):
226
- """Simulate group-wise dequantization."""
227
- if not config.enabled:
228
- return packed_data
229
-
230
- group_size, _, group_dim, symmetric = (
231
- config.group_size,
232
- config.num_bits,
233
- config.group_dim,
234
- config.symmetric,
235
- )
236
-
237
- # Dequantize
238
- if symmetric:
239
- data, scale, original_shape = packed_data
240
- data = data / scale
241
- else:
242
- data, mn, scale, original_shape = packed_data
243
- data = data / scale
244
- data.add_(mn)
245
-
246
- # Unpad
247
- pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
248
- if pad_len:
249
- padded_original_shape = (
250
- original_shape[:group_dim]
251
- + (original_shape[group_dim] + pad_len,)
252
- + original_shape[group_dim + 1 :]
253
- )
254
- data = data.reshape(padded_original_shape)
255
- indices = [slice(0, x) for x in original_shape]
256
- return data[indices].contiguous()
257
- else:
258
- return data.view(original_shape)