xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +1 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +54 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +24 -3
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +51 -1
- xinference/core/worker.py +315 -47
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +4 -6
- xinference/model/llm/core.py +5 -0
- xinference/model/llm/llama_cpp/core.py +46 -17
- xinference/model/llm/llm_family.json +530 -85
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +572 -1
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/reasoning_parser.py +3 -3
- xinference/model/llm/sglang/core.py +111 -13
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +31 -6
- xinference/model/llm/transformers/deepseek_vl.py +1 -1
- xinference/model/llm/transformers/deepseek_vl2.py +287 -0
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +37 -15
- xinference/model/llm/vllm/core.py +184 -8
- xinference/model/llm/vllm/distributed_executor.py +320 -0
- xinference/model/rerank/core.py +22 -12
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
- xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
- xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
- xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
- xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
- xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
- xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
- xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
- xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
- xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
- xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
- xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
- xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
- xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -13,12 +13,16 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import asyncio
|
|
16
|
+
import itertools
|
|
16
17
|
import json
|
|
17
18
|
import logging
|
|
18
19
|
import multiprocessing
|
|
19
20
|
import os
|
|
21
|
+
import sys
|
|
22
|
+
import threading
|
|
20
23
|
import time
|
|
21
24
|
import uuid
|
|
25
|
+
from functools import partial
|
|
22
26
|
from typing import (
|
|
23
27
|
TYPE_CHECKING,
|
|
24
28
|
Any,
|
|
@@ -27,10 +31,14 @@ from typing import (
|
|
|
27
31
|
List,
|
|
28
32
|
Optional,
|
|
29
33
|
Tuple,
|
|
34
|
+
Type,
|
|
30
35
|
TypedDict,
|
|
31
36
|
Union,
|
|
32
37
|
)
|
|
33
38
|
|
|
39
|
+
import xoscar as xo
|
|
40
|
+
from typing_extensions import NotRequired
|
|
41
|
+
|
|
34
42
|
from ....types import (
|
|
35
43
|
ChatCompletion,
|
|
36
44
|
ChatCompletionChunk,
|
|
@@ -73,6 +81,10 @@ class VLLMModelConfig(TypedDict, total=False):
|
|
|
73
81
|
guided_decoding_backend: Optional[str]
|
|
74
82
|
scheduling_policy: Optional[str]
|
|
75
83
|
reasoning_content: bool
|
|
84
|
+
model_quantization: Optional[str]
|
|
85
|
+
mm_processor_kwargs: NotRequired[dict[str, Any]]
|
|
86
|
+
min_pixels: NotRequired[int]
|
|
87
|
+
max_pixels: NotRequired[int]
|
|
76
88
|
|
|
77
89
|
|
|
78
90
|
class VLLMGenerateConfig(TypedDict, total=False):
|
|
@@ -161,6 +173,9 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
|
|
|
161
173
|
VLLM_SUPPORTED_CHAT_MODELS.append("QwQ-32B")
|
|
162
174
|
VLLM_SUPPORTED_CHAT_MODELS.append("marco-o1")
|
|
163
175
|
VLLM_SUPPORTED_CHAT_MODELS.append("deepseek-r1-distill-qwen")
|
|
176
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("fin-r1")
|
|
177
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("seallms-v3")
|
|
178
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("skywork-or1-preview")
|
|
164
179
|
|
|
165
180
|
if VLLM_INSTALLED and vllm.__version__ >= "0.3.2":
|
|
166
181
|
VLLM_SUPPORTED_CHAT_MODELS.append("gemma-it")
|
|
@@ -196,6 +211,7 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.6.1":
|
|
|
196
211
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("internvl2")
|
|
197
212
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("InternVL2.5")
|
|
198
213
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("InternVL2.5-MPO")
|
|
214
|
+
VLLM_SUPPORTED_VISION_MODEL_LIST.append("InternVL3")
|
|
199
215
|
|
|
200
216
|
if VLLM_INSTALLED and vllm.__version__ >= "0.6.2":
|
|
201
217
|
VLLM_SUPPORTED_CHAT_MODELS.append("minicpm3-4b")
|
|
@@ -220,6 +236,9 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.8.0":
|
|
|
220
236
|
VLLM_SUPPORTED_CHAT_MODELS.append("gemma-3-1b-it")
|
|
221
237
|
VLLM_SUPPORTED_VISION_MODEL_LIST.append("gemma-3-it")
|
|
222
238
|
|
|
239
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.8.4":
|
|
240
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("glm4-0414")
|
|
241
|
+
|
|
223
242
|
|
|
224
243
|
class VLLMModel(LLM):
|
|
225
244
|
def __init__(
|
|
@@ -248,15 +267,59 @@ class VLLMModel(LLM):
|
|
|
248
267
|
self.lora_modules = peft_model
|
|
249
268
|
self.lora_requests: List[LoRARequest] = []
|
|
250
269
|
self._xavier_config = None
|
|
270
|
+
# distributed inference
|
|
271
|
+
self._device_count = None
|
|
272
|
+
self._address = model_config.pop("address", None) # type: ignore
|
|
273
|
+
self._n_worker = model_config.pop("n_worker", 1) # type: ignore
|
|
274
|
+
self._shard = model_config.pop("shard", 0) # type: ignore
|
|
275
|
+
self._driver_info = model_config.pop("driver_info", None) # type: ignore
|
|
276
|
+
self._loading_thread: Optional[threading.Thread] = None
|
|
277
|
+
self._loading_error = None
|
|
278
|
+
# variables used for distributed inference and multiple GPUs
|
|
279
|
+
self._pool_addresses = None
|
|
280
|
+
self._worker_addresses: Optional[Dict[int, List[str]]] = None
|
|
281
|
+
self._all_worker_ready: Optional[threading.Event] = None
|
|
282
|
+
# used to call async
|
|
283
|
+
self._loop = None
|
|
251
284
|
|
|
252
285
|
def set_xavier_config(self, value: Optional[Dict]):
|
|
253
286
|
self._xavier_config = value # type: ignore
|
|
254
287
|
|
|
288
|
+
def set_worker_addresses(self, shard: int, worker_addresses: List[str]):
|
|
289
|
+
assert self._worker_addresses is not None
|
|
290
|
+
self._worker_addresses[shard] = worker_addresses
|
|
291
|
+
if (
|
|
292
|
+
self._all_worker_ready is not None
|
|
293
|
+
and len(self._worker_addresses) == self._n_worker
|
|
294
|
+
):
|
|
295
|
+
self._all_worker_ready.set()
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def driver_info(self) -> Optional[dict]:
|
|
299
|
+
return self._driver_info
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def need_create_pools(self):
|
|
303
|
+
return True
|
|
304
|
+
|
|
305
|
+
def set_pool_addresses(self, pool_addresses: List[str]):
|
|
306
|
+
self._pool_addresses = pool_addresses # type: ignore
|
|
307
|
+
|
|
308
|
+
def get_pool_addresses(self) -> Optional[List[str]]:
|
|
309
|
+
return self._pool_addresses
|
|
310
|
+
|
|
311
|
+
def set_loop(self, loop: asyncio.AbstractEventLoop):
|
|
312
|
+
# loop will be passed into XinferenceDistributedExecutor,
|
|
313
|
+
# to call aynsc method with asyncio.run_coroutine_threadsafe
|
|
314
|
+
self._loop = loop # type: ignore
|
|
315
|
+
|
|
255
316
|
def load(self):
|
|
256
317
|
try:
|
|
257
318
|
import vllm
|
|
319
|
+
from vllm.config import VllmConfig
|
|
258
320
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
259
321
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
322
|
+
from vllm.executor.executor_base import ExecutorBase
|
|
260
323
|
from vllm.lora.request import LoRARequest
|
|
261
324
|
except ImportError:
|
|
262
325
|
error_message = "Failed to import module 'vllm'"
|
|
@@ -275,6 +338,7 @@ class VLLMModel(LLM):
|
|
|
275
338
|
# we need to set it to fork to make cupy NCCL work
|
|
276
339
|
multiprocessing.set_start_method("fork", force=True)
|
|
277
340
|
|
|
341
|
+
self._device_count = self._get_cuda_count()
|
|
278
342
|
self._model_config = self._sanitize_model_config(self._model_config)
|
|
279
343
|
reasoning_content = self._model_config.pop("reasoning_content")
|
|
280
344
|
|
|
@@ -320,6 +384,83 @@ class VLLMModel(LLM):
|
|
|
320
384
|
self._engine = XavierEngine.from_engine_args(
|
|
321
385
|
engine_args, xavier_config=self._xavier_config
|
|
322
386
|
)
|
|
387
|
+
elif self._n_worker > 1 or (
|
|
388
|
+
self._device_count > 1 and vllm.__version__ >= "0.7.0"
|
|
389
|
+
):
|
|
390
|
+
from .distributed_executor import XinferenceDistributedExecutor
|
|
391
|
+
|
|
392
|
+
# model across multiple workers or GPUs
|
|
393
|
+
engine_args = AsyncEngineArgs(
|
|
394
|
+
model=self.model_path,
|
|
395
|
+
enable_lora=enable_lora,
|
|
396
|
+
max_loras=max_loras,
|
|
397
|
+
**self._model_config,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
assert self._loop is not None
|
|
401
|
+
self._worker_addresses = {}
|
|
402
|
+
|
|
403
|
+
def _load():
|
|
404
|
+
try:
|
|
405
|
+
assert self._pool_addresses
|
|
406
|
+
|
|
407
|
+
if self._shard > 0:
|
|
408
|
+
assert self._driver_info
|
|
409
|
+
address = self._driver_info["address"]
|
|
410
|
+
|
|
411
|
+
coro = xo.actor_ref(address, self.raw_model_uid)
|
|
412
|
+
model_ref = asyncio.run_coroutine_threadsafe(
|
|
413
|
+
coro, self._loop
|
|
414
|
+
).result()
|
|
415
|
+
coro = model_ref.set_worker_addresses(
|
|
416
|
+
self._shard, self._pool_addresses
|
|
417
|
+
)
|
|
418
|
+
asyncio.run_coroutine_threadsafe(coro, self._loop).result()
|
|
419
|
+
else:
|
|
420
|
+
self.set_worker_addresses(0, self._pool_addresses)
|
|
421
|
+
self._driver_info = {"address": self._address}
|
|
422
|
+
|
|
423
|
+
if self._n_worker > 1:
|
|
424
|
+
self._all_worker_ready = threading.Event()
|
|
425
|
+
# if model across workers, wait for other workers ready
|
|
426
|
+
self._all_worker_ready.wait()
|
|
427
|
+
|
|
428
|
+
# gather all worker addresses
|
|
429
|
+
worker_addresses = list(
|
|
430
|
+
itertools.chain(
|
|
431
|
+
*[
|
|
432
|
+
self._worker_addresses[shard]
|
|
433
|
+
for shard in range(self._n_worker)
|
|
434
|
+
]
|
|
435
|
+
)
|
|
436
|
+
)
|
|
437
|
+
assert worker_addresses
|
|
438
|
+
loop = self._loop
|
|
439
|
+
|
|
440
|
+
class XinferenceAsyncLLMEngine(AsyncLLMEngine):
|
|
441
|
+
@classmethod
|
|
442
|
+
def _get_executor_cls(
|
|
443
|
+
cls, engine_config: VllmConfig
|
|
444
|
+
) -> Type[ExecutorBase]:
|
|
445
|
+
return partial( # type: ignore
|
|
446
|
+
XinferenceDistributedExecutor,
|
|
447
|
+
pool_addresses=worker_addresses,
|
|
448
|
+
n_worker=self._n_worker,
|
|
449
|
+
loop=loop,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
self._engine = XinferenceAsyncLLMEngine.from_engine_args(
|
|
453
|
+
engine_args
|
|
454
|
+
)
|
|
455
|
+
except:
|
|
456
|
+
logger.exception("Creating vllm engine failed")
|
|
457
|
+
self._loading_error = sys.exc_info()
|
|
458
|
+
|
|
459
|
+
self._loading_thread = threading.Thread(target=_load)
|
|
460
|
+
self._loading_thread.start()
|
|
461
|
+
# wait some time for init finish
|
|
462
|
+
if self._shard == 0:
|
|
463
|
+
self._loading_thread.join(1)
|
|
323
464
|
else:
|
|
324
465
|
engine_args = AsyncEngineArgs(
|
|
325
466
|
model=self.model_path,
|
|
@@ -332,7 +473,14 @@ class VLLMModel(LLM):
|
|
|
332
473
|
self._check_health_task = None
|
|
333
474
|
if hasattr(self._engine, "check_health"):
|
|
334
475
|
# vLLM introduced `check_health` since v0.4.1
|
|
335
|
-
self._check_health_task =
|
|
476
|
+
self._check_health_task = self._loop.create_task(self._check_healthy())
|
|
477
|
+
|
|
478
|
+
def wait_for_load(self):
|
|
479
|
+
if self._loading_thread:
|
|
480
|
+
self._loading_thread.join()
|
|
481
|
+
if self._loading_error:
|
|
482
|
+
_, err, tb = self._loading_error
|
|
483
|
+
raise err.with_traceback(tb)
|
|
336
484
|
|
|
337
485
|
def stop(self):
|
|
338
486
|
# though the vLLM engine will shutdown when deleted,
|
|
@@ -341,9 +489,10 @@ class VLLMModel(LLM):
|
|
|
341
489
|
logger.info("Stopping vLLM engine")
|
|
342
490
|
if self._check_health_task:
|
|
343
491
|
self._check_health_task.cancel()
|
|
344
|
-
if
|
|
345
|
-
model_executor.
|
|
346
|
-
|
|
492
|
+
if self._engine:
|
|
493
|
+
if model_executor := getattr(self._engine.engine, "model_executor", None):
|
|
494
|
+
model_executor.shutdown()
|
|
495
|
+
self._engine = None
|
|
347
496
|
|
|
348
497
|
async def init_xavier(self):
|
|
349
498
|
await self._engine.init_xavier()
|
|
@@ -374,22 +523,49 @@ class VLLMModel(LLM):
|
|
|
374
523
|
if model_config is None:
|
|
375
524
|
model_config = VLLMModelConfig()
|
|
376
525
|
|
|
377
|
-
cuda_count = self._get_cuda_count()
|
|
378
|
-
|
|
379
526
|
model_config.setdefault("tokenizer_mode", "auto")
|
|
380
527
|
model_config.setdefault("trust_remote_code", True)
|
|
381
|
-
model_config.setdefault("tensor_parallel_size",
|
|
528
|
+
model_config.setdefault("tensor_parallel_size", self._device_count) # type: ignore
|
|
529
|
+
model_config.setdefault("pipeline_parallel_size", self._n_worker) # type: ignore
|
|
382
530
|
model_config.setdefault("block_size", 16)
|
|
383
531
|
model_config.setdefault("swap_space", 4)
|
|
384
532
|
model_config.setdefault("gpu_memory_utilization", 0.90)
|
|
385
533
|
model_config.setdefault("max_num_seqs", 256)
|
|
386
|
-
|
|
534
|
+
if "model_quantization" in model_config:
|
|
535
|
+
model_config["quantization"] = model_config.pop("model_quantization")
|
|
536
|
+
else:
|
|
537
|
+
model_config.setdefault("quantization", None)
|
|
387
538
|
model_config.setdefault("max_model_len", None)
|
|
388
539
|
model_config.setdefault("guided_decoding_backend", "outlines")
|
|
389
540
|
model_config.setdefault("reasoning_content", False)
|
|
390
541
|
# Add scheduling policy if vLLM version is 0.6.3 or higher
|
|
391
542
|
if vllm.__version__ >= "0.6.3":
|
|
392
543
|
model_config.setdefault("scheduling_policy", "fcfs")
|
|
544
|
+
# init mm_processor_kwargs params
|
|
545
|
+
mm_processor_kwargs = model_config.get("mm_processor_kwargs", {})
|
|
546
|
+
if isinstance(mm_processor_kwargs, str):
|
|
547
|
+
try:
|
|
548
|
+
mm_processor_kwargs = json.loads(mm_processor_kwargs)
|
|
549
|
+
except json.JSONDecodeError:
|
|
550
|
+
logger.warning(
|
|
551
|
+
"Failed to parse mm_processor_kwargs as JSON, using default empty dict"
|
|
552
|
+
)
|
|
553
|
+
mm_processor_kwargs = {}
|
|
554
|
+
except Exception as e:
|
|
555
|
+
logger.warning(
|
|
556
|
+
f"Unexpected error parsing mm_processor_kwargs: {e}, using default empty dict"
|
|
557
|
+
)
|
|
558
|
+
mm_processor_kwargs = {}
|
|
559
|
+
pixel_params: Dict[str, int] = {}
|
|
560
|
+
if "min_pixels" in model_config:
|
|
561
|
+
pixel_params["min_pixels"] = model_config.pop("min_pixels")
|
|
562
|
+
if "max_pixels" in model_config:
|
|
563
|
+
pixel_params["max_pixels"] = model_config.pop("max_pixels")
|
|
564
|
+
if pixel_params or mm_processor_kwargs:
|
|
565
|
+
model_config["mm_processor_kwargs"] = {
|
|
566
|
+
**mm_processor_kwargs,
|
|
567
|
+
**pixel_params,
|
|
568
|
+
}
|
|
393
569
|
return model_config
|
|
394
570
|
|
|
395
571
|
@staticmethod
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# Copyright 2022-2025 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from functools import partial
|
|
19
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
|
20
|
+
|
|
21
|
+
import xoscar as xo
|
|
22
|
+
from vllm.executor.executor_base import DistributedExecutorBase
|
|
23
|
+
from vllm.utils import _run_task_with_lock, get_distributed_init_method
|
|
24
|
+
from vllm.worker.worker_base import WorkerWrapperBase
|
|
25
|
+
from xoscar.utils import get_next_port
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from vllm.config import VllmConfig
|
|
29
|
+
from vllm.model_executor.layers.sampler import SamplerOutput
|
|
30
|
+
from vllm.sequence import ExecuteModelRequest
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class WorkerActor(xo.StatelessActor):
|
|
36
|
+
def __init__(self, vllm_config: "VllmConfig", rpc_rank: int = 0, **kwargs):
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self._worker = WorkerWrapperBase(vllm_config, rpc_rank=rpc_rank)
|
|
39
|
+
|
|
40
|
+
async def __post_create__(self):
|
|
41
|
+
try:
|
|
42
|
+
# Change process title for model
|
|
43
|
+
import setproctitle
|
|
44
|
+
|
|
45
|
+
setproctitle.setproctitle(f"Xinf vLLM worker: {self._worker.rpc_rank}")
|
|
46
|
+
except ImportError:
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
def __getattr__(self, item):
|
|
50
|
+
return getattr(self._worker, item)
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def gen_uid(cls, rank):
|
|
54
|
+
return f"VllmWorker_{rank}"
|
|
55
|
+
|
|
56
|
+
def execute_method(self, method: Union[str, Callable], *args, **kwargs):
|
|
57
|
+
logger.debug(
|
|
58
|
+
"Calling method %s in vllm worker %s, args: %s, kwargs: %s",
|
|
59
|
+
method,
|
|
60
|
+
self.uid,
|
|
61
|
+
args,
|
|
62
|
+
kwargs,
|
|
63
|
+
)
|
|
64
|
+
if isinstance(method, str):
|
|
65
|
+
return getattr(self._worker, method)(*args, **kwargs)
|
|
66
|
+
else:
|
|
67
|
+
return method(self._worker, *args, **kwargs)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class WorkerWrapper:
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
loop: asyncio.AbstractEventLoop,
|
|
74
|
+
worker_actor_ref: xo.ActorRefType[WorkerActor],
|
|
75
|
+
):
|
|
76
|
+
self._loop = loop
|
|
77
|
+
self._worker_actor_ref = worker_actor_ref
|
|
78
|
+
|
|
79
|
+
def execute_method(self, method: Union[str, Callable], *args, **kwargs):
|
|
80
|
+
coro = self._worker_actor_ref.execute_method(method, *args, **kwargs)
|
|
81
|
+
return asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
82
|
+
|
|
83
|
+
async def execute_method_async(self, method: Union[str, Callable], *args, **kwargs):
|
|
84
|
+
return await self._worker_actor_ref.execute_method(method, *args, **kwargs)
|
|
85
|
+
|
|
86
|
+
def kill(self):
|
|
87
|
+
coro = xo.destroy_actor(self._worker_actor_ref)
|
|
88
|
+
return asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class XinferenceDistributedExecutor(DistributedExecutorBase):
|
|
92
|
+
"""Xoscar based distributed executor"""
|
|
93
|
+
|
|
94
|
+
use_ray: bool = False
|
|
95
|
+
_loop: asyncio.AbstractEventLoop
|
|
96
|
+
_pool_addresses: List[str]
|
|
97
|
+
_n_worker: int
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
vllm_config: "VllmConfig",
|
|
102
|
+
pool_addresses: List[str],
|
|
103
|
+
n_worker: int,
|
|
104
|
+
loop: asyncio.AbstractEventLoop,
|
|
105
|
+
*args,
|
|
106
|
+
**kwargs,
|
|
107
|
+
):
|
|
108
|
+
self._pool_addresses = pool_addresses
|
|
109
|
+
self._loop = loop
|
|
110
|
+
self._n_worker = n_worker
|
|
111
|
+
self._is_shutdown = False
|
|
112
|
+
super().__init__(vllm_config, *args, **kwargs)
|
|
113
|
+
|
|
114
|
+
def _init_executor(self) -> None:
|
|
115
|
+
# Create the parallel GPU workers.
|
|
116
|
+
world_size = self.parallel_config.world_size
|
|
117
|
+
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
|
118
|
+
|
|
119
|
+
self.driver_worker: Optional[WorkerActor] = None
|
|
120
|
+
# The remaining workers are Xoscar actors
|
|
121
|
+
self.workers: List[WorkerWrapper] = []
|
|
122
|
+
|
|
123
|
+
assert (
|
|
124
|
+
self._pool_addresses and len(self._pool_addresses) == world_size
|
|
125
|
+
), f"Pool addresses(#{len(self._pool_addresses or [])} must be equal to worldsize(#{world_size})"
|
|
126
|
+
|
|
127
|
+
futures = []
|
|
128
|
+
for rank in range(world_size):
|
|
129
|
+
coro = xo.create_actor(
|
|
130
|
+
WorkerActor,
|
|
131
|
+
self.vllm_config,
|
|
132
|
+
rpc_rank=rank,
|
|
133
|
+
address=self._pool_addresses[rank],
|
|
134
|
+
uid=WorkerActor.gen_uid(rank),
|
|
135
|
+
)
|
|
136
|
+
futures.append(asyncio.run_coroutine_threadsafe(coro, self._loop))
|
|
137
|
+
refs = [fut.result() for fut in futures]
|
|
138
|
+
self.workers = [WorkerWrapper(self._loop, ref) for ref in refs[1:]]
|
|
139
|
+
self.driver_worker = WorkerActor(self.vllm_config, rpc_rank=0)
|
|
140
|
+
|
|
141
|
+
def driver_execute_method(*args, **kwargs):
|
|
142
|
+
func = partial(self.driver_worker.execute_method, *args, **kwargs)
|
|
143
|
+
return self._loop.run_in_executor(None, func)
|
|
144
|
+
|
|
145
|
+
self.driver_exec_method = driver_execute_method
|
|
146
|
+
|
|
147
|
+
# Set environment variables for the driver and workers.
|
|
148
|
+
all_args_to_update_environment_variables: List[Dict[str, str]] = [
|
|
149
|
+
dict() for _ in range(world_size)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
for args in all_args_to_update_environment_variables:
|
|
153
|
+
# some carry-over env vars from the driver
|
|
154
|
+
# TODO: refactor platform-specific env vars
|
|
155
|
+
for name in [
|
|
156
|
+
"VLLM_ATTENTION_BACKEND",
|
|
157
|
+
"TPU_CHIPS_PER_HOST_BOUNDS",
|
|
158
|
+
"TPU_HOST_BOUNDS",
|
|
159
|
+
"VLLM_USE_V1",
|
|
160
|
+
"VLLM_TRACE_FUNCTION",
|
|
161
|
+
]:
|
|
162
|
+
if name in os.environ:
|
|
163
|
+
args[name] = os.environ[name]
|
|
164
|
+
|
|
165
|
+
self._env_vars_for_all_workers = all_args_to_update_environment_variables
|
|
166
|
+
|
|
167
|
+
self._run_workers(
|
|
168
|
+
"update_environment_variables", self._env_vars_for_all_workers
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
all_kwargs = []
|
|
172
|
+
distributed_init_method = get_distributed_init_method(
|
|
173
|
+
self._pool_addresses[0].split(":", 1)[0], get_next_port()
|
|
174
|
+
)
|
|
175
|
+
for rank in range(world_size):
|
|
176
|
+
local_rank = rank % (world_size // self._n_worker)
|
|
177
|
+
kwargs = dict(
|
|
178
|
+
vllm_config=self.vllm_config,
|
|
179
|
+
local_rank=local_rank,
|
|
180
|
+
rank=rank,
|
|
181
|
+
distributed_init_method=distributed_init_method,
|
|
182
|
+
is_driver_worker=not self.parallel_config
|
|
183
|
+
or (rank % tensor_parallel_size == 0),
|
|
184
|
+
)
|
|
185
|
+
all_kwargs.append(kwargs)
|
|
186
|
+
self._run_workers("init_worker", all_kwargs)
|
|
187
|
+
self._run_workers("init_device")
|
|
188
|
+
self._run_workers(
|
|
189
|
+
"load_model",
|
|
190
|
+
max_concurrent_workers=self.parallel_config.max_parallel_loading_workers,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
|
194
|
+
# global rank 0. These are the workers that will broadcast to the
|
|
195
|
+
# rest of the workers.
|
|
196
|
+
self.tp_driver_workers: List[WorkerWrapper] = []
|
|
197
|
+
# This is the list of workers that are not drivers and not the first
|
|
198
|
+
# worker in a TP group. These are the workers that will be
|
|
199
|
+
# broadcasted to.
|
|
200
|
+
self.non_driver_workers: List[WorkerWrapper] = []
|
|
201
|
+
|
|
202
|
+
# Enforce rank order for correct rank to return final output.
|
|
203
|
+
for index, worker in enumerate(self.workers):
|
|
204
|
+
# The driver worker is rank 0 and not in self.workers.
|
|
205
|
+
rank = index + 1
|
|
206
|
+
if rank % self.parallel_config.tensor_parallel_size == 0:
|
|
207
|
+
self.tp_driver_workers.append(worker)
|
|
208
|
+
else:
|
|
209
|
+
self.non_driver_workers.append(worker)
|
|
210
|
+
|
|
211
|
+
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
|
212
|
+
|
|
213
|
+
def _run_workers(
|
|
214
|
+
self,
|
|
215
|
+
method: Union[str, Callable],
|
|
216
|
+
*args,
|
|
217
|
+
async_run_tensor_parallel_workers_only: bool = False,
|
|
218
|
+
max_concurrent_workers: Optional[int] = None,
|
|
219
|
+
**kwargs,
|
|
220
|
+
) -> Any:
|
|
221
|
+
if max_concurrent_workers:
|
|
222
|
+
raise NotImplementedError("max_concurrent_workers is not supported yet.")
|
|
223
|
+
|
|
224
|
+
workers = self.workers
|
|
225
|
+
if async_run_tensor_parallel_workers_only:
|
|
226
|
+
workers = self.non_driver_workers
|
|
227
|
+
worker_outputs = [
|
|
228
|
+
worker.execute_method(method, *args, **kwargs) for worker in workers
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
if async_run_tensor_parallel_workers_only:
|
|
232
|
+
return worker_outputs
|
|
233
|
+
|
|
234
|
+
driver_worker_outputs = [
|
|
235
|
+
self.driver_worker.execute_method(method, *args, **kwargs) # type: ignore
|
|
236
|
+
]
|
|
237
|
+
return driver_worker_outputs + [output.result() for output in worker_outputs]
|
|
238
|
+
|
|
239
|
+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
|
240
|
+
"""Wait for futures returned from _run_workers() with
|
|
241
|
+
async_run_remote_workers_only to complete."""
|
|
242
|
+
for result in parallel_worker_tasks:
|
|
243
|
+
result.get()
|
|
244
|
+
|
|
245
|
+
def check_health(self) -> None:
|
|
246
|
+
# Assume that the workers are healthy.
|
|
247
|
+
# TODO: check the health by checking if the workers all alive
|
|
248
|
+
return
|
|
249
|
+
|
|
250
|
+
def shutdown(self) -> None:
|
|
251
|
+
if self._is_shutdown:
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
self._is_shutdown = True
|
|
256
|
+
futs = [worker.kill() for worker in self.workers]
|
|
257
|
+
_ = [fut.result() for fut in futs]
|
|
258
|
+
except (RuntimeError, ConnectionError, xo.ActorNotExist):
|
|
259
|
+
# event loop closed already, ignore
|
|
260
|
+
# or actor already removed
|
|
261
|
+
pass
|
|
262
|
+
|
|
263
|
+
def __del__(self):
|
|
264
|
+
return self.shutdown()
|
|
265
|
+
|
|
266
|
+
def _driver_execute_model(
|
|
267
|
+
self, execute_model_req: Optional["ExecuteModelRequest"]
|
|
268
|
+
) -> Optional[List["SamplerOutput"]]:
|
|
269
|
+
return self.driver_worker.execute_method("execute_model", execute_model_req) # type: ignore
|
|
270
|
+
|
|
271
|
+
async def _driver_execute_model_async(
|
|
272
|
+
self,
|
|
273
|
+
execute_model_req: Optional["ExecuteModelRequest"] = None,
|
|
274
|
+
) -> List["SamplerOutput"]:
|
|
275
|
+
if not self.tp_driver_workers:
|
|
276
|
+
return await self.driver_exec_method("execute_model", execute_model_req)
|
|
277
|
+
|
|
278
|
+
if self.pp_locks is None:
|
|
279
|
+
# This locks each pipeline parallel stage so multiple virtual
|
|
280
|
+
# engines can't execute on the same stage at the same time
|
|
281
|
+
# We create the locks here to avoid creating them in the constructor
|
|
282
|
+
# which uses a different asyncio loop.
|
|
283
|
+
self.pp_locks = [
|
|
284
|
+
asyncio.Lock()
|
|
285
|
+
for _ in range(self.parallel_config.pipeline_parallel_size)
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
tasks = [
|
|
289
|
+
asyncio.create_task(
|
|
290
|
+
_run_task_with_lock(
|
|
291
|
+
self.driver_exec_method,
|
|
292
|
+
self.pp_locks[0],
|
|
293
|
+
"execute_model",
|
|
294
|
+
execute_model_req,
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
]
|
|
298
|
+
for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1):
|
|
299
|
+
tasks.append(
|
|
300
|
+
asyncio.create_task(
|
|
301
|
+
_run_task_with_lock(
|
|
302
|
+
driver_worker.execute_method_async,
|
|
303
|
+
self.pp_locks[pp_rank],
|
|
304
|
+
"execute_model",
|
|
305
|
+
execute_model_req,
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
results = await asyncio.gather(*tasks)
|
|
311
|
+
|
|
312
|
+
# Only the last PP stage has the final results.
|
|
313
|
+
return results[-1]
|
|
314
|
+
|
|
315
|
+
async def _start_worker_execution_loop(self):
|
|
316
|
+
coros = [
|
|
317
|
+
worker.execute_method_async("start_worker_execution_loop")
|
|
318
|
+
for worker in self.non_driver_workers
|
|
319
|
+
]
|
|
320
|
+
return await asyncio.gather(*coros)
|
xinference/model/rerank/core.py
CHANGED
|
@@ -29,7 +29,7 @@ import torch.nn as nn
|
|
|
29
29
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
30
30
|
from ...device_utils import empty_cache
|
|
31
31
|
from ...types import Document, DocumentObj, Rerank, RerankTokens
|
|
32
|
-
from ..core import CacheableModelSpec, ModelDescription
|
|
32
|
+
from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
|
|
33
33
|
from ..utils import is_model_cached
|
|
34
34
|
|
|
35
35
|
logger = logging.getLogger(__name__)
|
|
@@ -56,6 +56,7 @@ class RerankModelSpec(CacheableModelSpec):
|
|
|
56
56
|
model_id: str
|
|
57
57
|
model_revision: Optional[str]
|
|
58
58
|
model_hub: str = "huggingface"
|
|
59
|
+
virtualenv: Optional[VirtualEnvSettings]
|
|
59
60
|
|
|
60
61
|
|
|
61
62
|
class RerankModelDescription(ModelDescription):
|
|
@@ -69,6 +70,10 @@ class RerankModelDescription(ModelDescription):
|
|
|
69
70
|
super().__init__(address, devices, model_path=model_path)
|
|
70
71
|
self._model_spec = model_spec
|
|
71
72
|
|
|
73
|
+
@property
|
|
74
|
+
def spec(self):
|
|
75
|
+
return self._model_spec
|
|
76
|
+
|
|
72
77
|
def to_dict(self):
|
|
73
78
|
return {
|
|
74
79
|
"model_type": "rerank",
|
|
@@ -106,9 +111,10 @@ def generate_rerank_description(model_spec: RerankModelSpec) -> Dict[str, List[D
|
|
|
106
111
|
return res
|
|
107
112
|
|
|
108
113
|
|
|
109
|
-
class _ModelWrapper:
|
|
114
|
+
class _ModelWrapper(nn.Module):
|
|
110
115
|
def __init__(self, module: nn.Module):
|
|
111
|
-
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.model = module
|
|
112
118
|
self._local_data = threading.local()
|
|
113
119
|
|
|
114
120
|
@property
|
|
@@ -116,18 +122,22 @@ class _ModelWrapper:
|
|
|
116
122
|
return getattr(self._local_data, "n_tokens", 0)
|
|
117
123
|
|
|
118
124
|
@n_tokens.setter
|
|
119
|
-
def n_tokens(self,
|
|
120
|
-
self._local_data.n_tokens =
|
|
121
|
-
|
|
122
|
-
def __getattr__(self, attr):
|
|
123
|
-
return getattr(self._module, attr)
|
|
125
|
+
def n_tokens(self, value):
|
|
126
|
+
self._local_data.n_tokens = value
|
|
124
127
|
|
|
125
|
-
def
|
|
126
|
-
attention_mask = kwargs
|
|
128
|
+
def forward(self, **kwargs):
|
|
129
|
+
attention_mask = kwargs.get("attention_mask")
|
|
127
130
|
# when batching, the attention mask 1 means there is a token
|
|
128
131
|
# thus we just sum up it to get the total number of tokens
|
|
129
|
-
|
|
130
|
-
|
|
132
|
+
if attention_mask is not None:
|
|
133
|
+
self.n_tokens += attention_mask.sum().item()
|
|
134
|
+
return self.model(**kwargs)
|
|
135
|
+
|
|
136
|
+
def __getattr__(self, attr):
|
|
137
|
+
try:
|
|
138
|
+
return super().__getattr__(attr)
|
|
139
|
+
except AttributeError:
|
|
140
|
+
return getattr(self.model, attr)
|
|
131
141
|
|
|
132
142
|
|
|
133
143
|
class RerankModel:
|