xinference 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (85) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +108 -14
  3. xinference/client/restful/restful_client.py +78 -5
  4. xinference/constants.py +1 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/event.py +5 -6
  7. xinference/core/model.py +59 -42
  8. xinference/core/scheduler.py +46 -18
  9. xinference/core/supervisor.py +73 -24
  10. xinference/core/worker.py +68 -2
  11. xinference/deploy/cmdline.py +86 -2
  12. xinference/deploy/test/test_cmdline.py +19 -10
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/core.py +12 -1
  15. xinference/model/audio/custom.py +6 -4
  16. xinference/model/audio/model_spec_modelscope.json +20 -0
  17. xinference/model/llm/__init__.py +34 -2
  18. xinference/model/llm/llm_family.json +8 -2
  19. xinference/model/llm/llm_family.py +86 -1
  20. xinference/model/llm/llm_family_csghub.json +66 -0
  21. xinference/model/llm/llm_family_modelscope.json +8 -2
  22. xinference/model/llm/pytorch/chatglm.py +41 -12
  23. xinference/model/llm/pytorch/core.py +128 -88
  24. xinference/model/llm/pytorch/glm4v.py +24 -3
  25. xinference/model/llm/pytorch/internlm2.py +15 -0
  26. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  27. xinference/model/llm/pytorch/utils.py +69 -189
  28. xinference/model/llm/utils.py +27 -14
  29. xinference/model/llm/vllm/core.py +10 -4
  30. xinference/model/rerank/core.py +35 -6
  31. xinference/model/utils.py +8 -2
  32. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  33. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  34. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  35. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  36. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  38. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  39. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  40. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  41. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  42. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  43. xinference/types.py +28 -0
  44. xinference/web/ui/build/asset-manifest.json +6 -6
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
  47. xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
  49. xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  63. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/METADATA +1 -1
  64. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/RECORD +69 -56
  65. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  66. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  67. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  68. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  71. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
  72. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
  73. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  74. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
  75. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  76. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
  81. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
  82. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/LICENSE +0 -0
  83. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/WHEEL +0 -0
  84. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/entry_points.txt +0 -0
  85. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/top_level.txt +0 -0
@@ -17,11 +17,9 @@ import logging
17
17
  import os
18
18
  import time
19
19
  import uuid
20
- from threading import Thread
21
20
  from typing import Dict, Iterable, Iterator, List, Optional, Tuple
22
21
 
23
22
  import torch
24
- from transformers import GenerationConfig, TextIteratorStreamer
25
23
  from transformers.cache_utils import DynamicCache
26
24
  from transformers.generation.logits_process import (
27
25
  LogitsProcessorList,
@@ -126,6 +124,7 @@ def generate_stream(
126
124
  stop_str = generate_config.get("stop", None)
127
125
  stop_token_ids = generate_config.get("stop_token_ids", None) or []
128
126
  stop_token_ids.append(tokenizer.eos_token_id)
127
+ chunk_id = str(uuid.uuid4())
129
128
 
130
129
  logits_processor = prepare_logits_processor(
131
130
  temperature, repetition_penalty, top_p, top_k
@@ -289,7 +288,7 @@ def generate_stream(
289
288
  text=output, index=0, logprobs=None, finish_reason=None
290
289
  )
291
290
  completion_chunk = CompletionChunk(
292
- id=str(uuid.uuid1()),
291
+ id=chunk_id,
293
292
  object="text_completion",
294
293
  created=int(time.time()),
295
294
  model=model_uid,
@@ -327,7 +326,7 @@ def generate_stream(
327
326
  )
328
327
 
329
328
  completion_chunk = CompletionChunk(
330
- id=str(uuid.uuid1()),
329
+ id=chunk_id,
331
330
  object="text_completion",
332
331
  created=int(time.time()),
333
332
  model=model_uid,
@@ -343,7 +342,7 @@ def generate_stream(
343
342
 
344
343
  if include_usage:
345
344
  completion_chunk = CompletionChunk(
346
- id=str(uuid.uuid1()),
345
+ id=chunk_id,
347
346
  object="text_completion",
348
347
  created=int(time.time()),
349
348
  model=model_uid,
@@ -362,178 +361,6 @@ def generate_stream(
362
361
  empty_cache()
363
362
 
364
363
 
365
- @torch.inference_mode()
366
- def generate_stream_falcon(
367
- model_uid,
368
- model,
369
- tokenizer,
370
- prompt,
371
- device,
372
- generate_config,
373
- judge_sent_end=False,
374
- ) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
375
- context_len = get_context_length(model.config)
376
- stream_interval = generate_config.get("stream_interval", 2)
377
- stream = generate_config.get("stream", False)
378
- stream_options = generate_config.pop("stream_options", None)
379
- include_usage = (
380
- stream_options["include_usage"] if isinstance(stream_options, dict) else False
381
- )
382
- len_prompt = len(prompt)
383
-
384
- temperature = float(generate_config.get("temperature", 1.0))
385
- repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
386
- top_p = float(generate_config.get("top_p", 1.0))
387
- top_k = int(generate_config.get("top_k", 50)) # -1 means disable
388
- max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
389
- echo = bool(generate_config.get("echo", False))
390
- stop_str = generate_config.get("stop", None)
391
- stop_token_ids = generate_config.get("stop_token_ids", None) or []
392
- stop_token_ids.append(tokenizer.eos_token_id)
393
-
394
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
395
- input_ids = inputs["input_ids"]
396
- attention_mask = inputs["attention_mask"]
397
-
398
- max_src_len = context_len - max_new_tokens - 8
399
-
400
- input_ids = input_ids[-max_src_len:] # truncate from the left
401
- attention_mask = attention_mask[-max_src_len:] # truncate from the left
402
- input_echo_len = len(input_ids)
403
-
404
- decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
405
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
406
-
407
- generation_config = GenerationConfig(
408
- max_new_tokens=max_new_tokens,
409
- do_sample=temperature >= 1e-5,
410
- temperature=temperature,
411
- repetition_penalty=repetition_penalty,
412
- no_repeat_ngram_size=10,
413
- top_p=top_p,
414
- top_k=top_k,
415
- eos_token_id=stop_token_ids,
416
- )
417
-
418
- generation_kwargs = dict(
419
- inputs=input_ids,
420
- attention_mask=attention_mask,
421
- streamer=streamer,
422
- generation_config=generation_config,
423
- )
424
-
425
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
426
- thread.start()
427
-
428
- if echo:
429
- # means keep the prompt
430
- output = prompt
431
- else:
432
- output = ""
433
-
434
- last_output_length = 0
435
- for i, new_text in enumerate(streamer):
436
- output += new_text
437
- if i % stream_interval == 0:
438
- if echo:
439
- rfind_start = len_prompt
440
- else:
441
- rfind_start = 0
442
-
443
- partially_stopped = False
444
- if stop_str:
445
- if isinstance(stop_str, str):
446
- pos = output.rfind(stop_str, rfind_start)
447
- if pos != -1:
448
- output = output[:pos]
449
- else:
450
- partially_stopped = is_partial_stop(output, stop_str)
451
- elif isinstance(stop_str, Iterable):
452
- for each_stop in stop_str:
453
- pos = output.rfind(each_stop, rfind_start)
454
- if pos != -1:
455
- output = output[:pos]
456
- break
457
- else:
458
- partially_stopped = is_partial_stop(output, each_stop)
459
- if partially_stopped:
460
- break
461
- else:
462
- raise ValueError("Invalid stop field type.")
463
-
464
- if stream:
465
- output = output.strip("�")
466
- tmp_output_length = len(output)
467
- output = output[last_output_length:]
468
- last_output_length = tmp_output_length
469
-
470
- # prevent yielding partial stop sequence
471
- if not partially_stopped:
472
- completion_choice = CompletionChoice(
473
- text=output, index=0, logprobs=None, finish_reason=None
474
- )
475
- completion_chunk = CompletionChunk(
476
- id=str(uuid.uuid1()),
477
- object="text_completion",
478
- created=int(time.time()),
479
- model=model_uid,
480
- choices=[completion_choice],
481
- )
482
- completion_usage = CompletionUsage(
483
- prompt_tokens=input_echo_len,
484
- completion_tokens=i,
485
- total_tokens=(input_echo_len + i),
486
- )
487
-
488
- yield completion_chunk, completion_usage
489
- output = output.strip()
490
-
491
- # finish stream event, which contains finish reason
492
- if i == max_new_tokens - 1:
493
- finish_reason = "length"
494
- elif partially_stopped:
495
- finish_reason = None
496
- else:
497
- finish_reason = "stop"
498
-
499
- completion_choice = CompletionChoice(
500
- text=output, index=0, logprobs=None, finish_reason=finish_reason
501
- )
502
- completion_chunk = CompletionChunk(
503
- id=str(uuid.uuid1()),
504
- object="text_completion",
505
- created=int(time.time()),
506
- model=model_uid,
507
- choices=[completion_choice],
508
- )
509
- completion_usage = CompletionUsage(
510
- prompt_tokens=input_echo_len,
511
- completion_tokens=i,
512
- total_tokens=(input_echo_len + i),
513
- )
514
-
515
- yield completion_chunk, completion_usage
516
-
517
- if include_usage:
518
- completion_chunk = CompletionChunk(
519
- id=str(uuid.uuid1()),
520
- object="text_completion",
521
- created=int(time.time()),
522
- model=model_uid,
523
- choices=[],
524
- )
525
- completion_usage = CompletionUsage(
526
- prompt_tokens=input_echo_len,
527
- completion_tokens=i,
528
- total_tokens=(input_echo_len + i),
529
- )
530
- yield completion_chunk, completion_usage
531
-
532
- # clean
533
- gc.collect()
534
- empty_cache()
535
-
536
-
537
364
  def _get_token_from_logits(
538
365
  req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
539
366
  ):
@@ -568,12 +395,15 @@ def _pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
568
395
  return [pad] * (max_len - len(x)) + x
569
396
 
570
397
 
571
- def _pad_seqs_inplace(seqs: List[List[int]], pad: int):
398
+ def _pad_seqs_inplace(seqs: List[List[int]], reqs: List[InferenceRequest], pad: int):
572
399
  max_len = max(len(seq) for seq in seqs)
573
400
  n = len(seqs)
574
401
  i = 0
575
402
  while i < n:
403
+ prev_seq_len = len(seqs[i])
576
404
  seqs[i] = _pad_to_max_length(seqs[i], max_len, pad)
405
+ padding_len = len(seqs[i]) - prev_seq_len
406
+ reqs[i].padding_len = padding_len
577
407
  i += 1
578
408
 
579
409
 
@@ -586,6 +416,7 @@ def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
586
416
 
587
417
  def _get_completion_chunk(
588
418
  output: str,
419
+ chunk_id: str,
589
420
  finish_reason: Optional[str],
590
421
  model_uid: str,
591
422
  r: InferenceRequest,
@@ -601,7 +432,7 @@ def _get_completion_chunk(
601
432
  else []
602
433
  )
603
434
  completion_chunk = CompletionChunk(
604
- id=str(uuid.uuid1()),
435
+ id=chunk_id,
605
436
  object="text_completion",
606
437
  created=int(time.time()),
607
438
  model=model_uid,
@@ -617,14 +448,18 @@ def _get_completion_chunk(
617
448
 
618
449
 
619
450
  def _get_completion(
620
- output: str, finish_reason: Optional[str], model_uid: str, r: InferenceRequest
451
+ output: str,
452
+ chunk_id: str,
453
+ finish_reason: Optional[str],
454
+ model_uid: str,
455
+ r: InferenceRequest,
621
456
  ):
622
457
  completion_choice = CompletionChoice(
623
458
  text=output, index=0, logprobs=None, finish_reason=finish_reason
624
459
  )
625
460
 
626
461
  completion_chunk = CompletionChunk(
627
- id=str(uuid.uuid1()),
462
+ id=chunk_id,
628
463
  object="text_completion",
629
464
  created=int(time.time()),
630
465
  model=model_uid,
@@ -674,6 +509,25 @@ def _merge_kv_cache(
674
509
  return ret_kv.to_legacy_cache()
675
510
 
676
511
 
512
+ def _get_attention_mask_and_position_ids(kv, reqs: List[InferenceRequest]):
513
+ batch_size, seq_length, device = (
514
+ kv[0][0].shape[0],
515
+ kv[0][0].shape[2],
516
+ kv[0][0].device,
517
+ )
518
+ seq_length = seq_length + 1
519
+ position_ids = torch.as_tensor([[seq_length - 1]], dtype=torch.long, device=device)
520
+ attention_mask = torch.ones(
521
+ (batch_size, seq_length), dtype=torch.long, device=device
522
+ )
523
+ padding_lens = torch.as_tensor([r.padding_len for r in reqs])
524
+ mask = torch.arange(seq_length).expand(
525
+ batch_size, seq_length
526
+ ) < padding_lens.unsqueeze(1)
527
+ attention_mask[mask] = 0
528
+ return attention_mask, position_ids
529
+
530
+
677
531
  @torch.inference_mode()
678
532
  def _batch_inference_one_step_internal(
679
533
  req_list: List[InferenceRequest],
@@ -682,7 +536,9 @@ def _batch_inference_one_step_internal(
682
536
  tokenizer,
683
537
  device,
684
538
  context_len: int,
539
+ stop_tokens: Tuple[int],
685
540
  decode_round: int = 16,
541
+ require_attention_mask: bool = False,
686
542
  bos_flag: str = "<bos_stream>",
687
543
  eos_flag: str = "<eos_stream>",
688
544
  ):
@@ -692,7 +548,8 @@ def _batch_inference_one_step_internal(
692
548
  if not valid_req_list:
693
549
  return
694
550
  generate_config_mapping: Dict[InferenceRequest, Tuple] = {
695
- r: r.get_generate_configs(tokenizer.eos_token_id) for r in valid_req_list
551
+ r: r.get_generate_configs(tokenizer.eos_token_id, stop_tokens)
552
+ for r in valid_req_list
696
553
  }
697
554
  s_time = time.time()
698
555
 
@@ -701,7 +558,7 @@ def _batch_inference_one_step_internal(
701
558
  decode_reqs = []
702
559
  for r in valid_req_list:
703
560
  if r.is_prefill:
704
- prompts.append(r.full_prompt)
561
+ prompts.append(r.full_prompt if r.full_prompt is not None else r.prompt)
705
562
  prefill_reqs.append(r)
706
563
  else:
707
564
  decode_reqs.append(r)
@@ -714,7 +571,7 @@ def _batch_inference_one_step_internal(
714
571
  max_src_len = get_max_src_len(context_len, req)
715
572
  req.prompt_tokens = input_id[-max_src_len:]
716
573
  prompt_tokens.append(req.prompt_tokens)
717
- _pad_seqs_inplace(prompt_tokens, 0)
574
+ _pad_seqs_inplace(prompt_tokens, valid_req_list, 0)
718
575
  out = model(torch.as_tensor(prompt_tokens, device=device), use_cache=True)
719
576
 
720
577
  logits = out.logits
@@ -756,10 +613,18 @@ def _batch_inference_one_step_internal(
756
613
  # here, only decode phase, just run some rounds
757
614
  for _i in range(decode_round):
758
615
  decode_tokens: List[List[int]] = [[r.new_tokens[-1]] for r in valid_req_list]
616
+ inf_kws = {}
617
+ if require_attention_mask:
618
+ attention_mask, position_ids = _get_attention_mask_and_position_ids(
619
+ past_key_values, valid_req_list
620
+ )
621
+ inf_kws["position_ids"] = position_ids
622
+ inf_kws["attention_mask"] = attention_mask
759
623
  out = model(
760
624
  input_ids=torch.as_tensor(decode_tokens, device=device),
761
625
  use_cache=True,
762
626
  past_key_values=past_key_values,
627
+ **inf_kws,
763
628
  )
764
629
  logits = out.logits
765
630
  past_key_values = out.past_key_values
@@ -846,7 +711,7 @@ def _batch_inference_one_step_internal(
846
711
  r.last_output_length += len(output)
847
712
 
848
713
  completion_chunk = _get_completion_chunk(
849
- output, r.finish_reason, model_uid, r, False
714
+ output, r.chunk_id, r.finish_reason, model_uid, r, False
850
715
  )
851
716
  r.completion.append(completion_chunk)
852
717
  if r.stopped:
@@ -859,7 +724,7 @@ def _batch_inference_one_step_internal(
859
724
  if r.stopped and _i == decode_round - 1 and include_usage:
860
725
  r.completion.append(
861
726
  _get_completion_chunk(
862
- "", r.finish_reason, model_uid, r, True
727
+ "", r.chunk_id, r.finish_reason, model_uid, r, True
863
728
  )
864
729
  )
865
730
  else:
@@ -878,7 +743,9 @@ def _batch_inference_one_step_internal(
878
743
  if r not in output_mapping
879
744
  else output_mapping[r]
880
745
  )
881
- completion = _get_completion(outputs, r.finish_reason, model_uid, r)
746
+ completion = _get_completion(
747
+ outputs, r.chunk_id, r.finish_reason, model_uid, r
748
+ )
882
749
  r.completion = [completion]
883
750
 
884
751
  e_time = time.time()
@@ -894,12 +761,21 @@ def batch_inference_one_step(
894
761
  tokenizer,
895
762
  device,
896
763
  context_len: int,
764
+ stop_token_ids: Tuple[int],
765
+ require_attention_mask: bool = False,
897
766
  ):
898
767
  from ....core.model import OutOfMemoryError
899
768
 
900
769
  try:
901
770
  _batch_inference_one_step_internal(
902
- req_list, model_uid, model, tokenizer, device, context_len
771
+ req_list,
772
+ model_uid,
773
+ model,
774
+ tokenizer,
775
+ device,
776
+ context_len,
777
+ stop_token_ids,
778
+ require_attention_mask=require_attention_mask,
903
779
  )
904
780
  except OutOfMemoryError:
905
781
  logger.exception(
@@ -911,4 +787,8 @@ def batch_inference_one_step(
911
787
  os._exit(1)
912
788
  except Exception as e:
913
789
  logger.exception(f"Internal error for batch inference: {e}.")
914
- # TODO: handle this
790
+ # If internal error happens, just skip all the requests in this batch.
791
+ # If not handle here, the client will hang.
792
+ for r in req_list:
793
+ r.stopped = True
794
+ r.error_msg = str(e)
@@ -607,7 +607,7 @@ Begin!"""
607
607
  return arguments, None, None
608
608
 
609
609
  @staticmethod
610
- def _eval_chatglm3_arguments(c, tools):
610
+ def _eval_glm_chat_arguments(c, tools):
611
611
  if isinstance(c[0], str):
612
612
  return c[0], None, None
613
613
  return None, c[0]["name"], c[0]["parameters"]
@@ -659,9 +659,15 @@ Begin!"""
659
659
  family = model_family.model_family or model_family.model_name
660
660
  if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
661
661
  content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
662
- elif "chatglm3" == family:
663
- content, func, args = cls._eval_chatglm3_arguments(c, tools)
664
- elif family in ["qwen-chat", "qwen1.5-chat"]:
662
+ elif family in ["chatglm3", "glm4-chat"]:
663
+ content, func, args = cls._eval_glm_chat_arguments(c, tools)
664
+ elif family in [
665
+ "qwen-chat",
666
+ "qwen1.5-chat",
667
+ "qwen1.5-moe-chat",
668
+ "qwen2-instruct",
669
+ "qwen2-moe-instruct",
670
+ ]:
665
671
  content, func, args = cls._eval_qwen_chat_arguments(c, tools)
666
672
  else:
667
673
  raise Exception(
@@ -676,28 +682,35 @@ Begin!"""
676
682
  Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
677
683
 
678
684
  Returns:
679
- A function that takes tokens (string output by the model so far) as input
680
- returns True if current token is after "\nFinal Answer:", else False.
685
+ A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
686
+ returns the part after "\nFinal Answer:" if found, else returns delta.
681
687
  """
682
688
  family = model_family.model_family or model_family.model_name
683
- if family in ["qwen-chat", "qwen1.5-chat"]:
689
+ if family in [
690
+ "qwen-chat",
691
+ "qwen1.5-chat",
692
+ "qwen1.5-moe-chat",
693
+ "qwen2-instruct",
694
+ "qwen2-moe-instruct",
695
+ ]:
684
696
  # Encapsulating function to reset 'found' after each call
685
697
  found = False
686
698
 
687
- def process_token(tokens: str):
699
+ def process_tokens(tokens: str, delta: str):
688
700
  nonlocal found
689
701
  # Once "Final Answer:" is found, future tokens are allowed.
690
702
  if found:
691
- return True
703
+ return delta
692
704
  # Check if the token ends with "\nFinal Answer:" and update `found`.
693
- if tokens.endswith("\nFinal Answer:"):
705
+ final_answer_idx = tokens.lower().rfind("\nfinal answer:")
706
+ if final_answer_idx != -1:
694
707
  found = True
695
- return False
708
+ return tokens[final_answer_idx + len("\nfinal answer:") :]
709
+ return ""
696
710
 
697
- return process_token
711
+ return process_tokens
698
712
  else:
699
- # For other families, allow all tokens.
700
- return lambda tokens: True
713
+ return lambda tokens, delta: delta
701
714
 
702
715
  @classmethod
703
716
  def _tool_calls_completion(cls, model_family, model_uid, c, tools):
@@ -444,7 +444,9 @@ class VLLMModel(LLM):
444
444
  _content, func, args = ChatModelMixin._eval_tool_arguments(
445
445
  self.model_family, chunk, tools
446
446
  )
447
- choice["text"] = choice_delta
447
+ choice["text"] = tools_token_filter(
448
+ tokens=previous_texts[0], delta=choice_delta
449
+ )
448
450
  if func is not None:
449
451
  choice["text"] = None
450
452
  choice["finish_reason"] = "tool_calls"
@@ -458,9 +460,13 @@ class VLLMModel(LLM):
458
460
  ),
459
461
  )
460
462
  ]
461
- # use a filter function to skip Qwen's react thought process
462
- elif not tools_token_filter(previous_texts[0]):
463
- continue
463
+ else:
464
+ # use a filter function to skip Qwen's react thought process
465
+ choice["text"] = tools_token_filter(
466
+ tokens=previous_texts[0], delta=choice["text"]
467
+ )
468
+ if not choice["text"]:
469
+ continue
464
470
  prompt_tokens = len(_request_output.prompt_token_ids)
465
471
  completion_tokens = sum(
466
472
  len(output.token_ids) for output in _request_output.outputs
@@ -23,7 +23,7 @@ import numpy as np
23
23
 
24
24
  from ...constants import XINFERENCE_CACHE_DIR
25
25
  from ...device_utils import empty_cache
26
- from ...types import Document, DocumentObj, Rerank
26
+ from ...types import Document, DocumentObj, Rerank, RerankTokens
27
27
  from ..core import CacheableModelSpec, ModelDescription
28
28
  from ..utils import is_model_cached
29
29
 
@@ -121,11 +121,17 @@ class RerankModel:
121
121
  if model_spec.type == "unknown":
122
122
  model_spec.type = self._auto_detect_type(model_path)
123
123
 
124
+ @staticmethod
125
+ def _get_tokenizer(model_path):
126
+ from transformers import AutoTokenizer
127
+
128
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
129
+ return tokenizer
130
+
124
131
  @staticmethod
125
132
  def _auto_detect_type(model_path):
126
133
  """This method may not be stable due to the fact that the tokenizer name may be changed.
127
134
  Therefore, we only use this method for unknown model types."""
128
- from transformers import AutoTokenizer
129
135
 
130
136
  type_mapper = {
131
137
  "LlamaTokenizerFast": "LLM-based layerwise",
@@ -133,12 +139,13 @@ class RerankModel:
133
139
  "XLMRobertaTokenizerFast": "normal",
134
140
  }
135
141
 
136
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
142
+ tokenizer = RerankModel._get_tokenizer(model_path)
137
143
  rerank_type = type_mapper.get(type(tokenizer).__name__)
138
144
  if rerank_type is None:
139
- raise Exception(
140
- f"Can't determine the rerank type based on the tokenizer {tokenizer}"
145
+ logger.warning(
146
+ f"Can't determine the rerank type based on the tokenizer {tokenizer}, use normal type by default."
141
147
  )
148
+ return "normal"
142
149
  return rerank_type
143
150
 
144
151
  def load(self):
@@ -185,6 +192,7 @@ class RerankModel:
185
192
  top_n: Optional[int],
186
193
  max_chunks_per_doc: Optional[int],
187
194
  return_documents: Optional[bool],
195
+ return_len: Optional[bool],
188
196
  **kwargs,
189
197
  ) -> Rerank:
190
198
  self._counter += 1
@@ -223,7 +231,28 @@ class RerankModel:
223
231
  )
224
232
  for arg in sim_scores_argsort
225
233
  ]
226
- return Rerank(id=str(uuid.uuid1()), results=docs)
234
+ if return_len:
235
+ tokenizer = self._get_tokenizer(self._model_path)
236
+ input_len = sum([len(tokenizer.tokenize(t)) for t in documents])
237
+
238
+ # Rerank Model output is just score or documents
239
+ # while return_documents = True
240
+ output_len = input_len
241
+
242
+ # api_version, billed_units, warnings
243
+ # is for Cohere API compatibility, set to None
244
+ metadata = {
245
+ "api_version": None,
246
+ "billed_units": None,
247
+ "tokens": (
248
+ RerankTokens(input_tokens=input_len, output_tokens=output_len)
249
+ if return_len
250
+ else None
251
+ ),
252
+ "warnings": None,
253
+ }
254
+
255
+ return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
227
256
 
228
257
 
229
258
  def get_cache_dir(model_spec: RerankModelSpec):
xinference/model/utils.py CHANGED
@@ -42,14 +42,20 @@ def is_locale_chinese_simplified() -> bool:
42
42
 
43
43
 
44
44
  def download_from_modelscope() -> bool:
45
- if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope":
46
- return True
45
+ if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
46
+ return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope"
47
47
  elif is_locale_chinese_simplified():
48
48
  return True
49
49
  else:
50
50
  return False
51
51
 
52
52
 
53
+ def download_from_csghub() -> bool:
54
+ if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
55
+ return True
56
+ return False
57
+
58
+
53
59
  def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
54
60
  from huggingface_hub.file_download import _create_symlink
55
61
 
File without changes
@@ -0,0 +1,40 @@
1
+
2
+ from openai import OpenAI
3
+
4
+ prompt_dict = {
5
+ 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
6
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
7
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
8
+ 'deepseek': [
9
+ {"role": "system", "content": "You are a helpful assistant"},
10
+ {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
11
+ {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
12
+ 'deepseek_TN': [
13
+ {"role": "system", "content": "You are a helpful assistant"},
14
+ {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
15
+ {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
16
+ {"role": "user", "content": "We paid $123 for this desk."},
17
+ {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
18
+ {"role": "user", "content": "详询请拨打010-724654"},
19
+ {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
20
+ {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
21
+ {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
22
+ ],
23
+ }
24
+
25
+ class llm_api:
26
+ def __init__(self, api_key, base_url, model):
27
+ self.client = OpenAI(
28
+ api_key = api_key,
29
+ base_url = base_url,
30
+ )
31
+ self.model = model
32
+ def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
33
+
34
+ completion = self.client.chat.completions.create(
35
+ model = self.model,
36
+ messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
37
+ temperature = temperature,
38
+ **kwargs
39
+ )
40
+ return completion.choices[0].message.content
File without changes