xinference 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +108 -14
- xinference/client/restful/restful_client.py +78 -5
- xinference/constants.py +1 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/event.py +5 -6
- xinference/core/model.py +59 -42
- xinference/core/scheduler.py +46 -18
- xinference/core/supervisor.py +73 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/core.py +12 -1
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +34 -2
- xinference/model/llm/llm_family.json +8 -2
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +8 -2
- xinference/model/llm/pytorch/chatglm.py +41 -12
- xinference/model/llm/pytorch/core.py +128 -88
- xinference/model/llm/pytorch/glm4v.py +24 -3
- xinference/model/llm/pytorch/internlm2.py +15 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +69 -189
- xinference/model/llm/utils.py +27 -14
- xinference/model/llm/vllm/core.py +10 -4
- xinference/model/rerank/core.py +35 -6
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +28 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
- xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
- xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/METADATA +1 -1
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/RECORD +69 -56
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/LICENSE +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/WHEEL +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/top_level.txt +0 -0
|
@@ -17,11 +17,9 @@ import logging
|
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
19
|
import uuid
|
|
20
|
-
from threading import Thread
|
|
21
20
|
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
|
|
22
21
|
|
|
23
22
|
import torch
|
|
24
|
-
from transformers import GenerationConfig, TextIteratorStreamer
|
|
25
23
|
from transformers.cache_utils import DynamicCache
|
|
26
24
|
from transformers.generation.logits_process import (
|
|
27
25
|
LogitsProcessorList,
|
|
@@ -126,6 +124,7 @@ def generate_stream(
|
|
|
126
124
|
stop_str = generate_config.get("stop", None)
|
|
127
125
|
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
128
126
|
stop_token_ids.append(tokenizer.eos_token_id)
|
|
127
|
+
chunk_id = str(uuid.uuid4())
|
|
129
128
|
|
|
130
129
|
logits_processor = prepare_logits_processor(
|
|
131
130
|
temperature, repetition_penalty, top_p, top_k
|
|
@@ -289,7 +288,7 @@ def generate_stream(
|
|
|
289
288
|
text=output, index=0, logprobs=None, finish_reason=None
|
|
290
289
|
)
|
|
291
290
|
completion_chunk = CompletionChunk(
|
|
292
|
-
id=
|
|
291
|
+
id=chunk_id,
|
|
293
292
|
object="text_completion",
|
|
294
293
|
created=int(time.time()),
|
|
295
294
|
model=model_uid,
|
|
@@ -327,7 +326,7 @@ def generate_stream(
|
|
|
327
326
|
)
|
|
328
327
|
|
|
329
328
|
completion_chunk = CompletionChunk(
|
|
330
|
-
id=
|
|
329
|
+
id=chunk_id,
|
|
331
330
|
object="text_completion",
|
|
332
331
|
created=int(time.time()),
|
|
333
332
|
model=model_uid,
|
|
@@ -343,7 +342,7 @@ def generate_stream(
|
|
|
343
342
|
|
|
344
343
|
if include_usage:
|
|
345
344
|
completion_chunk = CompletionChunk(
|
|
346
|
-
id=
|
|
345
|
+
id=chunk_id,
|
|
347
346
|
object="text_completion",
|
|
348
347
|
created=int(time.time()),
|
|
349
348
|
model=model_uid,
|
|
@@ -362,178 +361,6 @@ def generate_stream(
|
|
|
362
361
|
empty_cache()
|
|
363
362
|
|
|
364
363
|
|
|
365
|
-
@torch.inference_mode()
|
|
366
|
-
def generate_stream_falcon(
|
|
367
|
-
model_uid,
|
|
368
|
-
model,
|
|
369
|
-
tokenizer,
|
|
370
|
-
prompt,
|
|
371
|
-
device,
|
|
372
|
-
generate_config,
|
|
373
|
-
judge_sent_end=False,
|
|
374
|
-
) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
|
|
375
|
-
context_len = get_context_length(model.config)
|
|
376
|
-
stream_interval = generate_config.get("stream_interval", 2)
|
|
377
|
-
stream = generate_config.get("stream", False)
|
|
378
|
-
stream_options = generate_config.pop("stream_options", None)
|
|
379
|
-
include_usage = (
|
|
380
|
-
stream_options["include_usage"] if isinstance(stream_options, dict) else False
|
|
381
|
-
)
|
|
382
|
-
len_prompt = len(prompt)
|
|
383
|
-
|
|
384
|
-
temperature = float(generate_config.get("temperature", 1.0))
|
|
385
|
-
repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
|
|
386
|
-
top_p = float(generate_config.get("top_p", 1.0))
|
|
387
|
-
top_k = int(generate_config.get("top_k", 50)) # -1 means disable
|
|
388
|
-
max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
|
|
389
|
-
echo = bool(generate_config.get("echo", False))
|
|
390
|
-
stop_str = generate_config.get("stop", None)
|
|
391
|
-
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
392
|
-
stop_token_ids.append(tokenizer.eos_token_id)
|
|
393
|
-
|
|
394
|
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
395
|
-
input_ids = inputs["input_ids"]
|
|
396
|
-
attention_mask = inputs["attention_mask"]
|
|
397
|
-
|
|
398
|
-
max_src_len = context_len - max_new_tokens - 8
|
|
399
|
-
|
|
400
|
-
input_ids = input_ids[-max_src_len:] # truncate from the left
|
|
401
|
-
attention_mask = attention_mask[-max_src_len:] # truncate from the left
|
|
402
|
-
input_echo_len = len(input_ids)
|
|
403
|
-
|
|
404
|
-
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
405
|
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
|
|
406
|
-
|
|
407
|
-
generation_config = GenerationConfig(
|
|
408
|
-
max_new_tokens=max_new_tokens,
|
|
409
|
-
do_sample=temperature >= 1e-5,
|
|
410
|
-
temperature=temperature,
|
|
411
|
-
repetition_penalty=repetition_penalty,
|
|
412
|
-
no_repeat_ngram_size=10,
|
|
413
|
-
top_p=top_p,
|
|
414
|
-
top_k=top_k,
|
|
415
|
-
eos_token_id=stop_token_ids,
|
|
416
|
-
)
|
|
417
|
-
|
|
418
|
-
generation_kwargs = dict(
|
|
419
|
-
inputs=input_ids,
|
|
420
|
-
attention_mask=attention_mask,
|
|
421
|
-
streamer=streamer,
|
|
422
|
-
generation_config=generation_config,
|
|
423
|
-
)
|
|
424
|
-
|
|
425
|
-
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
426
|
-
thread.start()
|
|
427
|
-
|
|
428
|
-
if echo:
|
|
429
|
-
# means keep the prompt
|
|
430
|
-
output = prompt
|
|
431
|
-
else:
|
|
432
|
-
output = ""
|
|
433
|
-
|
|
434
|
-
last_output_length = 0
|
|
435
|
-
for i, new_text in enumerate(streamer):
|
|
436
|
-
output += new_text
|
|
437
|
-
if i % stream_interval == 0:
|
|
438
|
-
if echo:
|
|
439
|
-
rfind_start = len_prompt
|
|
440
|
-
else:
|
|
441
|
-
rfind_start = 0
|
|
442
|
-
|
|
443
|
-
partially_stopped = False
|
|
444
|
-
if stop_str:
|
|
445
|
-
if isinstance(stop_str, str):
|
|
446
|
-
pos = output.rfind(stop_str, rfind_start)
|
|
447
|
-
if pos != -1:
|
|
448
|
-
output = output[:pos]
|
|
449
|
-
else:
|
|
450
|
-
partially_stopped = is_partial_stop(output, stop_str)
|
|
451
|
-
elif isinstance(stop_str, Iterable):
|
|
452
|
-
for each_stop in stop_str:
|
|
453
|
-
pos = output.rfind(each_stop, rfind_start)
|
|
454
|
-
if pos != -1:
|
|
455
|
-
output = output[:pos]
|
|
456
|
-
break
|
|
457
|
-
else:
|
|
458
|
-
partially_stopped = is_partial_stop(output, each_stop)
|
|
459
|
-
if partially_stopped:
|
|
460
|
-
break
|
|
461
|
-
else:
|
|
462
|
-
raise ValueError("Invalid stop field type.")
|
|
463
|
-
|
|
464
|
-
if stream:
|
|
465
|
-
output = output.strip("�")
|
|
466
|
-
tmp_output_length = len(output)
|
|
467
|
-
output = output[last_output_length:]
|
|
468
|
-
last_output_length = tmp_output_length
|
|
469
|
-
|
|
470
|
-
# prevent yielding partial stop sequence
|
|
471
|
-
if not partially_stopped:
|
|
472
|
-
completion_choice = CompletionChoice(
|
|
473
|
-
text=output, index=0, logprobs=None, finish_reason=None
|
|
474
|
-
)
|
|
475
|
-
completion_chunk = CompletionChunk(
|
|
476
|
-
id=str(uuid.uuid1()),
|
|
477
|
-
object="text_completion",
|
|
478
|
-
created=int(time.time()),
|
|
479
|
-
model=model_uid,
|
|
480
|
-
choices=[completion_choice],
|
|
481
|
-
)
|
|
482
|
-
completion_usage = CompletionUsage(
|
|
483
|
-
prompt_tokens=input_echo_len,
|
|
484
|
-
completion_tokens=i,
|
|
485
|
-
total_tokens=(input_echo_len + i),
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
yield completion_chunk, completion_usage
|
|
489
|
-
output = output.strip()
|
|
490
|
-
|
|
491
|
-
# finish stream event, which contains finish reason
|
|
492
|
-
if i == max_new_tokens - 1:
|
|
493
|
-
finish_reason = "length"
|
|
494
|
-
elif partially_stopped:
|
|
495
|
-
finish_reason = None
|
|
496
|
-
else:
|
|
497
|
-
finish_reason = "stop"
|
|
498
|
-
|
|
499
|
-
completion_choice = CompletionChoice(
|
|
500
|
-
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
501
|
-
)
|
|
502
|
-
completion_chunk = CompletionChunk(
|
|
503
|
-
id=str(uuid.uuid1()),
|
|
504
|
-
object="text_completion",
|
|
505
|
-
created=int(time.time()),
|
|
506
|
-
model=model_uid,
|
|
507
|
-
choices=[completion_choice],
|
|
508
|
-
)
|
|
509
|
-
completion_usage = CompletionUsage(
|
|
510
|
-
prompt_tokens=input_echo_len,
|
|
511
|
-
completion_tokens=i,
|
|
512
|
-
total_tokens=(input_echo_len + i),
|
|
513
|
-
)
|
|
514
|
-
|
|
515
|
-
yield completion_chunk, completion_usage
|
|
516
|
-
|
|
517
|
-
if include_usage:
|
|
518
|
-
completion_chunk = CompletionChunk(
|
|
519
|
-
id=str(uuid.uuid1()),
|
|
520
|
-
object="text_completion",
|
|
521
|
-
created=int(time.time()),
|
|
522
|
-
model=model_uid,
|
|
523
|
-
choices=[],
|
|
524
|
-
)
|
|
525
|
-
completion_usage = CompletionUsage(
|
|
526
|
-
prompt_tokens=input_echo_len,
|
|
527
|
-
completion_tokens=i,
|
|
528
|
-
total_tokens=(input_echo_len + i),
|
|
529
|
-
)
|
|
530
|
-
yield completion_chunk, completion_usage
|
|
531
|
-
|
|
532
|
-
# clean
|
|
533
|
-
gc.collect()
|
|
534
|
-
empty_cache()
|
|
535
|
-
|
|
536
|
-
|
|
537
364
|
def _get_token_from_logits(
|
|
538
365
|
req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
|
|
539
366
|
):
|
|
@@ -568,12 +395,15 @@ def _pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
|
|
|
568
395
|
return [pad] * (max_len - len(x)) + x
|
|
569
396
|
|
|
570
397
|
|
|
571
|
-
def _pad_seqs_inplace(seqs: List[List[int]], pad: int):
|
|
398
|
+
def _pad_seqs_inplace(seqs: List[List[int]], reqs: List[InferenceRequest], pad: int):
|
|
572
399
|
max_len = max(len(seq) for seq in seqs)
|
|
573
400
|
n = len(seqs)
|
|
574
401
|
i = 0
|
|
575
402
|
while i < n:
|
|
403
|
+
prev_seq_len = len(seqs[i])
|
|
576
404
|
seqs[i] = _pad_to_max_length(seqs[i], max_len, pad)
|
|
405
|
+
padding_len = len(seqs[i]) - prev_seq_len
|
|
406
|
+
reqs[i].padding_len = padding_len
|
|
577
407
|
i += 1
|
|
578
408
|
|
|
579
409
|
|
|
@@ -586,6 +416,7 @@ def get_max_src_len(context_len: int, r: InferenceRequest) -> int:
|
|
|
586
416
|
|
|
587
417
|
def _get_completion_chunk(
|
|
588
418
|
output: str,
|
|
419
|
+
chunk_id: str,
|
|
589
420
|
finish_reason: Optional[str],
|
|
590
421
|
model_uid: str,
|
|
591
422
|
r: InferenceRequest,
|
|
@@ -601,7 +432,7 @@ def _get_completion_chunk(
|
|
|
601
432
|
else []
|
|
602
433
|
)
|
|
603
434
|
completion_chunk = CompletionChunk(
|
|
604
|
-
id=
|
|
435
|
+
id=chunk_id,
|
|
605
436
|
object="text_completion",
|
|
606
437
|
created=int(time.time()),
|
|
607
438
|
model=model_uid,
|
|
@@ -617,14 +448,18 @@ def _get_completion_chunk(
|
|
|
617
448
|
|
|
618
449
|
|
|
619
450
|
def _get_completion(
|
|
620
|
-
output: str,
|
|
451
|
+
output: str,
|
|
452
|
+
chunk_id: str,
|
|
453
|
+
finish_reason: Optional[str],
|
|
454
|
+
model_uid: str,
|
|
455
|
+
r: InferenceRequest,
|
|
621
456
|
):
|
|
622
457
|
completion_choice = CompletionChoice(
|
|
623
458
|
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
624
459
|
)
|
|
625
460
|
|
|
626
461
|
completion_chunk = CompletionChunk(
|
|
627
|
-
id=
|
|
462
|
+
id=chunk_id,
|
|
628
463
|
object="text_completion",
|
|
629
464
|
created=int(time.time()),
|
|
630
465
|
model=model_uid,
|
|
@@ -674,6 +509,25 @@ def _merge_kv_cache(
|
|
|
674
509
|
return ret_kv.to_legacy_cache()
|
|
675
510
|
|
|
676
511
|
|
|
512
|
+
def _get_attention_mask_and_position_ids(kv, reqs: List[InferenceRequest]):
|
|
513
|
+
batch_size, seq_length, device = (
|
|
514
|
+
kv[0][0].shape[0],
|
|
515
|
+
kv[0][0].shape[2],
|
|
516
|
+
kv[0][0].device,
|
|
517
|
+
)
|
|
518
|
+
seq_length = seq_length + 1
|
|
519
|
+
position_ids = torch.as_tensor([[seq_length - 1]], dtype=torch.long, device=device)
|
|
520
|
+
attention_mask = torch.ones(
|
|
521
|
+
(batch_size, seq_length), dtype=torch.long, device=device
|
|
522
|
+
)
|
|
523
|
+
padding_lens = torch.as_tensor([r.padding_len for r in reqs])
|
|
524
|
+
mask = torch.arange(seq_length).expand(
|
|
525
|
+
batch_size, seq_length
|
|
526
|
+
) < padding_lens.unsqueeze(1)
|
|
527
|
+
attention_mask[mask] = 0
|
|
528
|
+
return attention_mask, position_ids
|
|
529
|
+
|
|
530
|
+
|
|
677
531
|
@torch.inference_mode()
|
|
678
532
|
def _batch_inference_one_step_internal(
|
|
679
533
|
req_list: List[InferenceRequest],
|
|
@@ -682,7 +536,9 @@ def _batch_inference_one_step_internal(
|
|
|
682
536
|
tokenizer,
|
|
683
537
|
device,
|
|
684
538
|
context_len: int,
|
|
539
|
+
stop_tokens: Tuple[int],
|
|
685
540
|
decode_round: int = 16,
|
|
541
|
+
require_attention_mask: bool = False,
|
|
686
542
|
bos_flag: str = "<bos_stream>",
|
|
687
543
|
eos_flag: str = "<eos_stream>",
|
|
688
544
|
):
|
|
@@ -692,7 +548,8 @@ def _batch_inference_one_step_internal(
|
|
|
692
548
|
if not valid_req_list:
|
|
693
549
|
return
|
|
694
550
|
generate_config_mapping: Dict[InferenceRequest, Tuple] = {
|
|
695
|
-
r: r.get_generate_configs(tokenizer.eos_token_id)
|
|
551
|
+
r: r.get_generate_configs(tokenizer.eos_token_id, stop_tokens)
|
|
552
|
+
for r in valid_req_list
|
|
696
553
|
}
|
|
697
554
|
s_time = time.time()
|
|
698
555
|
|
|
@@ -701,7 +558,7 @@ def _batch_inference_one_step_internal(
|
|
|
701
558
|
decode_reqs = []
|
|
702
559
|
for r in valid_req_list:
|
|
703
560
|
if r.is_prefill:
|
|
704
|
-
prompts.append(r.full_prompt)
|
|
561
|
+
prompts.append(r.full_prompt if r.full_prompt is not None else r.prompt)
|
|
705
562
|
prefill_reqs.append(r)
|
|
706
563
|
else:
|
|
707
564
|
decode_reqs.append(r)
|
|
@@ -714,7 +571,7 @@ def _batch_inference_one_step_internal(
|
|
|
714
571
|
max_src_len = get_max_src_len(context_len, req)
|
|
715
572
|
req.prompt_tokens = input_id[-max_src_len:]
|
|
716
573
|
prompt_tokens.append(req.prompt_tokens)
|
|
717
|
-
_pad_seqs_inplace(prompt_tokens, 0)
|
|
574
|
+
_pad_seqs_inplace(prompt_tokens, valid_req_list, 0)
|
|
718
575
|
out = model(torch.as_tensor(prompt_tokens, device=device), use_cache=True)
|
|
719
576
|
|
|
720
577
|
logits = out.logits
|
|
@@ -756,10 +613,18 @@ def _batch_inference_one_step_internal(
|
|
|
756
613
|
# here, only decode phase, just run some rounds
|
|
757
614
|
for _i in range(decode_round):
|
|
758
615
|
decode_tokens: List[List[int]] = [[r.new_tokens[-1]] for r in valid_req_list]
|
|
616
|
+
inf_kws = {}
|
|
617
|
+
if require_attention_mask:
|
|
618
|
+
attention_mask, position_ids = _get_attention_mask_and_position_ids(
|
|
619
|
+
past_key_values, valid_req_list
|
|
620
|
+
)
|
|
621
|
+
inf_kws["position_ids"] = position_ids
|
|
622
|
+
inf_kws["attention_mask"] = attention_mask
|
|
759
623
|
out = model(
|
|
760
624
|
input_ids=torch.as_tensor(decode_tokens, device=device),
|
|
761
625
|
use_cache=True,
|
|
762
626
|
past_key_values=past_key_values,
|
|
627
|
+
**inf_kws,
|
|
763
628
|
)
|
|
764
629
|
logits = out.logits
|
|
765
630
|
past_key_values = out.past_key_values
|
|
@@ -846,7 +711,7 @@ def _batch_inference_one_step_internal(
|
|
|
846
711
|
r.last_output_length += len(output)
|
|
847
712
|
|
|
848
713
|
completion_chunk = _get_completion_chunk(
|
|
849
|
-
output, r.finish_reason, model_uid, r, False
|
|
714
|
+
output, r.chunk_id, r.finish_reason, model_uid, r, False
|
|
850
715
|
)
|
|
851
716
|
r.completion.append(completion_chunk)
|
|
852
717
|
if r.stopped:
|
|
@@ -859,7 +724,7 @@ def _batch_inference_one_step_internal(
|
|
|
859
724
|
if r.stopped and _i == decode_round - 1 and include_usage:
|
|
860
725
|
r.completion.append(
|
|
861
726
|
_get_completion_chunk(
|
|
862
|
-
"", r.finish_reason, model_uid, r, True
|
|
727
|
+
"", r.chunk_id, r.finish_reason, model_uid, r, True
|
|
863
728
|
)
|
|
864
729
|
)
|
|
865
730
|
else:
|
|
@@ -878,7 +743,9 @@ def _batch_inference_one_step_internal(
|
|
|
878
743
|
if r not in output_mapping
|
|
879
744
|
else output_mapping[r]
|
|
880
745
|
)
|
|
881
|
-
completion = _get_completion(
|
|
746
|
+
completion = _get_completion(
|
|
747
|
+
outputs, r.chunk_id, r.finish_reason, model_uid, r
|
|
748
|
+
)
|
|
882
749
|
r.completion = [completion]
|
|
883
750
|
|
|
884
751
|
e_time = time.time()
|
|
@@ -894,12 +761,21 @@ def batch_inference_one_step(
|
|
|
894
761
|
tokenizer,
|
|
895
762
|
device,
|
|
896
763
|
context_len: int,
|
|
764
|
+
stop_token_ids: Tuple[int],
|
|
765
|
+
require_attention_mask: bool = False,
|
|
897
766
|
):
|
|
898
767
|
from ....core.model import OutOfMemoryError
|
|
899
768
|
|
|
900
769
|
try:
|
|
901
770
|
_batch_inference_one_step_internal(
|
|
902
|
-
req_list,
|
|
771
|
+
req_list,
|
|
772
|
+
model_uid,
|
|
773
|
+
model,
|
|
774
|
+
tokenizer,
|
|
775
|
+
device,
|
|
776
|
+
context_len,
|
|
777
|
+
stop_token_ids,
|
|
778
|
+
require_attention_mask=require_attention_mask,
|
|
903
779
|
)
|
|
904
780
|
except OutOfMemoryError:
|
|
905
781
|
logger.exception(
|
|
@@ -911,4 +787,8 @@ def batch_inference_one_step(
|
|
|
911
787
|
os._exit(1)
|
|
912
788
|
except Exception as e:
|
|
913
789
|
logger.exception(f"Internal error for batch inference: {e}.")
|
|
914
|
-
#
|
|
790
|
+
# If internal error happens, just skip all the requests in this batch.
|
|
791
|
+
# If not handle here, the client will hang.
|
|
792
|
+
for r in req_list:
|
|
793
|
+
r.stopped = True
|
|
794
|
+
r.error_msg = str(e)
|
xinference/model/llm/utils.py
CHANGED
|
@@ -607,7 +607,7 @@ Begin!"""
|
|
|
607
607
|
return arguments, None, None
|
|
608
608
|
|
|
609
609
|
@staticmethod
|
|
610
|
-
def
|
|
610
|
+
def _eval_glm_chat_arguments(c, tools):
|
|
611
611
|
if isinstance(c[0], str):
|
|
612
612
|
return c[0], None, None
|
|
613
613
|
return None, c[0]["name"], c[0]["parameters"]
|
|
@@ -659,9 +659,15 @@ Begin!"""
|
|
|
659
659
|
family = model_family.model_family or model_family.model_name
|
|
660
660
|
if family in ["gorilla-openfunctions-v1", "gorilla-openfunctions-v2"]:
|
|
661
661
|
content, func, args = cls._eval_gorilla_openfunctions_arguments(c, tools)
|
|
662
|
-
elif "chatglm3"
|
|
663
|
-
content, func, args = cls.
|
|
664
|
-
elif family in [
|
|
662
|
+
elif family in ["chatglm3", "glm4-chat"]:
|
|
663
|
+
content, func, args = cls._eval_glm_chat_arguments(c, tools)
|
|
664
|
+
elif family in [
|
|
665
|
+
"qwen-chat",
|
|
666
|
+
"qwen1.5-chat",
|
|
667
|
+
"qwen1.5-moe-chat",
|
|
668
|
+
"qwen2-instruct",
|
|
669
|
+
"qwen2-moe-instruct",
|
|
670
|
+
]:
|
|
665
671
|
content, func, args = cls._eval_qwen_chat_arguments(c, tools)
|
|
666
672
|
else:
|
|
667
673
|
raise Exception(
|
|
@@ -676,28 +682,35 @@ Begin!"""
|
|
|
676
682
|
Generates a filter function for Qwen series models to retain outputs after "\nFinal Answer:".
|
|
677
683
|
|
|
678
684
|
Returns:
|
|
679
|
-
A function that takes tokens (string output by the model so far) as input
|
|
680
|
-
returns
|
|
685
|
+
A function that takes tokens (string output by the model so far) and delta (new tokens added) as input,
|
|
686
|
+
returns the part after "\nFinal Answer:" if found, else returns delta.
|
|
681
687
|
"""
|
|
682
688
|
family = model_family.model_family or model_family.model_name
|
|
683
|
-
if family in [
|
|
689
|
+
if family in [
|
|
690
|
+
"qwen-chat",
|
|
691
|
+
"qwen1.5-chat",
|
|
692
|
+
"qwen1.5-moe-chat",
|
|
693
|
+
"qwen2-instruct",
|
|
694
|
+
"qwen2-moe-instruct",
|
|
695
|
+
]:
|
|
684
696
|
# Encapsulating function to reset 'found' after each call
|
|
685
697
|
found = False
|
|
686
698
|
|
|
687
|
-
def
|
|
699
|
+
def process_tokens(tokens: str, delta: str):
|
|
688
700
|
nonlocal found
|
|
689
701
|
# Once "Final Answer:" is found, future tokens are allowed.
|
|
690
702
|
if found:
|
|
691
|
-
return
|
|
703
|
+
return delta
|
|
692
704
|
# Check if the token ends with "\nFinal Answer:" and update `found`.
|
|
693
|
-
|
|
705
|
+
final_answer_idx = tokens.lower().rfind("\nfinal answer:")
|
|
706
|
+
if final_answer_idx != -1:
|
|
694
707
|
found = True
|
|
695
|
-
|
|
708
|
+
return tokens[final_answer_idx + len("\nfinal answer:") :]
|
|
709
|
+
return ""
|
|
696
710
|
|
|
697
|
-
return
|
|
711
|
+
return process_tokens
|
|
698
712
|
else:
|
|
699
|
-
|
|
700
|
-
return lambda tokens: True
|
|
713
|
+
return lambda tokens, delta: delta
|
|
701
714
|
|
|
702
715
|
@classmethod
|
|
703
716
|
def _tool_calls_completion(cls, model_family, model_uid, c, tools):
|
|
@@ -444,7 +444,9 @@ class VLLMModel(LLM):
|
|
|
444
444
|
_content, func, args = ChatModelMixin._eval_tool_arguments(
|
|
445
445
|
self.model_family, chunk, tools
|
|
446
446
|
)
|
|
447
|
-
choice["text"] =
|
|
447
|
+
choice["text"] = tools_token_filter(
|
|
448
|
+
tokens=previous_texts[0], delta=choice_delta
|
|
449
|
+
)
|
|
448
450
|
if func is not None:
|
|
449
451
|
choice["text"] = None
|
|
450
452
|
choice["finish_reason"] = "tool_calls"
|
|
@@ -458,9 +460,13 @@ class VLLMModel(LLM):
|
|
|
458
460
|
),
|
|
459
461
|
)
|
|
460
462
|
]
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
463
|
+
else:
|
|
464
|
+
# use a filter function to skip Qwen's react thought process
|
|
465
|
+
choice["text"] = tools_token_filter(
|
|
466
|
+
tokens=previous_texts[0], delta=choice["text"]
|
|
467
|
+
)
|
|
468
|
+
if not choice["text"]:
|
|
469
|
+
continue
|
|
464
470
|
prompt_tokens = len(_request_output.prompt_token_ids)
|
|
465
471
|
completion_tokens = sum(
|
|
466
472
|
len(output.token_ids) for output in _request_output.outputs
|
xinference/model/rerank/core.py
CHANGED
|
@@ -23,7 +23,7 @@ import numpy as np
|
|
|
23
23
|
|
|
24
24
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
25
25
|
from ...device_utils import empty_cache
|
|
26
|
-
from ...types import Document, DocumentObj, Rerank
|
|
26
|
+
from ...types import Document, DocumentObj, Rerank, RerankTokens
|
|
27
27
|
from ..core import CacheableModelSpec, ModelDescription
|
|
28
28
|
from ..utils import is_model_cached
|
|
29
29
|
|
|
@@ -121,11 +121,17 @@ class RerankModel:
|
|
|
121
121
|
if model_spec.type == "unknown":
|
|
122
122
|
model_spec.type = self._auto_detect_type(model_path)
|
|
123
123
|
|
|
124
|
+
@staticmethod
|
|
125
|
+
def _get_tokenizer(model_path):
|
|
126
|
+
from transformers import AutoTokenizer
|
|
127
|
+
|
|
128
|
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
129
|
+
return tokenizer
|
|
130
|
+
|
|
124
131
|
@staticmethod
|
|
125
132
|
def _auto_detect_type(model_path):
|
|
126
133
|
"""This method may not be stable due to the fact that the tokenizer name may be changed.
|
|
127
134
|
Therefore, we only use this method for unknown model types."""
|
|
128
|
-
from transformers import AutoTokenizer
|
|
129
135
|
|
|
130
136
|
type_mapper = {
|
|
131
137
|
"LlamaTokenizerFast": "LLM-based layerwise",
|
|
@@ -133,12 +139,13 @@ class RerankModel:
|
|
|
133
139
|
"XLMRobertaTokenizerFast": "normal",
|
|
134
140
|
}
|
|
135
141
|
|
|
136
|
-
tokenizer =
|
|
142
|
+
tokenizer = RerankModel._get_tokenizer(model_path)
|
|
137
143
|
rerank_type = type_mapper.get(type(tokenizer).__name__)
|
|
138
144
|
if rerank_type is None:
|
|
139
|
-
|
|
140
|
-
f"Can't determine the rerank type based on the tokenizer {tokenizer}"
|
|
145
|
+
logger.warning(
|
|
146
|
+
f"Can't determine the rerank type based on the tokenizer {tokenizer}, use normal type by default."
|
|
141
147
|
)
|
|
148
|
+
return "normal"
|
|
142
149
|
return rerank_type
|
|
143
150
|
|
|
144
151
|
def load(self):
|
|
@@ -185,6 +192,7 @@ class RerankModel:
|
|
|
185
192
|
top_n: Optional[int],
|
|
186
193
|
max_chunks_per_doc: Optional[int],
|
|
187
194
|
return_documents: Optional[bool],
|
|
195
|
+
return_len: Optional[bool],
|
|
188
196
|
**kwargs,
|
|
189
197
|
) -> Rerank:
|
|
190
198
|
self._counter += 1
|
|
@@ -223,7 +231,28 @@ class RerankModel:
|
|
|
223
231
|
)
|
|
224
232
|
for arg in sim_scores_argsort
|
|
225
233
|
]
|
|
226
|
-
|
|
234
|
+
if return_len:
|
|
235
|
+
tokenizer = self._get_tokenizer(self._model_path)
|
|
236
|
+
input_len = sum([len(tokenizer.tokenize(t)) for t in documents])
|
|
237
|
+
|
|
238
|
+
# Rerank Model output is just score or documents
|
|
239
|
+
# while return_documents = True
|
|
240
|
+
output_len = input_len
|
|
241
|
+
|
|
242
|
+
# api_version, billed_units, warnings
|
|
243
|
+
# is for Cohere API compatibility, set to None
|
|
244
|
+
metadata = {
|
|
245
|
+
"api_version": None,
|
|
246
|
+
"billed_units": None,
|
|
247
|
+
"tokens": (
|
|
248
|
+
RerankTokens(input_tokens=input_len, output_tokens=output_len)
|
|
249
|
+
if return_len
|
|
250
|
+
else None
|
|
251
|
+
),
|
|
252
|
+
"warnings": None,
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)
|
|
227
256
|
|
|
228
257
|
|
|
229
258
|
def get_cache_dir(model_spec: RerankModelSpec):
|
xinference/model/utils.py
CHANGED
|
@@ -42,14 +42,20 @@ def is_locale_chinese_simplified() -> bool:
|
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
def download_from_modelscope() -> bool:
|
|
45
|
-
if os.environ.get(XINFERENCE_ENV_MODEL_SRC)
|
|
46
|
-
return
|
|
45
|
+
if os.environ.get(XINFERENCE_ENV_MODEL_SRC):
|
|
46
|
+
return os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "modelscope"
|
|
47
47
|
elif is_locale_chinese_simplified():
|
|
48
48
|
return True
|
|
49
49
|
else:
|
|
50
50
|
return False
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def download_from_csghub() -> bool:
|
|
54
|
+
if os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "csghub":
|
|
55
|
+
return True
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
|
|
53
59
|
def symlink_local_file(path: str, local_dir: str, relpath: str) -> str:
|
|
54
60
|
from huggingface_hub.file_download import _create_symlink
|
|
55
61
|
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
|
|
2
|
+
from openai import OpenAI
|
|
3
|
+
|
|
4
|
+
prompt_dict = {
|
|
5
|
+
'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
|
|
6
|
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
|
7
|
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
|
8
|
+
'deepseek': [
|
|
9
|
+
{"role": "system", "content": "You are a helpful assistant"},
|
|
10
|
+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
|
|
11
|
+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
|
|
12
|
+
'deepseek_TN': [
|
|
13
|
+
{"role": "system", "content": "You are a helpful assistant"},
|
|
14
|
+
{"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
|
|
15
|
+
{"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
|
|
16
|
+
{"role": "user", "content": "We paid $123 for this desk."},
|
|
17
|
+
{"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
|
|
18
|
+
{"role": "user", "content": "详询请拨打010-724654"},
|
|
19
|
+
{"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
|
|
20
|
+
{"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
|
|
21
|
+
{"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
|
|
22
|
+
],
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
class llm_api:
|
|
26
|
+
def __init__(self, api_key, base_url, model):
|
|
27
|
+
self.client = OpenAI(
|
|
28
|
+
api_key = api_key,
|
|
29
|
+
base_url = base_url,
|
|
30
|
+
)
|
|
31
|
+
self.model = model
|
|
32
|
+
def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
|
|
33
|
+
|
|
34
|
+
completion = self.client.chat.completions.create(
|
|
35
|
+
model = self.model,
|
|
36
|
+
messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
|
|
37
|
+
temperature = temperature,
|
|
38
|
+
**kwargs
|
|
39
|
+
)
|
|
40
|
+
return completion.choices[0].message.content
|
|
File without changes
|