xinference 0.12.3__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +6 -6
- xinference/client/restful/restful_client.py +0 -2
- xinference/core/model.py +21 -4
- xinference/core/scheduler.py +2 -0
- xinference/core/worker.py +74 -45
- xinference/deploy/utils.py +33 -2
- xinference/model/llm/__init__.py +5 -0
- xinference/model/llm/llm_family.json +240 -1
- xinference/model/llm/llm_family.py +32 -8
- xinference/model/llm/llm_family_modelscope.json +192 -0
- xinference/model/llm/mlx/__init__.py +13 -0
- xinference/model/llm/mlx/core.py +408 -0
- xinference/model/llm/pytorch/chatglm.py +2 -9
- xinference/model/llm/pytorch/cogvlm2.py +206 -21
- xinference/model/llm/pytorch/core.py +213 -40
- xinference/model/llm/pytorch/glm4v.py +171 -15
- xinference/model/llm/pytorch/qwen_vl.py +168 -7
- xinference/model/llm/pytorch/utils.py +53 -62
- xinference/model/llm/utils.py +24 -5
- xinference/model/rerank/core.py +5 -0
- xinference/thirdparty/deepseek_vl/serve/__init__.py +13 -0
- xinference/thirdparty/deepseek_vl/serve/app_deepseek.py +510 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/__init__.py +13 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/gradio_utils.py +94 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/overwrites.py +81 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/presets.py +96 -0
- xinference/thirdparty/deepseek_vl/serve/app_modules/utils.py +229 -0
- xinference/thirdparty/deepseek_vl/serve/inference.py +170 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/main.0fb6f3ab.js +3 -0
- xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +1 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/METADATA +4 -1
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/RECORD +55 -44
- xinference/web/ui/build/static/js/main.77dd47c3.js +0 -3
- xinference/web/ui/build/static/js/main.77dd47c3.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0cd591866aa345566e0b63fb51ff2043e163a770af6fdc2f3bad395d046353e2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/37c1476717199863bbba1530e3513a9368f8f73001b75b4a85c2075956308027.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46edc1fe657dfedb2e673148332bb442c6eb98f09f2592c389209e376510afa5.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/82db357f3fd5b32215d747ee593f69ff06c95ad6cde37f71a96c8290aaab64c0.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/bc6da27195ec4607bb472bf61f97c928ad4966fa64e4c2247661bedb7400abba.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f118f99c22b713c678c1209c4e1dd43fe86e3f6e801a4c0c35d3bbf41fd05fe6.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +0 -1
- /xinference/web/ui/build/static/js/{main.77dd47c3.js.LICENSE.txt → main.0fb6f3ab.js.LICENSE.txt} +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/LICENSE +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/WHEEL +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/top_level.txt +0 -0
|
@@ -18,6 +18,8 @@ import os
|
|
|
18
18
|
from functools import lru_cache
|
|
19
19
|
from typing import Iterable, Iterator, List, Optional, Tuple, Union
|
|
20
20
|
|
|
21
|
+
import torch
|
|
22
|
+
|
|
21
23
|
from ....core.scheduler import InferenceRequest
|
|
22
24
|
from ....device_utils import (
|
|
23
25
|
get_device_preferred_dtype,
|
|
@@ -43,7 +45,7 @@ from ...utils import select_device
|
|
|
43
45
|
from ..core import LLM
|
|
44
46
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
45
47
|
from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
|
|
46
|
-
from .utils import get_context_length, get_max_src_len
|
|
48
|
+
from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
|
|
47
49
|
|
|
48
50
|
logger = logging.getLogger(__name__)
|
|
49
51
|
|
|
@@ -409,9 +411,171 @@ class PytorchModel(LLM):
|
|
|
409
411
|
else:
|
|
410
412
|
return generator_wrapper(prompt, generate_config)
|
|
411
413
|
|
|
414
|
+
def build_prefill_attention_mask(
|
|
415
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
416
|
+
):
|
|
417
|
+
"""
|
|
418
|
+
Build attention mask for prefill phase.
|
|
419
|
+
Padding `0` on the left.
|
|
420
|
+
Note that the parameter `seq_length` is from `input_ids`.
|
|
421
|
+
"""
|
|
422
|
+
data = []
|
|
423
|
+
for r in reqs:
|
|
424
|
+
real_len = seq_length - r.padding_len
|
|
425
|
+
x = torch.cat(
|
|
426
|
+
[
|
|
427
|
+
torch.full((r.padding_len,), 0, dtype=torch.long),
|
|
428
|
+
torch.ones((real_len,), dtype=torch.long),
|
|
429
|
+
]
|
|
430
|
+
)
|
|
431
|
+
data.append(x)
|
|
432
|
+
r.extra_kwargs["attention_mask_seq_len"] = real_len
|
|
433
|
+
return torch.stack(data).to(self._device)
|
|
434
|
+
|
|
435
|
+
def build_decode_attention_mask(
|
|
436
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
437
|
+
):
|
|
438
|
+
"""
|
|
439
|
+
Build attention mask for decode phase.
|
|
440
|
+
Note that the `seq_length` parameter is from merged kv_cache.
|
|
441
|
+
So we need pad `0` on the left again.
|
|
442
|
+
"""
|
|
443
|
+
data = []
|
|
444
|
+
for r in reqs:
|
|
445
|
+
r.extra_kwargs["attention_mask_seq_len"] += 1
|
|
446
|
+
attention_mask_seq_len = r.extra_kwargs["attention_mask_seq_len"]
|
|
447
|
+
pad_len = seq_length - attention_mask_seq_len
|
|
448
|
+
x = torch.cat(
|
|
449
|
+
[
|
|
450
|
+
torch.full((pad_len,), 0, dtype=torch.long),
|
|
451
|
+
torch.ones((attention_mask_seq_len,), dtype=torch.long),
|
|
452
|
+
]
|
|
453
|
+
)
|
|
454
|
+
data.append(x)
|
|
455
|
+
return torch.stack(data).to(self._device)
|
|
456
|
+
|
|
457
|
+
def build_prefill_position_ids(
|
|
458
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
459
|
+
):
|
|
460
|
+
"""
|
|
461
|
+
Build position ids for prefill phase.
|
|
462
|
+
Padding `0` on the left.
|
|
463
|
+
Note that the parameter `seq_length` is from `input_ids`.
|
|
464
|
+
Record the `max_position_id` on request for the decode phase.
|
|
465
|
+
"""
|
|
466
|
+
res = []
|
|
467
|
+
for r in reqs:
|
|
468
|
+
real_seq_len = seq_length - r.padding_len
|
|
469
|
+
res.append(
|
|
470
|
+
torch.cat(
|
|
471
|
+
[
|
|
472
|
+
torch.full((r.padding_len,), 0, dtype=torch.long),
|
|
473
|
+
torch.arange(0, real_seq_len, dtype=torch.long),
|
|
474
|
+
]
|
|
475
|
+
)
|
|
476
|
+
)
|
|
477
|
+
r.extra_kwargs["max_position_id"] = real_seq_len - 1
|
|
478
|
+
return torch.stack(res).to(self._device)
|
|
479
|
+
|
|
480
|
+
def build_decode_position_ids(
|
|
481
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
482
|
+
):
|
|
483
|
+
"""
|
|
484
|
+
Build position ids for decode phase.
|
|
485
|
+
For most models, just let the `max_position_id` in previous step += 1 and use the latest `max_position_id`
|
|
486
|
+
"""
|
|
487
|
+
data = []
|
|
488
|
+
for r in reqs:
|
|
489
|
+
r.extra_kwargs["max_position_id"] += 1
|
|
490
|
+
data.append([r.extra_kwargs["max_position_id"]])
|
|
491
|
+
position_ids = torch.as_tensor(data, dtype=torch.long, device=self._device)
|
|
492
|
+
return position_ids
|
|
493
|
+
|
|
494
|
+
def build_prefill_token_type_ids(
|
|
495
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
496
|
+
):
|
|
497
|
+
"""
|
|
498
|
+
Build token_type_ids for prefill phase.
|
|
499
|
+
For most models, this is not required.
|
|
500
|
+
"""
|
|
501
|
+
return None
|
|
502
|
+
|
|
503
|
+
def build_decode_token_type_ids(
|
|
504
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
505
|
+
):
|
|
506
|
+
"""
|
|
507
|
+
Build token_type_ids for decode phase.
|
|
508
|
+
For most models, this is not required.
|
|
509
|
+
"""
|
|
510
|
+
return None
|
|
511
|
+
|
|
512
|
+
def build_prefill_inputs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
513
|
+
"""
|
|
514
|
+
Get inputs for inference. Models may have their own impl.
|
|
515
|
+
"""
|
|
516
|
+
assert isinstance(prompts[0], str)
|
|
517
|
+
inputs = self._tokenizer(prompts, padding=False).input_ids
|
|
518
|
+
context_len = self.get_context_len()
|
|
519
|
+
input_ids = torch.as_tensor(
|
|
520
|
+
pad_prefill_tokens(inputs, context_len, req_list), device=self._device
|
|
521
|
+
)
|
|
522
|
+
return input_ids
|
|
523
|
+
|
|
524
|
+
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
525
|
+
"""
|
|
526
|
+
Get all inputs parameters for prefill phase. Models may have their own impl.
|
|
527
|
+
"""
|
|
528
|
+
input_ids = self.build_prefill_inputs(prompts, req_list)
|
|
529
|
+
res = {"input_ids": input_ids}
|
|
530
|
+
batch_size, seq_len = input_ids.shape
|
|
531
|
+
attention_mask = self.build_prefill_attention_mask(
|
|
532
|
+
batch_size, seq_len, req_list
|
|
533
|
+
)
|
|
534
|
+
if attention_mask is not None:
|
|
535
|
+
res["attention_mask"] = attention_mask
|
|
536
|
+
position_ids = self.build_prefill_position_ids(batch_size, seq_len, req_list)
|
|
537
|
+
if position_ids is not None:
|
|
538
|
+
res["position_ids"] = position_ids
|
|
539
|
+
token_type_ids = self.build_prefill_token_type_ids(
|
|
540
|
+
batch_size, seq_len, req_list
|
|
541
|
+
)
|
|
542
|
+
if token_type_ids is not None:
|
|
543
|
+
res["token_type_ids"] = token_type_ids
|
|
544
|
+
return res
|
|
545
|
+
|
|
546
|
+
def build_decode_kwargs(
|
|
547
|
+
self,
|
|
548
|
+
prompts: List,
|
|
549
|
+
req_list: List[InferenceRequest],
|
|
550
|
+
batch_size: int,
|
|
551
|
+
seq_len: int,
|
|
552
|
+
):
|
|
553
|
+
"""
|
|
554
|
+
Get all inputs parameters for decode phase. Models may have their own impl.
|
|
555
|
+
"""
|
|
556
|
+
res = {"input_ids": torch.as_tensor(prompts, device=self._device)}
|
|
557
|
+
attention_mask = self.build_decode_attention_mask(batch_size, seq_len, req_list)
|
|
558
|
+
if attention_mask is not None:
|
|
559
|
+
res["attention_mask"] = attention_mask
|
|
560
|
+
position_ids = self.build_decode_position_ids(batch_size, seq_len, req_list)
|
|
561
|
+
if position_ids is not None:
|
|
562
|
+
res["position_ids"] = position_ids
|
|
563
|
+
token_type_ids = self.build_decode_token_type_ids(batch_size, seq_len, req_list)
|
|
564
|
+
if token_type_ids is not None:
|
|
565
|
+
res["token_type_ids"] = token_type_ids
|
|
566
|
+
return res
|
|
567
|
+
|
|
412
568
|
@staticmethod
|
|
413
|
-
def
|
|
414
|
-
|
|
569
|
+
def get_batch_size_and_seq_len_indexes_from_kv() -> Tuple[int, int]:
|
|
570
|
+
"""
|
|
571
|
+
From huggingface transformers document, the `pask_key_values` has the shape of
|
|
572
|
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
|
|
573
|
+
However, for some models, the shape may be changed.
|
|
574
|
+
"""
|
|
575
|
+
return 0, 2
|
|
576
|
+
|
|
577
|
+
def get_dtype(self):
|
|
578
|
+
raise NotImplementedError("Not implemented.")
|
|
415
579
|
|
|
416
580
|
@lru_cache
|
|
417
581
|
def get_context_len(self):
|
|
@@ -426,28 +590,38 @@ class PytorchModel(LLM):
|
|
|
426
590
|
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
427
591
|
# check some parameters
|
|
428
592
|
for r in req_list:
|
|
429
|
-
|
|
430
|
-
r.sanitized_generate_config
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
if
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
593
|
+
try:
|
|
594
|
+
if r.sanitized_generate_config is None:
|
|
595
|
+
r.sanitized_generate_config = self.prepare_sanitize_generate_config(
|
|
596
|
+
r
|
|
597
|
+
)
|
|
598
|
+
if r.is_prefill:
|
|
599
|
+
# check some generate params
|
|
600
|
+
max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
|
|
601
|
+
if max_src_len < 0:
|
|
602
|
+
r.stopped = True
|
|
603
|
+
r.error_msg = "Max tokens exceeds model's max length"
|
|
604
|
+
continue
|
|
605
|
+
if r.stream_interval <= 0:
|
|
606
|
+
r.stopped = True
|
|
607
|
+
r.error_msg = "`stream_interval` must be greater than 0"
|
|
608
|
+
continue
|
|
609
|
+
stop_str = r.sanitized_generate_config.get("stop", None)
|
|
610
|
+
if stop_str and (
|
|
611
|
+
not (
|
|
612
|
+
isinstance(stop_str, str) or isinstance(stop_str, Iterable)
|
|
613
|
+
)
|
|
614
|
+
):
|
|
615
|
+
r.stopped = True
|
|
616
|
+
r.error_msg = "Invalid `stop` field type"
|
|
617
|
+
continue
|
|
618
|
+
# Catch exception here. If not catch exception, the request would hang.
|
|
619
|
+
except Exception as e:
|
|
620
|
+
logger.exception(f"prepare inference error with {e}")
|
|
621
|
+
r.stopped = True
|
|
622
|
+
r.error_msg = str(e)
|
|
623
|
+
|
|
624
|
+
def get_builtin_stop_token_ids(self) -> Tuple:
|
|
451
625
|
return (
|
|
452
626
|
tuple(self.model_family.prompt_style.stop_token_ids)
|
|
453
627
|
if self.model_family.prompt_style
|
|
@@ -494,17 +668,8 @@ class PytorchModel(LLM):
|
|
|
494
668
|
from .utils import batch_inference_one_step
|
|
495
669
|
|
|
496
670
|
self.prepare_batch_inference(req_list)
|
|
497
|
-
context_len = self.get_context_len()
|
|
498
|
-
assert isinstance(context_len, int)
|
|
499
671
|
batch_inference_one_step(
|
|
500
|
-
req_list,
|
|
501
|
-
self.model_uid,
|
|
502
|
-
self._model,
|
|
503
|
-
self._tokenizer,
|
|
504
|
-
self._device,
|
|
505
|
-
context_len,
|
|
506
|
-
self._get_builtin_stop_token_ids(),
|
|
507
|
-
require_attention_mask=self.require_attention_mask(),
|
|
672
|
+
self, req_list, self.model_uid, self._model, self._tokenizer
|
|
508
673
|
)
|
|
509
674
|
self.handle_batch_inference_results(req_list)
|
|
510
675
|
|
|
@@ -696,14 +861,20 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
696
861
|
def prepare_batch_inference(self, req_list: List[InferenceRequest]):
|
|
697
862
|
super().prepare_batch_inference(req_list)
|
|
698
863
|
for r in req_list:
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
864
|
+
try:
|
|
865
|
+
if not r.stopped and r.is_prefill:
|
|
866
|
+
r.full_prompt = self._get_full_prompt(
|
|
867
|
+
r.prompt, r.system_prompt, r.chat_history, None
|
|
868
|
+
)
|
|
869
|
+
except Exception as e:
|
|
870
|
+
logger.exception(f"prepare inference error with {e}")
|
|
871
|
+
r.stopped = True
|
|
872
|
+
r.error_msg = str(e)
|
|
702
873
|
|
|
703
874
|
def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
|
|
704
875
|
for req in req_list:
|
|
705
|
-
if req.
|
|
706
|
-
if req.
|
|
876
|
+
if req.error_msg is None and req.completion:
|
|
877
|
+
if req.stream:
|
|
707
878
|
results = []
|
|
708
879
|
for i, c in enumerate(req.completion):
|
|
709
880
|
if c == "<bos_stream>":
|
|
@@ -722,3 +893,5 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
|
|
|
722
893
|
self._get_final_chat_completion_chunk(req.completion[-1])
|
|
723
894
|
)
|
|
724
895
|
req.completion = results
|
|
896
|
+
else:
|
|
897
|
+
req.completion[0] = self._to_chat_completion(req.completion[0])
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
import base64
|
|
15
15
|
import logging
|
|
16
16
|
import time
|
|
17
|
+
import typing
|
|
17
18
|
import uuid
|
|
18
19
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
20
|
from io import BytesIO
|
|
@@ -24,6 +25,7 @@ import requests
|
|
|
24
25
|
import torch
|
|
25
26
|
from PIL import Image
|
|
26
27
|
|
|
28
|
+
from ....core.scheduler import InferenceRequest
|
|
27
29
|
from ....types import (
|
|
28
30
|
ChatCompletion,
|
|
29
31
|
ChatCompletionChunk,
|
|
@@ -36,6 +38,7 @@ from ....types import (
|
|
|
36
38
|
from ...utils import select_device
|
|
37
39
|
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
38
40
|
from .core import PytorchChatModel, PytorchGenerateConfig
|
|
41
|
+
from .utils import get_max_src_len
|
|
39
42
|
|
|
40
43
|
logger = logging.getLogger(__name__)
|
|
41
44
|
|
|
@@ -69,7 +72,6 @@ class Glm4VModel(PytorchChatModel):
|
|
|
69
72
|
if quantization != "none":
|
|
70
73
|
if self._device == "cuda" and self._is_linux():
|
|
71
74
|
kwargs["device_map"] = "auto"
|
|
72
|
-
self._device = "auto"
|
|
73
75
|
if quantization == "4-bit":
|
|
74
76
|
kwargs["load_in_4bit"] = True
|
|
75
77
|
elif quantization == "8-bit":
|
|
@@ -137,9 +139,6 @@ class Glm4VModel(PytorchChatModel):
|
|
|
137
139
|
fut = executor.submit(_load_image, image_url)
|
|
138
140
|
image_futures.append(fut)
|
|
139
141
|
images = [fut.result() for fut in image_futures]
|
|
140
|
-
# images = []
|
|
141
|
-
# for image_url in image_urls:
|
|
142
|
-
# images.append(_load_image(image_url))
|
|
143
142
|
text = " ".join(texts)
|
|
144
143
|
if len(images) == 0:
|
|
145
144
|
return text, []
|
|
@@ -149,19 +148,11 @@ class Glm4VModel(PytorchChatModel):
|
|
|
149
148
|
raise RuntimeError("Only one image per message is supported")
|
|
150
149
|
return content, []
|
|
151
150
|
|
|
152
|
-
def
|
|
151
|
+
def _get_chat_msgs(
|
|
153
152
|
self,
|
|
154
153
|
prompt: Union[str, List[Dict]],
|
|
155
|
-
system_prompt: Optional[str] = None,
|
|
156
154
|
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
157
|
-
|
|
158
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
159
|
-
from transformers import TextIteratorStreamer
|
|
160
|
-
|
|
161
|
-
if not generate_config:
|
|
162
|
-
generate_config = {}
|
|
163
|
-
|
|
164
|
-
stream = generate_config.get("stream", False)
|
|
155
|
+
):
|
|
165
156
|
content, images_chat = self._message_content_to_chat(prompt)
|
|
166
157
|
|
|
167
158
|
msgs = []
|
|
@@ -170,7 +161,7 @@ class Glm4VModel(PytorchChatModel):
|
|
|
170
161
|
for h in chat_history or []:
|
|
171
162
|
role = h["role"]
|
|
172
163
|
content_h, images_tmp = self._message_content_to_chat(h["content"])
|
|
173
|
-
if images_tmp
|
|
164
|
+
if images_tmp:
|
|
174
165
|
images_history = images_tmp
|
|
175
166
|
if len(query_to_response) == 0 and role == "user":
|
|
176
167
|
query_to_response.append({"role": "user", "content": content_h})
|
|
@@ -185,6 +176,22 @@ class Glm4VModel(PytorchChatModel):
|
|
|
185
176
|
elif len(images_history) > 0:
|
|
186
177
|
image = images_history[0]
|
|
187
178
|
msgs.append({"role": "user", "content": content, "image": image})
|
|
179
|
+
return msgs
|
|
180
|
+
|
|
181
|
+
def chat(
|
|
182
|
+
self,
|
|
183
|
+
prompt: Union[str, List[Dict]],
|
|
184
|
+
system_prompt: Optional[str] = None,
|
|
185
|
+
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
186
|
+
generate_config: Optional[PytorchGenerateConfig] = None,
|
|
187
|
+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
188
|
+
from transformers import TextIteratorStreamer
|
|
189
|
+
|
|
190
|
+
if not generate_config:
|
|
191
|
+
generate_config = {}
|
|
192
|
+
|
|
193
|
+
stream = generate_config.get("stream", False)
|
|
194
|
+
msgs = self._get_chat_msgs(prompt, chat_history)
|
|
188
195
|
|
|
189
196
|
inputs = self._tokenizer.apply_chat_template(
|
|
190
197
|
msgs,
|
|
@@ -282,3 +289,152 @@ class Glm4VModel(PytorchChatModel):
|
|
|
282
289
|
)
|
|
283
290
|
chunk["usage"] = completion_usage
|
|
284
291
|
yield chunk
|
|
292
|
+
|
|
293
|
+
def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
|
|
294
|
+
msgs = self._get_chat_msgs(prompt, chat_history)
|
|
295
|
+
inputs = self._tokenizer.apply_chat_template(
|
|
296
|
+
msgs,
|
|
297
|
+
add_generation_prompt=True,
|
|
298
|
+
tokenize=True,
|
|
299
|
+
return_tensors="pt",
|
|
300
|
+
return_dict=True,
|
|
301
|
+
)
|
|
302
|
+
return {
|
|
303
|
+
"input_ids": inputs.input_ids.squeeze(0),
|
|
304
|
+
"images": inputs.images.squeeze(0),
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
def prepare_sanitize_generate_config(self, req: InferenceRequest):
|
|
308
|
+
"""
|
|
309
|
+
Refer to https://huggingface.co/THUDM/glm-4v-9b/blob/main/generation_config.json
|
|
310
|
+
"""
|
|
311
|
+
raw_config = req.inference_kwargs.get("raw_params", {})
|
|
312
|
+
temperature = raw_config.get("temperature", None)
|
|
313
|
+
if temperature is None:
|
|
314
|
+
raw_config["temperature"] = 0.8
|
|
315
|
+
top_p = raw_config.get("top_p", None)
|
|
316
|
+
if top_p is None:
|
|
317
|
+
raw_config["top_p"] = 0.8
|
|
318
|
+
return raw_config
|
|
319
|
+
|
|
320
|
+
def build_prefill_inputs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
321
|
+
context_len = self.get_context_len()
|
|
322
|
+
assert isinstance(prompts[0], dict)
|
|
323
|
+
images = []
|
|
324
|
+
max_length = float("-inf")
|
|
325
|
+
for i, feature in enumerate(prompts):
|
|
326
|
+
req = req_list[i]
|
|
327
|
+
if "images" in feature:
|
|
328
|
+
images.append(feature.pop("images", None))
|
|
329
|
+
max_src_len = get_max_src_len(context_len, req)
|
|
330
|
+
input_ids = feature["input_ids"][-max_src_len:]
|
|
331
|
+
req.prompt_tokens = input_ids.tolist()
|
|
332
|
+
feature["input_ids"] = input_ids
|
|
333
|
+
max_length = max(len(input_ids), max_length)
|
|
334
|
+
|
|
335
|
+
def pad_to_max_length_internal(feature, max_len, idx):
|
|
336
|
+
padding_length = max_len - len(feature["input_ids"])
|
|
337
|
+
req_list[idx].padding_len = padding_length
|
|
338
|
+
feature["input_ids"] = torch.cat(
|
|
339
|
+
[torch.full((padding_length,), 0), feature["input_ids"]]
|
|
340
|
+
)
|
|
341
|
+
return feature
|
|
342
|
+
|
|
343
|
+
features = [
|
|
344
|
+
pad_to_max_length_internal(feature, max_length, i)
|
|
345
|
+
for i, feature in enumerate(prompts)
|
|
346
|
+
]
|
|
347
|
+
batch = {
|
|
348
|
+
key: torch.stack([feature[key] for feature in features])
|
|
349
|
+
for key in features[0].keys()
|
|
350
|
+
}
|
|
351
|
+
if images:
|
|
352
|
+
batch["images"] = torch.stack(images).to(self._device)
|
|
353
|
+
batch["input_ids"] = batch["input_ids"].to(self._device)
|
|
354
|
+
return batch
|
|
355
|
+
|
|
356
|
+
@staticmethod
|
|
357
|
+
def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
|
|
358
|
+
"""
|
|
359
|
+
Copied from https://huggingface.co/THUDM/glm-4v-9b/blob/main/modeling_chatglm.py
|
|
360
|
+
"""
|
|
361
|
+
if images_list is None or len(images_list) == 0:
|
|
362
|
+
return True
|
|
363
|
+
for image_list in images_list:
|
|
364
|
+
if image_list is not None:
|
|
365
|
+
return False
|
|
366
|
+
return True
|
|
367
|
+
|
|
368
|
+
@typing.no_type_check
|
|
369
|
+
def get_full_attention_mask(
|
|
370
|
+
self, attention_mask, input_ids, images, req_list: List[InferenceRequest]
|
|
371
|
+
):
|
|
372
|
+
"""
|
|
373
|
+
Modified according to https://huggingface.co/THUDM/glm-4v-9b/blob/main/modeling_chatglm.py
|
|
374
|
+
"""
|
|
375
|
+
image_size: int = self._model.config.vision_config["image_size"]
|
|
376
|
+
patch_size: int = self._model.config.vision_config["patch_size"]
|
|
377
|
+
num_patches = (image_size // patch_size // 2) ** 2
|
|
378
|
+
new_attention_masks = []
|
|
379
|
+
|
|
380
|
+
# if not image, use this default id
|
|
381
|
+
eoi_token_pos = 6
|
|
382
|
+
boi_token_pos = 4
|
|
383
|
+
|
|
384
|
+
for i in range(len(input_ids)):
|
|
385
|
+
input_id = input_ids[i].tolist()
|
|
386
|
+
req = req_list[i]
|
|
387
|
+
if not self.is_empty(images):
|
|
388
|
+
_boi_token_pos, _eoi_token_pos = input_id.index(
|
|
389
|
+
self._model.config.boi_token_id
|
|
390
|
+
), input_id.index(self._model.config.eoi_token_id)
|
|
391
|
+
else:
|
|
392
|
+
_boi_token_pos = boi_token_pos + req.padding_len
|
|
393
|
+
_eoi_token_pos = eoi_token_pos + req.padding_len
|
|
394
|
+
assert eoi_token_pos - boi_token_pos == 2
|
|
395
|
+
new_attention_masks.append(
|
|
396
|
+
torch.cat(
|
|
397
|
+
(
|
|
398
|
+
attention_mask[i, : _boi_token_pos + 1],
|
|
399
|
+
attention_mask.new_ones(num_patches),
|
|
400
|
+
attention_mask[i, _eoi_token_pos:],
|
|
401
|
+
)
|
|
402
|
+
)
|
|
403
|
+
)
|
|
404
|
+
attention_mask = torch.stack(new_attention_masks, dim=0).to(self._device)
|
|
405
|
+
return attention_mask
|
|
406
|
+
|
|
407
|
+
def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
|
|
408
|
+
batch = self.build_prefill_inputs(prompts, req_list)
|
|
409
|
+
batch_size, seq_len = batch["input_ids"].shape
|
|
410
|
+
attention_mask = self.build_prefill_attention_mask(
|
|
411
|
+
batch_size, seq_len, req_list
|
|
412
|
+
)
|
|
413
|
+
if attention_mask is not None:
|
|
414
|
+
full_attention_mask = self.get_full_attention_mask(
|
|
415
|
+
attention_mask, batch["input_ids"], batch["images"], req_list
|
|
416
|
+
)
|
|
417
|
+
batch["attention_mask"] = full_attention_mask
|
|
418
|
+
for r in req_list:
|
|
419
|
+
r.extra_kwargs["attention_mask_seq_len"] = full_attention_mask.shape[1]
|
|
420
|
+
position_ids = self.build_prefill_position_ids(batch_size, seq_len, req_list)
|
|
421
|
+
if position_ids is not None:
|
|
422
|
+
batch["position_ids"] = position_ids
|
|
423
|
+
return batch
|
|
424
|
+
|
|
425
|
+
def build_decode_attention_mask(
|
|
426
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
427
|
+
):
|
|
428
|
+
max_seq_len = max(r.extra_kwargs["attention_mask_seq_len"] for r in reqs)
|
|
429
|
+
|
|
430
|
+
new_attention_mask = []
|
|
431
|
+
for r in reqs:
|
|
432
|
+
attn_mask_seq_len = r.extra_kwargs["attention_mask_seq_len"]
|
|
433
|
+
pad_len = max_seq_len - attn_mask_seq_len
|
|
434
|
+
new_attention_mask.append(
|
|
435
|
+
torch.cat(
|
|
436
|
+
[torch.full((pad_len,), 0), torch.ones((attn_mask_seq_len + 1,))]
|
|
437
|
+
)
|
|
438
|
+
)
|
|
439
|
+
r.extra_kwargs["attention_mask_seq_len"] += 1
|
|
440
|
+
return torch.stack(new_attention_mask, dim=0).to(self._device)
|