sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ limitations under the License.
|
|
19
19
|
|
20
20
|
import logging
|
21
21
|
from dataclasses import dataclass
|
22
|
-
from typing import
|
22
|
+
from typing import List, Optional, Tuple, Union
|
23
23
|
|
24
24
|
import torch
|
25
25
|
|
@@ -29,20 +29,19 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
|
|
29
29
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
30
30
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
31
31
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
32
33
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
33
|
-
|
34
|
-
if TYPE_CHECKING:
|
35
|
-
from sglang.srt.layers.sampler import SampleOutput
|
36
|
-
|
34
|
+
from sglang.srt.server_args import ServerArgs
|
37
35
|
|
38
36
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
39
37
|
|
40
38
|
# Put some global args for easy access
|
41
39
|
global_server_args_dict = {
|
42
|
-
"
|
43
|
-
"
|
44
|
-
"triton_attention_reduce_in_fp32":
|
45
|
-
"enable_mla":
|
40
|
+
"attention_backend": ServerArgs.attention_backend,
|
41
|
+
"sampling_backend": ServerArgs.sampling_backend,
|
42
|
+
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
43
|
+
"enable_mla": ServerArgs.enable_mla,
|
44
|
+
"torchao_config": ServerArgs.torchao_config,
|
46
45
|
}
|
47
46
|
|
48
47
|
|
@@ -53,8 +52,8 @@ class BaseFinishReason:
|
|
53
52
|
def __init__(self, is_error: bool = False):
|
54
53
|
self.is_error = is_error
|
55
54
|
|
56
|
-
def
|
57
|
-
raise NotImplementedError(
|
55
|
+
def to_json(self):
|
56
|
+
raise NotImplementedError()
|
58
57
|
|
59
58
|
|
60
59
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
@@ -62,40 +61,57 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
|
62
61
|
super().__init__()
|
63
62
|
self.matched = matched
|
64
63
|
|
65
|
-
def
|
66
|
-
return
|
64
|
+
def to_json(self):
|
65
|
+
return {
|
66
|
+
"type": "stop", # to match OpenAI API's return value
|
67
|
+
"matched": self.matched,
|
68
|
+
}
|
67
69
|
|
68
70
|
|
69
|
-
class
|
70
|
-
def __init__(self,
|
71
|
+
class FINISH_MATCHED_STR(BaseFinishReason):
|
72
|
+
def __init__(self, matched: str):
|
71
73
|
super().__init__()
|
72
|
-
self.
|
74
|
+
self.matched = matched
|
73
75
|
|
74
|
-
def
|
75
|
-
return
|
76
|
+
def to_json(self):
|
77
|
+
return {
|
78
|
+
"type": "stop", # to match OpenAI API's return value
|
79
|
+
"matched": self.matched,
|
80
|
+
}
|
76
81
|
|
77
82
|
|
78
|
-
class
|
79
|
-
def __init__(self,
|
83
|
+
class FINISH_LENGTH(BaseFinishReason):
|
84
|
+
def __init__(self, length: int):
|
80
85
|
super().__init__()
|
81
|
-
self.
|
86
|
+
self.length = length
|
82
87
|
|
83
|
-
def
|
84
|
-
return
|
88
|
+
def to_json(self):
|
89
|
+
return {
|
90
|
+
"type": "length", # to match OpenAI API's return value
|
91
|
+
"length": self.length,
|
92
|
+
}
|
85
93
|
|
86
94
|
|
87
95
|
class FINISH_ABORT(BaseFinishReason):
|
88
96
|
def __init__(self):
|
89
97
|
super().__init__(is_error=True)
|
90
98
|
|
91
|
-
def
|
92
|
-
return
|
99
|
+
def to_json(self):
|
100
|
+
return {
|
101
|
+
"type": "abort",
|
102
|
+
}
|
93
103
|
|
94
104
|
|
95
105
|
class Req:
|
96
106
|
"""Store all inforamtion of a request."""
|
97
107
|
|
98
|
-
def __init__(
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
rid: str,
|
111
|
+
origin_input_text: str,
|
112
|
+
origin_input_ids: Tuple[int],
|
113
|
+
lora_path: Optional[str] = None,
|
114
|
+
):
|
99
115
|
# Input and output info
|
100
116
|
self.rid = rid
|
101
117
|
self.origin_input_text = origin_input_text
|
@@ -103,10 +119,15 @@ class Req:
|
|
103
119
|
self.origin_input_ids = origin_input_ids
|
104
120
|
self.output_ids = [] # Each decode stage's output ids
|
105
121
|
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
122
|
+
self.lora_path = lora_path
|
106
123
|
|
107
124
|
# Memory info
|
108
125
|
self.req_pool_idx = None
|
109
126
|
|
127
|
+
# Check finish
|
128
|
+
self.tokenizer = None
|
129
|
+
self.finished_reason = None
|
130
|
+
|
110
131
|
# For incremental decoding
|
111
132
|
# ----- | --------- read_ids -------|
|
112
133
|
# ----- | surr_ids |
|
@@ -125,38 +146,43 @@ class Req:
|
|
125
146
|
# this does not include the jump forward tokens.
|
126
147
|
self.completion_tokens_wo_jump_forward = 0
|
127
148
|
|
128
|
-
# For vision
|
149
|
+
# For vision inputs
|
129
150
|
self.pixel_values = None
|
130
151
|
self.image_sizes = None
|
131
152
|
self.image_offsets = None
|
132
153
|
self.pad_value = None
|
154
|
+
self.modalities = None
|
133
155
|
|
134
156
|
# Prefix info
|
135
|
-
self.extend_input_len = 0
|
136
157
|
self.prefix_indices = []
|
158
|
+
self.extend_input_len = 0
|
137
159
|
self.last_node = None
|
138
160
|
|
139
161
|
# Sampling parameters
|
140
162
|
self.sampling_params = None
|
141
163
|
self.stream = False
|
142
164
|
|
143
|
-
#
|
144
|
-
self.tokenizer = None
|
145
|
-
self.finished_reason = None
|
146
|
-
|
147
|
-
# Logprobs
|
165
|
+
# Logprobs (arguments)
|
148
166
|
self.return_logprob = False
|
149
|
-
self.embedding = None
|
150
167
|
self.logprob_start_len = 0
|
151
168
|
self.top_logprobs_num = 0
|
169
|
+
|
170
|
+
# Logprobs (return value)
|
152
171
|
self.normalized_prompt_logprob = None
|
153
172
|
self.input_token_logprobs = None
|
154
173
|
self.input_top_logprobs = None
|
155
174
|
self.output_token_logprobs = []
|
156
175
|
self.output_top_logprobs = []
|
176
|
+
|
177
|
+
# Logprobs (internal values)
|
157
178
|
# The tokens is prefilled but need to be considered as decode tokens
|
158
179
|
# and should be updated for the decode logprobs
|
159
180
|
self.last_update_decode_tokens = 0
|
181
|
+
# The relative logprob_start_len in an extend batch
|
182
|
+
self.extend_logprob_start_len = 0
|
183
|
+
|
184
|
+
# Embedding
|
185
|
+
self.embedding = None
|
160
186
|
|
161
187
|
# Constrained decoding
|
162
188
|
self.regex_fsm: RegexGuide = None
|
@@ -178,19 +204,22 @@ class Req:
|
|
178
204
|
def adjust_max_prefix_ids(self):
|
179
205
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
180
206
|
input_len = len(self.fill_ids)
|
181
|
-
|
207
|
+
|
208
|
+
# FIXME: To work around some bugs in logprob computation, we need to ensure each
|
209
|
+
# request has at least one token. Later, we can relax this requirement and use `input_len`.
|
210
|
+
max_prefix_len = input_len - 1
|
182
211
|
|
183
212
|
if self.sampling_params.max_new_tokens > 0:
|
184
213
|
# Need at least one token to compute logits
|
185
214
|
max_prefix_len = min(max_prefix_len, input_len - 1)
|
186
215
|
|
187
216
|
if self.return_logprob:
|
188
|
-
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
189
|
-
|
190
217
|
if self.normalized_prompt_logprob is None:
|
191
218
|
# Need at least two tokens to compute normalized logprob
|
192
219
|
max_prefix_len = min(max_prefix_len, input_len - 2)
|
220
|
+
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
193
221
|
|
222
|
+
max_prefix_len = max(max_prefix_len, 0)
|
194
223
|
return self.fill_ids[:max_prefix_len]
|
195
224
|
|
196
225
|
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
@@ -330,6 +359,8 @@ class ScheduleBatch:
|
|
330
359
|
token_to_kv_pool: BaseTokenToKVPool
|
331
360
|
tree_cache: BasePrefixCache
|
332
361
|
|
362
|
+
forward_mode: ForwardMode = None
|
363
|
+
|
333
364
|
# Batched arguments to model runner
|
334
365
|
input_ids: torch.Tensor = None
|
335
366
|
req_pool_indices: torch.Tensor = None
|
@@ -340,14 +371,19 @@ class ScheduleBatch:
|
|
340
371
|
|
341
372
|
# For mixed chunekd prefill
|
342
373
|
prefix_lens_cpu: List[int] = None
|
374
|
+
running_bs: int = None
|
343
375
|
|
344
376
|
# For processing logprobs
|
345
377
|
return_logprob: bool = False
|
346
378
|
top_logprobs_nums: List[int] = None
|
347
379
|
|
380
|
+
# Stream
|
381
|
+
has_stream: bool = False
|
382
|
+
|
348
383
|
@classmethod
|
349
384
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
350
385
|
return_logprob = any(req.return_logprob for req in reqs)
|
386
|
+
has_stream = any(req.stream for req in reqs)
|
351
387
|
|
352
388
|
return cls(
|
353
389
|
reqs=reqs,
|
@@ -355,18 +391,15 @@ class ScheduleBatch:
|
|
355
391
|
token_to_kv_pool=token_to_kv_pool,
|
356
392
|
tree_cache=tree_cache,
|
357
393
|
return_logprob=return_logprob,
|
394
|
+
has_stream=has_stream,
|
358
395
|
)
|
359
396
|
|
360
397
|
def batch_size(self):
|
361
|
-
return len(self.reqs)
|
398
|
+
return len(self.reqs)
|
362
399
|
|
363
400
|
def is_empty(self):
|
364
401
|
return len(self.reqs) == 0
|
365
402
|
|
366
|
-
def has_stream(self) -> bool:
|
367
|
-
# Return whether batch has at least 1 streaming request
|
368
|
-
return any(r.stream for r in self.reqs)
|
369
|
-
|
370
403
|
def alloc_req_slots(self, num_reqs):
|
371
404
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
372
405
|
if req_pool_indices is None:
|
@@ -393,6 +426,8 @@ class ScheduleBatch:
|
|
393
426
|
return out_cache_loc
|
394
427
|
|
395
428
|
def prepare_for_extend(self, vocab_size: int):
|
429
|
+
self.forward_mode = ForwardMode.EXTEND
|
430
|
+
|
396
431
|
bs = self.batch_size()
|
397
432
|
reqs = self.reqs
|
398
433
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
@@ -407,8 +442,8 @@ class ScheduleBatch:
|
|
407
442
|
for i, req in enumerate(reqs):
|
408
443
|
req.req_pool_idx = req_pool_indices_cpu[i]
|
409
444
|
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
410
|
-
ext_len = seq_len - pre_len
|
411
445
|
seq_lens.append(seq_len)
|
446
|
+
assert seq_len - pre_len == req.extend_input_len
|
412
447
|
|
413
448
|
if pre_len > 0:
|
414
449
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
@@ -416,9 +451,19 @@ class ScheduleBatch:
|
|
416
451
|
] = req.prefix_indices
|
417
452
|
|
418
453
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
|
419
|
-
out_cache_loc[pt : pt +
|
454
|
+
out_cache_loc[pt : pt + req.extend_input_len]
|
420
455
|
)
|
421
|
-
|
456
|
+
|
457
|
+
# Compute the relative logprob_start_len in an extend batch
|
458
|
+
if req.logprob_start_len >= pre_len:
|
459
|
+
extend_logprob_start_len = min(
|
460
|
+
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
461
|
+
)
|
462
|
+
else:
|
463
|
+
extend_logprob_start_len = req.extend_input_len - 1
|
464
|
+
|
465
|
+
req.extend_logprob_start_len = extend_logprob_start_len
|
466
|
+
pt += req.extend_input_len
|
422
467
|
|
423
468
|
# Set fields
|
424
469
|
with torch.device("cuda"):
|
@@ -431,18 +476,13 @@ class ScheduleBatch:
|
|
431
476
|
self.out_cache_loc = out_cache_loc
|
432
477
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
433
478
|
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
|
434
|
-
|
479
|
+
self.extend_lens_cpu = [r.extend_input_len for r in reqs]
|
480
|
+
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
|
435
481
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
|
436
482
|
|
437
483
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
438
|
-
|
439
|
-
|
440
|
-
prefix_lens_cpu.extend(
|
441
|
-
[
|
442
|
-
len(r.origin_input_ids) + len(r.output_ids) - 1
|
443
|
-
for r in running_batch.reqs
|
444
|
-
]
|
445
|
-
)
|
484
|
+
self.forward_mode = ForwardMode.MIXED
|
485
|
+
running_bs = running_batch.batch_size()
|
446
486
|
|
447
487
|
for req in running_batch.reqs:
|
448
488
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
@@ -450,12 +490,22 @@ class ScheduleBatch:
|
|
450
490
|
|
451
491
|
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
|
452
492
|
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
|
453
|
-
extend_num_tokens = self.extend_num_tokens +
|
493
|
+
extend_num_tokens = self.extend_num_tokens + running_bs
|
494
|
+
|
454
495
|
self.merge(running_batch)
|
455
496
|
self.input_ids = input_ids
|
456
497
|
self.out_cache_loc = out_cache_loc
|
457
498
|
self.extend_num_tokens = extend_num_tokens
|
458
|
-
|
499
|
+
|
500
|
+
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
501
|
+
self.prefix_lens_cpu.extend(
|
502
|
+
[
|
503
|
+
len(r.origin_input_ids) + len(r.output_ids) - 1
|
504
|
+
for r in running_batch.reqs
|
505
|
+
]
|
506
|
+
)
|
507
|
+
self.extend_lens_cpu.extend([1] * running_bs)
|
508
|
+
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
|
459
509
|
|
460
510
|
def check_decode_mem(self):
|
461
511
|
bs = self.batch_size()
|
@@ -622,6 +672,8 @@ class ScheduleBatch:
|
|
622
672
|
return jump_forward_reqs
|
623
673
|
|
624
674
|
def prepare_for_decode(self, input_ids=None):
|
675
|
+
self.forward_mode = ForwardMode.DECODE
|
676
|
+
|
625
677
|
if input_ids is None:
|
626
678
|
input_ids = [
|
627
679
|
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
|
@@ -641,8 +693,6 @@ class ScheduleBatch:
|
|
641
693
|
self.req_pool_indices, self.seq_lens - 1
|
642
694
|
] = self.out_cache_loc
|
643
695
|
|
644
|
-
self.sampling_info.update_regex_vocab_mask(self)
|
645
|
-
|
646
696
|
def filter_batch(self, unfinished_indices: List[int]):
|
647
697
|
if unfinished_indices is None or len(unfinished_indices) == 0:
|
648
698
|
# Filter out all requests
|
@@ -662,6 +712,7 @@ class ScheduleBatch:
|
|
662
712
|
self.out_cache_loc = None
|
663
713
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
664
714
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
715
|
+
self.has_stream = any(req.stream for req in self.reqs)
|
665
716
|
|
666
717
|
self.sampling_info.filter(unfinished_indices, new_indices)
|
667
718
|
|
@@ -672,7 +723,6 @@ class ScheduleBatch:
|
|
672
723
|
self.sampling_info.merge(other.sampling_info)
|
673
724
|
|
674
725
|
self.reqs.extend(other.reqs)
|
675
|
-
|
676
726
|
self.req_pool_indices = torch.concat(
|
677
727
|
[self.req_pool_indices, other.req_pool_indices]
|
678
728
|
)
|
@@ -683,18 +733,4 @@ class ScheduleBatch:
|
|
683
733
|
self.out_cache_loc = None
|
684
734
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
685
735
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
686
|
-
|
687
|
-
def check_sample_results(self, sample_output: SampleOutput):
|
688
|
-
if not torch.all(sample_output.success):
|
689
|
-
probs = sample_output.probs
|
690
|
-
batch_next_token_ids = sample_output.batch_next_token_ids
|
691
|
-
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
692
|
-
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
693
|
-
argmax_ids = torch.argmax(probs, dim=-1)
|
694
|
-
batch_next_token_ids = torch.where(
|
695
|
-
sample_output.success, batch_next_token_ids, argmax_ids
|
696
|
-
)
|
697
|
-
sample_output.probs = probs
|
698
|
-
sample_output.batch_next_token_ids = batch_next_token_ids
|
699
|
-
|
700
|
-
return sample_output.batch_next_token_ids
|
736
|
+
self.has_stream = any(req.stream for req in self.reqs)
|
@@ -18,6 +18,7 @@ limitations under the License.
|
|
18
18
|
import asyncio
|
19
19
|
import concurrent.futures
|
20
20
|
import dataclasses
|
21
|
+
import json
|
21
22
|
import logging
|
22
23
|
import multiprocessing as mp
|
23
24
|
import os
|
@@ -77,7 +78,6 @@ class TokenizerManager:
|
|
77
78
|
self,
|
78
79
|
server_args: ServerArgs,
|
79
80
|
port_args: PortArgs,
|
80
|
-
model_override_args: dict = None,
|
81
81
|
):
|
82
82
|
self.server_args = server_args
|
83
83
|
|
@@ -86,8 +86,8 @@ class TokenizerManager:
|
|
86
86
|
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
87
87
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
88
88
|
|
89
|
-
self.
|
90
|
-
self.
|
89
|
+
self.send_to_controller = context.socket(zmq.PUSH)
|
90
|
+
self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
91
91
|
|
92
92
|
# Read model args
|
93
93
|
self.model_path = server_args.model_path
|
@@ -95,7 +95,7 @@ class TokenizerManager:
|
|
95
95
|
self.hf_config = get_config(
|
96
96
|
self.model_path,
|
97
97
|
trust_remote_code=server_args.trust_remote_code,
|
98
|
-
model_override_args=
|
98
|
+
model_override_args=json.loads(server_args.json_model_override_args),
|
99
99
|
)
|
100
100
|
self.is_generation = is_generation_model(
|
101
101
|
self.hf_config.architectures, self.server_args.is_embedding
|
@@ -188,6 +188,7 @@ class TokenizerManager:
|
|
188
188
|
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
189
189
|
obj.image_data if not_use_index else obj.image_data[index]
|
190
190
|
)
|
191
|
+
modalities = obj.modalities
|
191
192
|
return_logprob = (
|
192
193
|
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
193
194
|
)
|
@@ -196,8 +197,6 @@ class TokenizerManager:
|
|
196
197
|
if not_use_index
|
197
198
|
else obj.logprob_start_len[index]
|
198
199
|
)
|
199
|
-
if return_logprob and logprob_start_len == -1:
|
200
|
-
logprob_start_len = len(input_ids) - 1
|
201
200
|
top_logprobs_num = (
|
202
201
|
obj.top_logprobs_num
|
203
202
|
if not_use_index
|
@@ -243,14 +242,13 @@ class TokenizerManager:
|
|
243
242
|
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
244
243
|
obj.image_data[0]
|
245
244
|
)
|
245
|
+
modalities = obj.modalities
|
246
246
|
return_logprob = obj.return_logprob[0]
|
247
247
|
logprob_start_len = obj.logprob_start_len[0]
|
248
248
|
top_logprobs_num = obj.top_logprobs_num[0]
|
249
249
|
|
250
250
|
# Send to the controller
|
251
251
|
if self.is_generation:
|
252
|
-
if return_logprob and logprob_start_len == -1:
|
253
|
-
logprob_start_len = len(input_ids) - 1
|
254
252
|
tokenized_obj = TokenizedGenerateReqInput(
|
255
253
|
rid,
|
256
254
|
input_text,
|
@@ -263,6 +261,12 @@ class TokenizerManager:
|
|
263
261
|
logprob_start_len,
|
264
262
|
top_logprobs_num,
|
265
263
|
obj.stream,
|
264
|
+
modalities,
|
265
|
+
(
|
266
|
+
obj.lora_path[index]
|
267
|
+
if isinstance(obj.lora_path, list)
|
268
|
+
else obj.lora_path
|
269
|
+
),
|
266
270
|
)
|
267
271
|
else: # is embedding
|
268
272
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -271,7 +275,7 @@ class TokenizerManager:
|
|
271
275
|
input_ids,
|
272
276
|
sampling_params,
|
273
277
|
)
|
274
|
-
self.
|
278
|
+
self.send_to_controller.send_pyobj(tokenized_obj)
|
275
279
|
|
276
280
|
# Recv results
|
277
281
|
event = asyncio.Event()
|
@@ -341,11 +345,10 @@ class TokenizerManager:
|
|
341
345
|
sampling_params = self._get_sampling_params(obj.sampling_params[index])
|
342
346
|
|
343
347
|
if self.is_generation:
|
344
|
-
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
345
|
-
obj.logprob_start_len[index] = len(input_ids) - 1
|
346
348
|
pixel_values, image_hashes, image_sizes = (
|
347
349
|
await self._get_pixel_values(obj.image_data[index])
|
348
350
|
)
|
351
|
+
modalities = obj.modalities
|
349
352
|
|
350
353
|
tokenized_obj = TokenizedGenerateReqInput(
|
351
354
|
rid,
|
@@ -359,6 +362,12 @@ class TokenizerManager:
|
|
359
362
|
obj.logprob_start_len[index],
|
360
363
|
obj.top_logprobs_num[index],
|
361
364
|
obj.stream,
|
365
|
+
modalities,
|
366
|
+
(
|
367
|
+
obj.lora_path[index]
|
368
|
+
if isinstance(obj.lora_path, list)
|
369
|
+
else obj.lora_path
|
370
|
+
),
|
362
371
|
)
|
363
372
|
else:
|
364
373
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -367,7 +376,7 @@ class TokenizerManager:
|
|
367
376
|
input_ids,
|
368
377
|
sampling_params,
|
369
378
|
)
|
370
|
-
self.
|
379
|
+
self.send_to_controller.send_pyobj(tokenized_obj)
|
371
380
|
|
372
381
|
event = asyncio.Event()
|
373
382
|
state = ReqState([], False, event)
|
@@ -500,14 +509,14 @@ class TokenizerManager:
|
|
500
509
|
|
501
510
|
def flush_cache(self):
|
502
511
|
req = FlushCacheReq()
|
503
|
-
self.
|
512
|
+
self.send_to_controller.send_pyobj(req)
|
504
513
|
|
505
514
|
def abort_request(self, rid: str):
|
506
515
|
if rid not in self.rid_to_state:
|
507
516
|
return
|
508
517
|
del self.rid_to_state[rid]
|
509
518
|
req = AbortReq(rid)
|
510
|
-
self.
|
519
|
+
self.send_to_controller.send_pyobj(req)
|
511
520
|
|
512
521
|
async def update_weights(
|
513
522
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
@@ -524,7 +533,7 @@ class TokenizerManager:
|
|
524
533
|
# wait for the previous generation requests to finish
|
525
534
|
while len(self.rid_to_state) > 0:
|
526
535
|
await asyncio.sleep(0)
|
527
|
-
self.
|
536
|
+
self.send_to_controller.send_pyobj(obj)
|
528
537
|
self.model_update_result = asyncio.Future()
|
529
538
|
result = await self.model_update_result
|
530
539
|
if result.success:
|