xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (94) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +91 -6
  4. xinference/client/restful/restful_client.py +39 -0
  5. xinference/core/model.py +41 -13
  6. xinference/deploy/cmdline.py +3 -1
  7. xinference/deploy/test/test_cmdline.py +56 -0
  8. xinference/isolation.py +24 -0
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +26 -4
  11. xinference/model/audio/f5tts.py +195 -0
  12. xinference/model/audio/fish_speech.py +71 -35
  13. xinference/model/audio/model_spec.json +88 -0
  14. xinference/model/audio/model_spec_modelscope.json +9 -0
  15. xinference/model/audio/whisper_mlx.py +208 -0
  16. xinference/model/embedding/core.py +322 -6
  17. xinference/model/embedding/model_spec.json +8 -1
  18. xinference/model/embedding/model_spec_modelscope.json +9 -1
  19. xinference/model/llm/__init__.py +4 -2
  20. xinference/model/llm/llm_family.json +479 -53
  21. xinference/model/llm/llm_family_modelscope.json +423 -17
  22. xinference/model/llm/mlx/core.py +230 -50
  23. xinference/model/llm/sglang/core.py +2 -0
  24. xinference/model/llm/transformers/chatglm.py +9 -5
  25. xinference/model/llm/transformers/core.py +1 -0
  26. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  27. xinference/model/llm/transformers/utils.py +16 -8
  28. xinference/model/llm/utils.py +23 -1
  29. xinference/model/llm/vllm/core.py +89 -2
  30. xinference/thirdparty/f5_tts/__init__.py +0 -0
  31. xinference/thirdparty/f5_tts/api.py +166 -0
  32. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  33. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  34. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  35. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  36. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  37. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  38. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  39. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  40. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  41. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  42. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  43. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  44. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  45. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  46. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  47. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  48. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  49. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  50. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  51. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  52. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  53. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  54. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  55. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  56. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  57. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  58. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  59. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  60. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  61. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  62. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  63. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  64. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  65. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  66. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  67. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  68. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  69. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  70. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  71. xinference/thirdparty/f5_tts/train/README.md +77 -0
  72. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  73. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  74. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  75. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  76. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  77. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  78. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  79. xinference/thirdparty/f5_tts/train/train.py +75 -0
  80. xinference/types.py +2 -1
  81. xinference/web/ui/build/asset-manifest.json +3 -3
  82. xinference/web/ui/build/index.html +1 -1
  83. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  84. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  86. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
  87. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
  88. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
  89. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  92. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
  93. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
  94. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
@@ -168,6 +168,9 @@ class MLXModel(LLM):
168
168
  return False
169
169
  if "generate" not in llm_family.model_ability:
170
170
  return False
171
+ if "chat" in llm_family.model_ability or "vision" in llm_family.model_ability:
172
+ # do not process chat or vision
173
+ return False
171
174
  return True
172
175
 
173
176
  def _get_prompt_cache(self, prompt, lora_name: Optional[str] = None):
@@ -191,18 +194,35 @@ class MLXModel(LLM):
191
194
  self._prompt_cache.tokens.extend(prompt)
192
195
  return prompt
193
196
 
194
- def _generate_stream(self, prompt: str, kwargs: MLXGenerateConfig):
195
- import mlx.core as mx
196
- from mlx_lm.utils import generate_step
197
+ def _generate_stream_inner(self, **kwargs):
198
+ from mlx_lm.utils import make_sampler, stream_generate
199
+
200
+ sampler = make_sampler(
201
+ temp=kwargs.pop("temperature"), top_p=kwargs.pop("top_p")
202
+ )
203
+ prompt_token_ids = kwargs.pop("prompt_token_ids")
204
+ yield from stream_generate(
205
+ self._model, self._tokenizer, prompt_token_ids, sampler=sampler, **kwargs
206
+ )
207
+
208
+ def _prepare_inputs(
209
+ self, prompt: Union[str, Dict[str, Any]], kwargs
210
+ ) -> Tuple[Any, int]:
211
+ prompt_token_ids = self._tokenizer.encode(prompt)
212
+ prompt_token_ids = self._get_prompt_cache(
213
+ prompt_token_ids, kwargs.get("lora_name")
214
+ )
215
+ return prompt_token_ids, len(prompt_token_ids)
197
216
 
198
- model = self._model
217
+ def _generate_stream(
218
+ self, prompt: Union[str, Dict[str, Any]], kwargs: MLXGenerateConfig
219
+ ):
199
220
  model_uid = self.model_uid
200
221
  tokenizer = self._tokenizer
201
222
  max_tokens = kwargs["max_tokens"]
202
223
  chunk_id = str(uuid.uuid4())
203
224
  stop_token_ids = kwargs.get("stop_token_ids", [])
204
225
  stream = kwargs.get("stream", False)
205
- lora_name = kwargs.get("lora_name")
206
226
  stream_options = kwargs.pop("stream_options", None)
207
227
  include_usage = (
208
228
  stream_options["include_usage"]
@@ -210,40 +230,28 @@ class MLXModel(LLM):
210
230
  else False
211
231
  )
212
232
 
213
- prompt_token_ids = tokenizer.encode(prompt)
214
- prompt_token_ids = self._get_prompt_cache(prompt_token_ids, lora_name)
215
- prompt_tokens = mx.array(prompt_token_ids)
216
- input_echo_len = len(prompt_tokens)
233
+ prompt_token_ids, input_echo_len = self._prepare_inputs(prompt, kwargs)
217
234
 
218
235
  i = 0
219
236
  start = time.time()
220
237
  output = ""
221
238
  tokens = []
222
- for (token, _), i in zip(
223
- generate_step(
224
- prompt_tokens,
225
- model,
226
- temp=kwargs["temperature"],
239
+ for chunk_resp, i in zip(
240
+ self._generate_stream_inner(
241
+ prompt_token_ids=prompt_token_ids,
242
+ max_tokens=max_tokens,
243
+ temperature=kwargs["temperature"],
244
+ top_p=kwargs["top_p"],
227
245
  repetition_penalty=kwargs["repetition_penalty"],
228
246
  repetition_context_size=kwargs["repetition_context_size"],
229
- top_p=kwargs["top_p"],
230
- logit_bias=kwargs["logit_bias"],
231
- prompt_cache=self._prompt_cache.cache, # type: ignore
247
+ prompt_cache=self._prompt_cache.cache if self._prompt_cache else None, # type: ignore
232
248
  ),
233
249
  range(max_tokens),
234
250
  ):
251
+ token = chunk_resp.token
235
252
  tokens.append(token)
236
- if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
237
- break
238
-
239
- # Yield the last segment if streaming
240
- out = tokenizer.decode(
241
- token,
242
- skip_special_tokens=True,
243
- spaces_between_special_tokens=False,
244
- clean_up_tokenization_spaces=True,
245
- )
246
253
 
254
+ out = chunk_resp.text
247
255
  if stream:
248
256
  # this special character is mainly for qwen
249
257
  out = out.strip("�")
@@ -267,11 +275,15 @@ class MLXModel(LLM):
267
275
  total_tokens=(input_echo_len + i),
268
276
  ), completion_usage
269
277
 
278
+ if token == tokenizer.eos_token_id or token in stop_token_ids: # type: ignore
279
+ break
280
+
270
281
  logger.info(
271
282
  f"Average generation speed: {i / (time.time() - start):.2f} tokens/s."
272
283
  )
273
284
 
274
- self._prompt_cache.tokens.extend(tokens) # type: ignore
285
+ if self._prompt_cache:
286
+ self._prompt_cache.tokens.extend(tokens) # type: ignore
275
287
 
276
288
  if i == max_tokens - 1:
277
289
  finish_reason = "length"
@@ -315,10 +327,12 @@ class MLXModel(LLM):
315
327
  yield completion_chunk, completion_usage
316
328
 
317
329
  def generate(
318
- self, prompt: str, generate_config: Optional[MLXGenerateConfig] = None
330
+ self,
331
+ prompt: Union[str, Dict[str, Any]],
332
+ generate_config: Optional[MLXGenerateConfig] = None,
319
333
  ) -> Union[Completion, Iterator[CompletionChunk]]:
320
334
  def generator_wrapper(
321
- prompt: str, generate_config: MLXGenerateConfig
335
+ prompt: Union[str, Dict[str, Any]], generate_config: MLXGenerateConfig
322
336
  ) -> Iterator[CompletionChunk]:
323
337
  for completion_chunk, completion_usage in self._generate_stream(
324
338
  prompt,
@@ -357,26 +371,6 @@ class MLXModel(LLM):
357
371
 
358
372
 
359
373
  class MLXChatModel(MLXModel, ChatModelMixin):
360
- def __init__(
361
- self,
362
- model_uid: str,
363
- model_family: "LLMFamilyV1",
364
- model_spec: "LLMSpecV1",
365
- quantization: str,
366
- model_path: str,
367
- model_config: Optional[MLXModelConfig] = None,
368
- peft_model: Optional[List[LoRA]] = None,
369
- ):
370
- super().__init__(
371
- model_uid,
372
- model_family,
373
- model_spec,
374
- quantization,
375
- model_path,
376
- model_config,
377
- peft_model,
378
- )
379
-
380
374
  def _sanitize_generate_config(
381
375
  self,
382
376
  generate_config: Optional[MLXGenerateConfig],
@@ -403,6 +397,9 @@ class MLXChatModel(MLXModel, ChatModelMixin):
403
397
  return False
404
398
  if "chat" not in llm_family.model_ability:
405
399
  return False
400
+ if "vision" in llm_family.model_ability:
401
+ # do not process vision
402
+ return False
406
403
  return True
407
404
 
408
405
  def chat(
@@ -433,3 +430,186 @@ class MLXChatModel(MLXModel, ChatModelMixin):
433
430
  if tools:
434
431
  return self._tool_calls_completion(self.model_family, self.model_uid, c)
435
432
  return self._to_chat_completion(c)
433
+
434
+
435
+ class MLXVisionModel(MLXModel, ChatModelMixin):
436
+ @classmethod
437
+ def match(
438
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
439
+ ) -> bool:
440
+ if llm_spec.model_format not in ["mlx"]:
441
+ return False
442
+ if sys.platform != "darwin" or platform.processor() != "arm":
443
+ # only work for Mac M chips
444
+ return False
445
+ if "vision" not in llm_family.model_ability:
446
+ return False
447
+ return True
448
+
449
+ def _load_model(self, **kwargs):
450
+ try:
451
+ from mlx_vlm import load
452
+ except ImportError:
453
+ error_message = "Failed to import module 'mlx_vlm'"
454
+ installation_guide = [
455
+ "Please make sure 'mlx_vlm' is installed. ",
456
+ "You can install it by `pip install mlx_vlm`\n",
457
+ ]
458
+
459
+ raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
460
+
461
+ return load(self.model_path)
462
+
463
+ def load(self):
464
+ kwargs = {}
465
+ kwargs["revision"] = self._model_config.get(
466
+ "revision", self.model_spec.model_revision
467
+ )
468
+ kwargs["trust_remote_code"] = self._model_config.get("trust_remote_code")
469
+ kwargs["cache_limit_gb"] = self._model_config.pop("cache_limit_gb", None)
470
+
471
+ self._model, self._processor = self._load_model(**kwargs)
472
+ self._tokenizer = self._processor.tokenizer
473
+
474
+ def _generate_stream_inner(self, **kwargs):
475
+ import mlx.core as mx
476
+ from mlx_lm.utils import GenerationResponse
477
+ from mlx_vlm.utils import generate_step
478
+
479
+ max_tokens = kwargs.pop("max_tokens")
480
+ inputs = kwargs["prompt_token_ids"]
481
+ input_ids, pixel_values, mask = inputs[:3]
482
+
483
+ kwargs = {
484
+ k: v
485
+ for k, v in zip(
486
+ [
487
+ "image_grid_thw",
488
+ "image_sizes",
489
+ "aspect_ratio_ids",
490
+ "aspect_ratio_mask",
491
+ "cross_attention_mask",
492
+ ],
493
+ inputs[3:],
494
+ )
495
+ }
496
+
497
+ tokenizer = self._processor.tokenizer
498
+ detokenizer = self._processor.detokenizer
499
+
500
+ detokenizer.reset()
501
+ tic = time.perf_counter()
502
+ for (token, logprobs), n in zip(
503
+ generate_step(input_ids, self._model, pixel_values, mask, **kwargs),
504
+ range(max_tokens),
505
+ ):
506
+ if n == 0:
507
+ prompt_time = time.perf_counter() - tic
508
+ prompt_tps = len(input_ids) / prompt_time
509
+ tic = time.perf_counter()
510
+ if token == tokenizer.eos_token_id:
511
+ break
512
+ detokenizer.add_token(token)
513
+
514
+ # Yield the last segment if streaming
515
+ yield GenerationResponse(
516
+ text=detokenizer.last_segment,
517
+ token=token,
518
+ logprobs=logprobs,
519
+ prompt_tokens=len(input_ids),
520
+ prompt_tps=prompt_tps,
521
+ generation_tokens=n + 1,
522
+ generation_tps=(n + 1) / (time.perf_counter() - tic),
523
+ peak_memory=mx.metal.get_peak_memory() / 1e9,
524
+ )
525
+
526
+ detokenizer.finalize()
527
+ yield GenerationResponse(
528
+ text=detokenizer.last_segment,
529
+ token=token,
530
+ logprobs=logprobs,
531
+ prompt_tokens=len(input_ids),
532
+ prompt_tps=prompt_tps,
533
+ generation_tokens=n + 1,
534
+ generation_tps=(n + 1) / (time.perf_counter() - tic),
535
+ peak_memory=mx.metal.get_peak_memory() / 1e9,
536
+ )
537
+
538
+ def _prepare_inputs(
539
+ self, prompt: Union[str, Dict[str, Any]], kwargs
540
+ ) -> Tuple[Any, int]:
541
+ from mlx_vlm import prepare_inputs
542
+
543
+ prompt_str = prompt.get("prompt") # type: ignore
544
+ images = prompt.get("multi_modal_data", {}).get("image") # type: ignore
545
+ if images and not isinstance(images, list):
546
+ images = [images]
547
+ if hasattr(self._model.config, "image_token_index"):
548
+ image_token_index = self._model.config.image_token_index
549
+ else:
550
+ image_token_index = None
551
+
552
+ inputs = prepare_inputs(
553
+ None,
554
+ self._processor,
555
+ images,
556
+ prompt_str,
557
+ image_token_index,
558
+ kwargs.get("resize_shape"),
559
+ )
560
+ input_ids = inputs[0]
561
+ return inputs, len(input_ids)
562
+
563
+ def chat(
564
+ self,
565
+ messages: List[Dict],
566
+ generate_config: Optional[MLXGenerateConfig] = None,
567
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
568
+ messages = self._transform_messages(messages) # type: ignore
569
+ tools = generate_config.pop("tools", []) if generate_config else None
570
+
571
+ model_family = self.model_family.model_family or self.model_family.model_name
572
+
573
+ if "internvl2" not in model_family.lower():
574
+ from qwen_vl_utils import process_vision_info
575
+
576
+ full_context_kwargs = {}
577
+ if tools and model_family in QWEN_TOOL_CALL_FAMILY:
578
+ full_context_kwargs["tools"] = tools
579
+ assert self.model_family.chat_template is not None
580
+ prompt = self.get_full_context(
581
+ messages, self.model_family.chat_template, **full_context_kwargs
582
+ )
583
+ images, video_inputs = process_vision_info(messages)
584
+ if video_inputs:
585
+ raise ValueError("Not support video input now.")
586
+ else:
587
+ prompt, images = self.get_specific_prompt(model_family, messages) # type: ignore
588
+
589
+ if not images:
590
+ inputs = {
591
+ "prompt": prompt,
592
+ }
593
+ elif len(images) == 1:
594
+ inputs = {
595
+ "prompt": prompt,
596
+ "multi_modal_data": {"image": images[-1]}, # type: ignore
597
+ }
598
+ else:
599
+ inputs = {
600
+ "prompt": prompt,
601
+ "multi_modal_data": {"image": images}, # type: ignore
602
+ }
603
+ generate_config = self._sanitize_generate_config(generate_config)
604
+
605
+ stream = generate_config.get("stream", False)
606
+ if stream:
607
+ it = self.generate(inputs, generate_config)
608
+ assert isinstance(it, Iterator)
609
+ return self._to_chat_completion_chunks(it)
610
+ else:
611
+ c = self.generate(inputs, generate_config)
612
+ assert not isinstance(c, Iterator)
613
+ if tools:
614
+ return self._tool_calls_completion(self.model_family, self.model_uid, c)
615
+ return self._to_chat_completion(c)
@@ -75,6 +75,7 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
75
75
  "llama-2-chat",
76
76
  "llama-3-instruct",
77
77
  "llama-3.1-instruct",
78
+ "llama-3.3-instruct",
78
79
  "qwen-chat",
79
80
  "qwen1.5-chat",
80
81
  "qwen2-instruct",
@@ -89,6 +90,7 @@ SGLANG_SUPPORTED_CHAT_MODELS = [
89
90
  "deepseek-v2-chat-0628",
90
91
  "qwen2.5-instruct",
91
92
  "qwen2.5-coder-instruct",
93
+ "QwQ-32B-Preview",
92
94
  ]
93
95
 
94
96
 
@@ -61,7 +61,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
61
61
 
62
62
  def _load_model(self, **kwargs):
63
63
  try:
64
- from transformers import AutoModel, AutoTokenizer
64
+ from transformers import AutoModelForCausalLM, AutoTokenizer
65
65
  except ImportError:
66
66
  error_message = "Failed to import module 'transformers'"
67
67
  installation_guide = [
@@ -77,7 +77,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
77
77
  encode_special_tokens=True,
78
78
  revision=kwargs["revision"],
79
79
  )
80
- model = AutoModel.from_pretrained(
80
+ model = AutoModelForCausalLM.from_pretrained(
81
81
  self.model_path,
82
82
  **kwargs,
83
83
  )
@@ -232,9 +232,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
232
232
  content = {
233
233
  "name": function_name,
234
234
  "arguments": json.dumps(
235
- arguments_json
236
- if isinstance(arguments_json, dict)
237
- else arguments,
235
+ (
236
+ arguments_json
237
+ if isinstance(arguments_json, dict)
238
+ else arguments
239
+ ),
238
240
  ensure_ascii=False,
239
241
  ),
240
242
  }
@@ -331,6 +333,8 @@ class ChatglmPytorchChatModel(PytorchChatModel):
331
333
  max_new_tokens = generate_config.get("max_tokens")
332
334
  if max_new_tokens is not None:
333
335
  kwargs["max_new_tokens"] = int(max_new_tokens)
336
+ else:
337
+ kwargs["max_new_tokens"] = 1024
334
338
  do_sample = generate_config.get("do_sample")
335
339
  if do_sample is not None:
336
340
  kwargs["do_sample"] = bool(do_sample)
@@ -68,6 +68,7 @@ NON_DEFAULT_MODEL_LIST: List[str] = [
68
68
  "deepseek-v2-chat",
69
69
  "deepseek-v2.5",
70
70
  "deepseek-v2-chat-0628",
71
+ "glm-edge-v",
71
72
  ]
72
73
 
73
74
 
@@ -0,0 +1,230 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import uuid
16
+ from concurrent.futures import ThreadPoolExecutor
17
+ from threading import Thread
18
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+
22
+ from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
23
+ from ...utils import select_device
24
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
25
+ from ..utils import (
26
+ _decode_image_without_rgb,
27
+ generate_chat_completion,
28
+ generate_completion_chunk,
29
+ )
30
+ from .core import PytorchChatModel, PytorchGenerateConfig
31
+ from .utils import cache_clean
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class GlmEdgeVModel(PytorchChatModel):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self._device = None
40
+ self._tokenizer = None
41
+ self._model = None
42
+ self._processor = None
43
+
44
+ @classmethod
45
+ def match(
46
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
47
+ ) -> bool:
48
+ family = model_family.model_family or model_family.model_name
49
+ if "glm-edge-v" in family.lower():
50
+ return True
51
+ return False
52
+
53
+ def load(self):
54
+ from transformers import AutoImageProcessor, AutoModelForCausalLM, AutoTokenizer
55
+
56
+ device = self._pytorch_model_config.get("device", "auto")
57
+ self._device = select_device(device)
58
+
59
+ kwargs = {"device_map": self._device}
60
+ quantization = self.quantization
61
+
62
+ # referenced from PytorchModel.load
63
+ if quantization != "none":
64
+ if self._device == "cuda" and self._is_linux():
65
+ kwargs["device_map"] = "auto"
66
+ if quantization == "4-bit":
67
+ kwargs["load_in_4bit"] = True
68
+ elif quantization == "8-bit":
69
+ kwargs["load_in_8bit"] = True
70
+ else:
71
+ raise ValueError(
72
+ f"Quantization {quantization} is not supported in temporary"
73
+ )
74
+ else:
75
+ if quantization != "8-bit":
76
+ raise ValueError(
77
+ f"Only 8-bit quantization is supported if it is not linux system or cuda device"
78
+ )
79
+
80
+ processor = AutoImageProcessor.from_pretrained(
81
+ self.model_path, trust_remote_code=True
82
+ )
83
+ self._processor = processor
84
+
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ self.model_path,
87
+ trust_remote_code=True,
88
+ torch_dtype=torch.bfloat16,
89
+ device_map="auto",
90
+ )
91
+
92
+ self._model = model
93
+
94
+ tokenizer = AutoTokenizer.from_pretrained(
95
+ self.model_path, trust_remote_code=True
96
+ )
97
+ self._tokenizer = tokenizer
98
+
99
+ @staticmethod
100
+ def _get_processed_msgs(
101
+ messages: List[Dict],
102
+ ) -> Tuple[List[Dict[str, Any]], List[Any]]:
103
+ res = []
104
+ img = []
105
+ for message in messages:
106
+ role = message["role"]
107
+ content = message["content"]
108
+ if isinstance(content, str):
109
+ res.append({"role": role, "content": content})
110
+ else:
111
+ texts = []
112
+ image_urls = []
113
+ for c in content:
114
+ c_type = c.get("type")
115
+ if c_type == "text":
116
+ texts.append(c["text"])
117
+ else:
118
+ assert (
119
+ c_type == "image_url"
120
+ ), "Please follow the image input of the OpenAI API."
121
+ image_urls.append(c["image_url"]["url"])
122
+ if len(image_urls) > 1:
123
+ raise RuntimeError("Only one image per message is supported")
124
+ image_futures = []
125
+ with ThreadPoolExecutor() as executor:
126
+ for image_url in image_urls:
127
+ fut = executor.submit(_decode_image_without_rgb, image_url)
128
+ image_futures.append(fut)
129
+ images = [fut.result() for fut in image_futures]
130
+ assert len(images) <= 1
131
+ text = " ".join(texts)
132
+ img.extend(images)
133
+ if images:
134
+ res.append(
135
+ {
136
+ "role": role,
137
+ "content": [
138
+ {"type": "image"},
139
+ {"type": "text", "text": text},
140
+ ],
141
+ }
142
+ )
143
+ else:
144
+ res.append({"role": role, "content": text})
145
+ return res, img
146
+
147
+ @cache_clean
148
+ def chat(
149
+ self,
150
+ messages: List[Dict],
151
+ generate_config: Optional[PytorchGenerateConfig] = None,
152
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
153
+ from transformers import TextIteratorStreamer
154
+
155
+ if not generate_config:
156
+ generate_config = {}
157
+
158
+ stream = generate_config.get("stream", False)
159
+ msgs, imgs = self._get_processed_msgs(messages)
160
+
161
+ inputs = self._tokenizer.apply_chat_template(
162
+ msgs,
163
+ add_generation_prompt=True,
164
+ tokenize=True,
165
+ return_tensors="pt",
166
+ return_dict=True,
167
+ ) # chat mode
168
+ inputs = inputs.to(self._model.device)
169
+
170
+ generate_kwargs = {
171
+ **inputs,
172
+ }
173
+ if len(imgs) > 0:
174
+ generate_kwargs["pixel_values"] = torch.tensor(
175
+ self._processor(imgs[-1]).pixel_values
176
+ ).to(self._model.device)
177
+ stop_str = "<|endoftext|>"
178
+
179
+ if stream:
180
+ streamer = TextIteratorStreamer(
181
+ tokenizer=self._tokenizer,
182
+ timeout=60,
183
+ skip_prompt=True,
184
+ skip_special_tokens=True,
185
+ )
186
+ generate_kwargs = {
187
+ **generate_kwargs,
188
+ "streamer": streamer,
189
+ }
190
+ t = Thread(target=self._model.generate, kwargs=generate_kwargs)
191
+ t.start()
192
+
193
+ it = self.chat_stream(streamer, stop_str)
194
+ return self._to_chat_completion_chunks(it)
195
+ else:
196
+ with torch.no_grad():
197
+ outputs = self._model.generate(**generate_kwargs)
198
+ outputs = outputs[0][len(inputs["input_ids"][0]) :]
199
+ response = self._tokenizer.decode(outputs)
200
+ if response.endswith(stop_str):
201
+ response = response[: -len(stop_str)]
202
+ return generate_chat_completion(self.model_uid, response)
203
+
204
+ def chat_stream(self, streamer, stop_str) -> Iterator[CompletionChunk]:
205
+ completion_id = str(uuid.uuid1())
206
+ for new_text in streamer:
207
+ if not new_text.endswith(stop_str):
208
+ yield generate_completion_chunk(
209
+ chunk_text=new_text,
210
+ finish_reason=None,
211
+ chunk_id=completion_id,
212
+ model_uid=self.model_uid,
213
+ prompt_tokens=-1,
214
+ completion_tokens=-1,
215
+ total_tokens=-1,
216
+ has_choice=True,
217
+ has_content=True,
218
+ )
219
+
220
+ yield generate_completion_chunk(
221
+ chunk_text=None,
222
+ finish_reason="stop",
223
+ chunk_id=completion_id,
224
+ model_uid=self.model_uid,
225
+ prompt_tokens=-1,
226
+ completion_tokens=-1,
227
+ total_tokens=-1,
228
+ has_choice=True,
229
+ has_content=False,
230
+ )