xinference 0.12.3__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (71) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +6 -6
  3. xinference/client/restful/restful_client.py +0 -2
  4. xinference/core/model.py +21 -4
  5. xinference/core/scheduler.py +2 -0
  6. xinference/core/worker.py +74 -45
  7. xinference/deploy/utils.py +33 -2
  8. xinference/model/llm/__init__.py +5 -0
  9. xinference/model/llm/llm_family.json +240 -1
  10. xinference/model/llm/llm_family.py +32 -8
  11. xinference/model/llm/llm_family_modelscope.json +192 -0
  12. xinference/model/llm/mlx/__init__.py +13 -0
  13. xinference/model/llm/mlx/core.py +408 -0
  14. xinference/model/llm/pytorch/chatglm.py +2 -9
  15. xinference/model/llm/pytorch/cogvlm2.py +206 -21
  16. xinference/model/llm/pytorch/core.py +213 -40
  17. xinference/model/llm/pytorch/glm4v.py +171 -15
  18. xinference/model/llm/pytorch/qwen_vl.py +168 -7
  19. xinference/model/llm/pytorch/utils.py +53 -62
  20. xinference/model/llm/utils.py +24 -5
  21. xinference/model/rerank/core.py +5 -0
  22. xinference/thirdparty/deepseek_vl/serve/__init__.py +13 -0
  23. xinference/thirdparty/deepseek_vl/serve/app_deepseek.py +510 -0
  24. xinference/thirdparty/deepseek_vl/serve/app_modules/__init__.py +13 -0
  25. xinference/thirdparty/deepseek_vl/serve/app_modules/gradio_utils.py +94 -0
  26. xinference/thirdparty/deepseek_vl/serve/app_modules/overwrites.py +81 -0
  27. xinference/thirdparty/deepseek_vl/serve/app_modules/presets.py +96 -0
  28. xinference/thirdparty/deepseek_vl/serve/app_modules/utils.py +229 -0
  29. xinference/thirdparty/deepseek_vl/serve/inference.py +170 -0
  30. xinference/web/ui/build/asset-manifest.json +3 -3
  31. xinference/web/ui/build/index.html +1 -1
  32. xinference/web/ui/build/static/js/main.0fb6f3ab.js +3 -0
  33. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +1 -0
  35. xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +1 -0
  36. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +1 -0
  41. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +1 -0
  42. xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +1 -0
  43. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +1 -0
  44. xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +1 -0
  49. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/METADATA +4 -1
  50. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/RECORD +55 -44
  51. xinference/web/ui/build/static/js/main.77dd47c3.js +0 -3
  52. xinference/web/ui/build/static/js/main.77dd47c3.js.map +0 -1
  53. xinference/web/ui/node_modules/.cache/babel-loader/0cd591866aa345566e0b63fb51ff2043e163a770af6fdc2f3bad395d046353e2.json +0 -1
  54. xinference/web/ui/node_modules/.cache/babel-loader/37c1476717199863bbba1530e3513a9368f8f73001b75b4a85c2075956308027.json +0 -1
  55. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +0 -1
  56. xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/46edc1fe657dfedb2e673148332bb442c6eb98f09f2592c389209e376510afa5.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/82db357f3fd5b32215d747ee593f69ff06c95ad6cde37f71a96c8290aaab64c0.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/bc6da27195ec4607bb472bf61f97c928ad4966fa64e4c2247661bedb7400abba.json +0 -1
  63. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +0 -1
  64. xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/f118f99c22b713c678c1209c4e1dd43fe86e3f6e801a4c0c35d3bbf41fd05fe6.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +0 -1
  67. /xinference/web/ui/build/static/js/{main.77dd47c3.js.LICENSE.txt → main.0fb6f3ab.js.LICENSE.txt} +0 -0
  68. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/LICENSE +0 -0
  69. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/WHEEL +0 -0
  70. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/entry_points.txt +0 -0
  71. {xinference-0.12.3.dist-info → xinference-0.13.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,8 @@ import os
18
18
  from functools import lru_cache
19
19
  from typing import Iterable, Iterator, List, Optional, Tuple, Union
20
20
 
21
+ import torch
22
+
21
23
  from ....core.scheduler import InferenceRequest
22
24
  from ....device_utils import (
23
25
  get_device_preferred_dtype,
@@ -43,7 +45,7 @@ from ...utils import select_device
43
45
  from ..core import LLM
44
46
  from ..llm_family import LLMFamilyV1, LLMSpecV1
45
47
  from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
46
- from .utils import get_context_length, get_max_src_len
48
+ from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
47
49
 
48
50
  logger = logging.getLogger(__name__)
49
51
 
@@ -409,9 +411,171 @@ class PytorchModel(LLM):
409
411
  else:
410
412
  return generator_wrapper(prompt, generate_config)
411
413
 
414
+ def build_prefill_attention_mask(
415
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
416
+ ):
417
+ """
418
+ Build attention mask for prefill phase.
419
+ Padding `0` on the left.
420
+ Note that the parameter `seq_length` is from `input_ids`.
421
+ """
422
+ data = []
423
+ for r in reqs:
424
+ real_len = seq_length - r.padding_len
425
+ x = torch.cat(
426
+ [
427
+ torch.full((r.padding_len,), 0, dtype=torch.long),
428
+ torch.ones((real_len,), dtype=torch.long),
429
+ ]
430
+ )
431
+ data.append(x)
432
+ r.extra_kwargs["attention_mask_seq_len"] = real_len
433
+ return torch.stack(data).to(self._device)
434
+
435
+ def build_decode_attention_mask(
436
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
437
+ ):
438
+ """
439
+ Build attention mask for decode phase.
440
+ Note that the `seq_length` parameter is from merged kv_cache.
441
+ So we need pad `0` on the left again.
442
+ """
443
+ data = []
444
+ for r in reqs:
445
+ r.extra_kwargs["attention_mask_seq_len"] += 1
446
+ attention_mask_seq_len = r.extra_kwargs["attention_mask_seq_len"]
447
+ pad_len = seq_length - attention_mask_seq_len
448
+ x = torch.cat(
449
+ [
450
+ torch.full((pad_len,), 0, dtype=torch.long),
451
+ torch.ones((attention_mask_seq_len,), dtype=torch.long),
452
+ ]
453
+ )
454
+ data.append(x)
455
+ return torch.stack(data).to(self._device)
456
+
457
+ def build_prefill_position_ids(
458
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
459
+ ):
460
+ """
461
+ Build position ids for prefill phase.
462
+ Padding `0` on the left.
463
+ Note that the parameter `seq_length` is from `input_ids`.
464
+ Record the `max_position_id` on request for the decode phase.
465
+ """
466
+ res = []
467
+ for r in reqs:
468
+ real_seq_len = seq_length - r.padding_len
469
+ res.append(
470
+ torch.cat(
471
+ [
472
+ torch.full((r.padding_len,), 0, dtype=torch.long),
473
+ torch.arange(0, real_seq_len, dtype=torch.long),
474
+ ]
475
+ )
476
+ )
477
+ r.extra_kwargs["max_position_id"] = real_seq_len - 1
478
+ return torch.stack(res).to(self._device)
479
+
480
+ def build_decode_position_ids(
481
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
482
+ ):
483
+ """
484
+ Build position ids for decode phase.
485
+ For most models, just let the `max_position_id` in previous step += 1 and use the latest `max_position_id`
486
+ """
487
+ data = []
488
+ for r in reqs:
489
+ r.extra_kwargs["max_position_id"] += 1
490
+ data.append([r.extra_kwargs["max_position_id"]])
491
+ position_ids = torch.as_tensor(data, dtype=torch.long, device=self._device)
492
+ return position_ids
493
+
494
+ def build_prefill_token_type_ids(
495
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
496
+ ):
497
+ """
498
+ Build token_type_ids for prefill phase.
499
+ For most models, this is not required.
500
+ """
501
+ return None
502
+
503
+ def build_decode_token_type_ids(
504
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
505
+ ):
506
+ """
507
+ Build token_type_ids for decode phase.
508
+ For most models, this is not required.
509
+ """
510
+ return None
511
+
512
+ def build_prefill_inputs(self, prompts: List, req_list: List[InferenceRequest]):
513
+ """
514
+ Get inputs for inference. Models may have their own impl.
515
+ """
516
+ assert isinstance(prompts[0], str)
517
+ inputs = self._tokenizer(prompts, padding=False).input_ids
518
+ context_len = self.get_context_len()
519
+ input_ids = torch.as_tensor(
520
+ pad_prefill_tokens(inputs, context_len, req_list), device=self._device
521
+ )
522
+ return input_ids
523
+
524
+ def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
525
+ """
526
+ Get all inputs parameters for prefill phase. Models may have their own impl.
527
+ """
528
+ input_ids = self.build_prefill_inputs(prompts, req_list)
529
+ res = {"input_ids": input_ids}
530
+ batch_size, seq_len = input_ids.shape
531
+ attention_mask = self.build_prefill_attention_mask(
532
+ batch_size, seq_len, req_list
533
+ )
534
+ if attention_mask is not None:
535
+ res["attention_mask"] = attention_mask
536
+ position_ids = self.build_prefill_position_ids(batch_size, seq_len, req_list)
537
+ if position_ids is not None:
538
+ res["position_ids"] = position_ids
539
+ token_type_ids = self.build_prefill_token_type_ids(
540
+ batch_size, seq_len, req_list
541
+ )
542
+ if token_type_ids is not None:
543
+ res["token_type_ids"] = token_type_ids
544
+ return res
545
+
546
+ def build_decode_kwargs(
547
+ self,
548
+ prompts: List,
549
+ req_list: List[InferenceRequest],
550
+ batch_size: int,
551
+ seq_len: int,
552
+ ):
553
+ """
554
+ Get all inputs parameters for decode phase. Models may have their own impl.
555
+ """
556
+ res = {"input_ids": torch.as_tensor(prompts, device=self._device)}
557
+ attention_mask = self.build_decode_attention_mask(batch_size, seq_len, req_list)
558
+ if attention_mask is not None:
559
+ res["attention_mask"] = attention_mask
560
+ position_ids = self.build_decode_position_ids(batch_size, seq_len, req_list)
561
+ if position_ids is not None:
562
+ res["position_ids"] = position_ids
563
+ token_type_ids = self.build_decode_token_type_ids(batch_size, seq_len, req_list)
564
+ if token_type_ids is not None:
565
+ res["token_type_ids"] = token_type_ids
566
+ return res
567
+
412
568
  @staticmethod
413
- def require_attention_mask():
414
- return False
569
+ def get_batch_size_and_seq_len_indexes_from_kv() -> Tuple[int, int]:
570
+ """
571
+ From huggingface transformers document, the `pask_key_values` has the shape of
572
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`.
573
+ However, for some models, the shape may be changed.
574
+ """
575
+ return 0, 2
576
+
577
+ def get_dtype(self):
578
+ raise NotImplementedError("Not implemented.")
415
579
 
416
580
  @lru_cache
417
581
  def get_context_len(self):
@@ -426,28 +590,38 @@ class PytorchModel(LLM):
426
590
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
427
591
  # check some parameters
428
592
  for r in req_list:
429
- if r.sanitized_generate_config is None:
430
- r.sanitized_generate_config = self.prepare_sanitize_generate_config(r)
431
- if r.is_prefill:
432
- # check some generate params
433
- max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
434
- if max_src_len < 0:
435
- r.stopped = True
436
- r.error_msg = "Max tokens exceeds model's max length"
437
- continue
438
- if r.stream_interval <= 0:
439
- r.stopped = True
440
- r.error_msg = "`stream_interval` must be greater than 0"
441
- continue
442
- stop_str = r.sanitized_generate_config.get("stop", None)
443
- if stop_str and (
444
- not (isinstance(stop_str, str) or isinstance(stop_str, Iterable))
445
- ):
446
- r.stopped = True
447
- r.error_msg = "Invalid `stop` field type"
448
- continue
449
-
450
- def _get_builtin_stop_token_ids(self) -> Tuple:
593
+ try:
594
+ if r.sanitized_generate_config is None:
595
+ r.sanitized_generate_config = self.prepare_sanitize_generate_config(
596
+ r
597
+ )
598
+ if r.is_prefill:
599
+ # check some generate params
600
+ max_src_len = get_max_src_len(self.get_context_len(), r) # type: ignore
601
+ if max_src_len < 0:
602
+ r.stopped = True
603
+ r.error_msg = "Max tokens exceeds model's max length"
604
+ continue
605
+ if r.stream_interval <= 0:
606
+ r.stopped = True
607
+ r.error_msg = "`stream_interval` must be greater than 0"
608
+ continue
609
+ stop_str = r.sanitized_generate_config.get("stop", None)
610
+ if stop_str and (
611
+ not (
612
+ isinstance(stop_str, str) or isinstance(stop_str, Iterable)
613
+ )
614
+ ):
615
+ r.stopped = True
616
+ r.error_msg = "Invalid `stop` field type"
617
+ continue
618
+ # Catch exception here. If not catch exception, the request would hang.
619
+ except Exception as e:
620
+ logger.exception(f"prepare inference error with {e}")
621
+ r.stopped = True
622
+ r.error_msg = str(e)
623
+
624
+ def get_builtin_stop_token_ids(self) -> Tuple:
451
625
  return (
452
626
  tuple(self.model_family.prompt_style.stop_token_ids)
453
627
  if self.model_family.prompt_style
@@ -494,17 +668,8 @@ class PytorchModel(LLM):
494
668
  from .utils import batch_inference_one_step
495
669
 
496
670
  self.prepare_batch_inference(req_list)
497
- context_len = self.get_context_len()
498
- assert isinstance(context_len, int)
499
671
  batch_inference_one_step(
500
- req_list,
501
- self.model_uid,
502
- self._model,
503
- self._tokenizer,
504
- self._device,
505
- context_len,
506
- self._get_builtin_stop_token_ids(),
507
- require_attention_mask=self.require_attention_mask(),
672
+ self, req_list, self.model_uid, self._model, self._tokenizer
508
673
  )
509
674
  self.handle_batch_inference_results(req_list)
510
675
 
@@ -696,14 +861,20 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
696
861
  def prepare_batch_inference(self, req_list: List[InferenceRequest]):
697
862
  super().prepare_batch_inference(req_list)
698
863
  for r in req_list:
699
- r.full_prompt = self._get_full_prompt(
700
- r.prompt, r.system_prompt, r.chat_history, None
701
- )
864
+ try:
865
+ if not r.stopped and r.is_prefill:
866
+ r.full_prompt = self._get_full_prompt(
867
+ r.prompt, r.system_prompt, r.chat_history, None
868
+ )
869
+ except Exception as e:
870
+ logger.exception(f"prepare inference error with {e}")
871
+ r.stopped = True
872
+ r.error_msg = str(e)
702
873
 
703
874
  def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
704
875
  for req in req_list:
705
- if req.stream and req.error_msg is None:
706
- if req.completion:
876
+ if req.error_msg is None and req.completion:
877
+ if req.stream:
707
878
  results = []
708
879
  for i, c in enumerate(req.completion):
709
880
  if c == "<bos_stream>":
@@ -722,3 +893,5 @@ class PytorchChatModel(PytorchModel, ChatModelMixin):
722
893
  self._get_final_chat_completion_chunk(req.completion[-1])
723
894
  )
724
895
  req.completion = results
896
+ else:
897
+ req.completion[0] = self._to_chat_completion(req.completion[0])
@@ -14,6 +14,7 @@
14
14
  import base64
15
15
  import logging
16
16
  import time
17
+ import typing
17
18
  import uuid
18
19
  from concurrent.futures import ThreadPoolExecutor
19
20
  from io import BytesIO
@@ -24,6 +25,7 @@ import requests
24
25
  import torch
25
26
  from PIL import Image
26
27
 
28
+ from ....core.scheduler import InferenceRequest
27
29
  from ....types import (
28
30
  ChatCompletion,
29
31
  ChatCompletionChunk,
@@ -36,6 +38,7 @@ from ....types import (
36
38
  from ...utils import select_device
37
39
  from ..llm_family import LLMFamilyV1, LLMSpecV1
38
40
  from .core import PytorchChatModel, PytorchGenerateConfig
41
+ from .utils import get_max_src_len
39
42
 
40
43
  logger = logging.getLogger(__name__)
41
44
 
@@ -69,7 +72,6 @@ class Glm4VModel(PytorchChatModel):
69
72
  if quantization != "none":
70
73
  if self._device == "cuda" and self._is_linux():
71
74
  kwargs["device_map"] = "auto"
72
- self._device = "auto"
73
75
  if quantization == "4-bit":
74
76
  kwargs["load_in_4bit"] = True
75
77
  elif quantization == "8-bit":
@@ -137,9 +139,6 @@ class Glm4VModel(PytorchChatModel):
137
139
  fut = executor.submit(_load_image, image_url)
138
140
  image_futures.append(fut)
139
141
  images = [fut.result() for fut in image_futures]
140
- # images = []
141
- # for image_url in image_urls:
142
- # images.append(_load_image(image_url))
143
142
  text = " ".join(texts)
144
143
  if len(images) == 0:
145
144
  return text, []
@@ -149,19 +148,11 @@ class Glm4VModel(PytorchChatModel):
149
148
  raise RuntimeError("Only one image per message is supported")
150
149
  return content, []
151
150
 
152
- def chat(
151
+ def _get_chat_msgs(
153
152
  self,
154
153
  prompt: Union[str, List[Dict]],
155
- system_prompt: Optional[str] = None,
156
154
  chat_history: Optional[List[ChatCompletionMessage]] = None,
157
- generate_config: Optional[PytorchGenerateConfig] = None,
158
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
159
- from transformers import TextIteratorStreamer
160
-
161
- if not generate_config:
162
- generate_config = {}
163
-
164
- stream = generate_config.get("stream", False)
155
+ ):
165
156
  content, images_chat = self._message_content_to_chat(prompt)
166
157
 
167
158
  msgs = []
@@ -170,7 +161,7 @@ class Glm4VModel(PytorchChatModel):
170
161
  for h in chat_history or []:
171
162
  role = h["role"]
172
163
  content_h, images_tmp = self._message_content_to_chat(h["content"])
173
- if images_tmp != []:
164
+ if images_tmp:
174
165
  images_history = images_tmp
175
166
  if len(query_to_response) == 0 and role == "user":
176
167
  query_to_response.append({"role": "user", "content": content_h})
@@ -185,6 +176,22 @@ class Glm4VModel(PytorchChatModel):
185
176
  elif len(images_history) > 0:
186
177
  image = images_history[0]
187
178
  msgs.append({"role": "user", "content": content, "image": image})
179
+ return msgs
180
+
181
+ def chat(
182
+ self,
183
+ prompt: Union[str, List[Dict]],
184
+ system_prompt: Optional[str] = None,
185
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
186
+ generate_config: Optional[PytorchGenerateConfig] = None,
187
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
188
+ from transformers import TextIteratorStreamer
189
+
190
+ if not generate_config:
191
+ generate_config = {}
192
+
193
+ stream = generate_config.get("stream", False)
194
+ msgs = self._get_chat_msgs(prompt, chat_history)
188
195
 
189
196
  inputs = self._tokenizer.apply_chat_template(
190
197
  msgs,
@@ -282,3 +289,152 @@ class Glm4VModel(PytorchChatModel):
282
289
  )
283
290
  chunk["usage"] = completion_usage
284
291
  yield chunk
292
+
293
+ def _get_full_prompt(self, prompt, system_prompt, chat_history, tools):
294
+ msgs = self._get_chat_msgs(prompt, chat_history)
295
+ inputs = self._tokenizer.apply_chat_template(
296
+ msgs,
297
+ add_generation_prompt=True,
298
+ tokenize=True,
299
+ return_tensors="pt",
300
+ return_dict=True,
301
+ )
302
+ return {
303
+ "input_ids": inputs.input_ids.squeeze(0),
304
+ "images": inputs.images.squeeze(0),
305
+ }
306
+
307
+ def prepare_sanitize_generate_config(self, req: InferenceRequest):
308
+ """
309
+ Refer to https://huggingface.co/THUDM/glm-4v-9b/blob/main/generation_config.json
310
+ """
311
+ raw_config = req.inference_kwargs.get("raw_params", {})
312
+ temperature = raw_config.get("temperature", None)
313
+ if temperature is None:
314
+ raw_config["temperature"] = 0.8
315
+ top_p = raw_config.get("top_p", None)
316
+ if top_p is None:
317
+ raw_config["top_p"] = 0.8
318
+ return raw_config
319
+
320
+ def build_prefill_inputs(self, prompts: List, req_list: List[InferenceRequest]):
321
+ context_len = self.get_context_len()
322
+ assert isinstance(prompts[0], dict)
323
+ images = []
324
+ max_length = float("-inf")
325
+ for i, feature in enumerate(prompts):
326
+ req = req_list[i]
327
+ if "images" in feature:
328
+ images.append(feature.pop("images", None))
329
+ max_src_len = get_max_src_len(context_len, req)
330
+ input_ids = feature["input_ids"][-max_src_len:]
331
+ req.prompt_tokens = input_ids.tolist()
332
+ feature["input_ids"] = input_ids
333
+ max_length = max(len(input_ids), max_length)
334
+
335
+ def pad_to_max_length_internal(feature, max_len, idx):
336
+ padding_length = max_len - len(feature["input_ids"])
337
+ req_list[idx].padding_len = padding_length
338
+ feature["input_ids"] = torch.cat(
339
+ [torch.full((padding_length,), 0), feature["input_ids"]]
340
+ )
341
+ return feature
342
+
343
+ features = [
344
+ pad_to_max_length_internal(feature, max_length, i)
345
+ for i, feature in enumerate(prompts)
346
+ ]
347
+ batch = {
348
+ key: torch.stack([feature[key] for feature in features])
349
+ for key in features[0].keys()
350
+ }
351
+ if images:
352
+ batch["images"] = torch.stack(images).to(self._device)
353
+ batch["input_ids"] = batch["input_ids"].to(self._device)
354
+ return batch
355
+
356
+ @staticmethod
357
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
358
+ """
359
+ Copied from https://huggingface.co/THUDM/glm-4v-9b/blob/main/modeling_chatglm.py
360
+ """
361
+ if images_list is None or len(images_list) == 0:
362
+ return True
363
+ for image_list in images_list:
364
+ if image_list is not None:
365
+ return False
366
+ return True
367
+
368
+ @typing.no_type_check
369
+ def get_full_attention_mask(
370
+ self, attention_mask, input_ids, images, req_list: List[InferenceRequest]
371
+ ):
372
+ """
373
+ Modified according to https://huggingface.co/THUDM/glm-4v-9b/blob/main/modeling_chatglm.py
374
+ """
375
+ image_size: int = self._model.config.vision_config["image_size"]
376
+ patch_size: int = self._model.config.vision_config["patch_size"]
377
+ num_patches = (image_size // patch_size // 2) ** 2
378
+ new_attention_masks = []
379
+
380
+ # if not image, use this default id
381
+ eoi_token_pos = 6
382
+ boi_token_pos = 4
383
+
384
+ for i in range(len(input_ids)):
385
+ input_id = input_ids[i].tolist()
386
+ req = req_list[i]
387
+ if not self.is_empty(images):
388
+ _boi_token_pos, _eoi_token_pos = input_id.index(
389
+ self._model.config.boi_token_id
390
+ ), input_id.index(self._model.config.eoi_token_id)
391
+ else:
392
+ _boi_token_pos = boi_token_pos + req.padding_len
393
+ _eoi_token_pos = eoi_token_pos + req.padding_len
394
+ assert eoi_token_pos - boi_token_pos == 2
395
+ new_attention_masks.append(
396
+ torch.cat(
397
+ (
398
+ attention_mask[i, : _boi_token_pos + 1],
399
+ attention_mask.new_ones(num_patches),
400
+ attention_mask[i, _eoi_token_pos:],
401
+ )
402
+ )
403
+ )
404
+ attention_mask = torch.stack(new_attention_masks, dim=0).to(self._device)
405
+ return attention_mask
406
+
407
+ def build_prefill_kwargs(self, prompts: List, req_list: List[InferenceRequest]):
408
+ batch = self.build_prefill_inputs(prompts, req_list)
409
+ batch_size, seq_len = batch["input_ids"].shape
410
+ attention_mask = self.build_prefill_attention_mask(
411
+ batch_size, seq_len, req_list
412
+ )
413
+ if attention_mask is not None:
414
+ full_attention_mask = self.get_full_attention_mask(
415
+ attention_mask, batch["input_ids"], batch["images"], req_list
416
+ )
417
+ batch["attention_mask"] = full_attention_mask
418
+ for r in req_list:
419
+ r.extra_kwargs["attention_mask_seq_len"] = full_attention_mask.shape[1]
420
+ position_ids = self.build_prefill_position_ids(batch_size, seq_len, req_list)
421
+ if position_ids is not None:
422
+ batch["position_ids"] = position_ids
423
+ return batch
424
+
425
+ def build_decode_attention_mask(
426
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
427
+ ):
428
+ max_seq_len = max(r.extra_kwargs["attention_mask_seq_len"] for r in reqs)
429
+
430
+ new_attention_mask = []
431
+ for r in reqs:
432
+ attn_mask_seq_len = r.extra_kwargs["attention_mask_seq_len"]
433
+ pad_len = max_seq_len - attn_mask_seq_len
434
+ new_attention_mask.append(
435
+ torch.cat(
436
+ [torch.full((pad_len,), 0), torch.ones((attn_mask_seq_len + 1,))]
437
+ )
438
+ )
439
+ r.extra_kwargs["attention_mask_seq_len"] += 1
440
+ return torch.stack(new_attention_mask, dim=0).to(self._device)