sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
|
|
1
1
|
"""Meta data for requests and batches"""
|
2
|
+
|
3
|
+
import warnings
|
2
4
|
from dataclasses import dataclass
|
3
5
|
from enum import IntEnum, auto
|
4
|
-
from typing import List
|
6
|
+
from typing import List, Union
|
5
7
|
|
6
8
|
import numpy as np
|
7
9
|
import torch
|
8
10
|
|
11
|
+
from sglang.srt.constrained import RegexGuide
|
12
|
+
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
9
13
|
from sglang.srt.managers.controller.radix_cache import RadixCache
|
10
14
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
11
15
|
|
16
|
+
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
17
|
+
|
12
18
|
|
13
19
|
class ForwardMode(IntEnum):
|
14
20
|
PREFILL = auto()
|
@@ -25,7 +31,7 @@ class BaseFinishReason:
|
|
25
31
|
|
26
32
|
|
27
33
|
class FINISH_MATCHED_TOKEN(BaseFinishReason):
|
28
|
-
def __init__(self, matched: int
|
34
|
+
def __init__(self, matched: Union[int, List[int]]):
|
29
35
|
super().__init__()
|
30
36
|
self.matched = matched
|
31
37
|
|
@@ -63,12 +69,15 @@ class Req:
|
|
63
69
|
def __init__(self, rid, origin_input_text, origin_input_ids):
|
64
70
|
self.rid = rid
|
65
71
|
self.origin_input_text = origin_input_text
|
72
|
+
self.origin_input_ids_unpadded = origin_input_ids # Before image padding
|
66
73
|
self.origin_input_ids = origin_input_ids
|
67
|
-
self.
|
68
|
-
self.
|
69
|
-
|
70
|
-
|
71
|
-
self.
|
74
|
+
self.output_ids = [] # Each decode stage's output ids
|
75
|
+
self.input_ids = None # input_ids = origin_input_ids + output_ids
|
76
|
+
|
77
|
+
# For incremental decode
|
78
|
+
self.decoded_text = ""
|
79
|
+
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
80
|
+
self.read_offset = None
|
72
81
|
|
73
82
|
# The number of decoded tokens for token usage report. Note that
|
74
83
|
# this does not include the jump forward tokens.
|
@@ -108,20 +117,54 @@ class Req:
|
|
108
117
|
self.last_update_decode_tokens = 0
|
109
118
|
|
110
119
|
# Constrained decoding
|
111
|
-
self.regex_fsm = None
|
112
|
-
self.regex_fsm_state = 0
|
113
|
-
self.jump_forward_map = None
|
120
|
+
self.regex_fsm: RegexGuide = None
|
121
|
+
self.regex_fsm_state: int = 0
|
122
|
+
self.jump_forward_map: JumpForwardMap = None
|
114
123
|
|
115
124
|
# whether request reached finished condition
|
116
125
|
def finished(self) -> bool:
|
117
126
|
return self.finished_reason is not None
|
118
127
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
128
|
+
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
129
|
+
def init_detokenize_incrementally(self):
|
130
|
+
first_iter = self.surr_offset is None or self.read_offset is None
|
131
|
+
|
132
|
+
if first_iter:
|
133
|
+
self.read_offset = len(self.origin_input_ids_unpadded)
|
134
|
+
self.surr_offset = max(
|
135
|
+
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
136
|
+
)
|
137
|
+
|
138
|
+
all_ids = self.origin_input_ids_unpadded + self.output_ids
|
139
|
+
surr_ids = all_ids[self.surr_offset : self.read_offset]
|
140
|
+
read_ids = all_ids[self.surr_offset :]
|
141
|
+
|
142
|
+
return surr_ids, read_ids, len(all_ids)
|
143
|
+
|
144
|
+
def detokenize_incrementally(self, inplace: bool = True):
|
145
|
+
surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
|
146
|
+
|
147
|
+
surr_text = self.tokenizer.decode(
|
148
|
+
surr_ids,
|
149
|
+
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
150
|
+
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
151
|
+
)
|
152
|
+
new_text = self.tokenizer.decode(
|
153
|
+
read_ids,
|
154
|
+
skip_special_tokens=self.sampling_params.skip_special_tokens,
|
155
|
+
spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
|
123
156
|
)
|
124
|
-
|
157
|
+
|
158
|
+
if len(new_text) > len(surr_text) and not new_text.endswith("�"):
|
159
|
+
new_text = new_text[len(surr_text) :]
|
160
|
+
if inplace:
|
161
|
+
self.decoded_text += new_text
|
162
|
+
self.surr_offset = self.read_offset
|
163
|
+
self.read_offset = num_all_tokens
|
164
|
+
|
165
|
+
return True, new_text
|
166
|
+
|
167
|
+
return False, ""
|
125
168
|
|
126
169
|
def max_new_tokens(self):
|
127
170
|
return self.sampling_params.max_new_tokens
|
@@ -130,18 +173,17 @@ class Req:
|
|
130
173
|
if self.finished():
|
131
174
|
return
|
132
175
|
|
133
|
-
if (
|
134
|
-
|
135
|
-
>= self.sampling_params.max_new_tokens
|
136
|
-
):
|
137
|
-
self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
|
176
|
+
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
177
|
+
self.finished_reason = FINISH_LENGTH(len(self.output_ids))
|
138
178
|
return
|
139
179
|
|
140
180
|
if (
|
141
181
|
self.output_ids[-1] == self.tokenizer.eos_token_id
|
142
182
|
and not self.sampling_params.ignore_eos
|
143
183
|
):
|
144
|
-
self.finished_reason = FINISH_MATCHED_TOKEN(
|
184
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(
|
185
|
+
matched=self.tokenizer.eos_token_id
|
186
|
+
)
|
145
187
|
return
|
146
188
|
|
147
189
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -150,61 +192,59 @@ class Req:
|
|
150
192
|
)
|
151
193
|
|
152
194
|
for stop_str in self.sampling_params.stop_strs:
|
153
|
-
|
154
|
-
if stop_str in tail_str or stop_str in self.prev_output_str:
|
195
|
+
if stop_str in tail_str or stop_str in self.decoded_text:
|
155
196
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
156
197
|
return
|
157
198
|
|
158
199
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
159
|
-
# FIXME: This logic does not really solve the problem of determining whether
|
160
|
-
# there should be a leading space.
|
161
|
-
cur_output_str = self.partial_decode(self.output_ids)
|
162
|
-
|
163
|
-
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
|
164
200
|
if self.origin_input_text is None:
|
165
201
|
# Recovering text can only use unpadded ids
|
166
202
|
self.origin_input_text = self.tokenizer.decode(
|
167
203
|
self.origin_input_ids_unpadded
|
168
204
|
)
|
169
205
|
|
170
|
-
all_text =
|
171
|
-
self.origin_input_text
|
172
|
-
+ self.prev_output_str
|
173
|
-
+ cur_output_str
|
174
|
-
+ jump_forward_str
|
175
|
-
)
|
206
|
+
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
176
207
|
all_ids = self.tokenizer.encode(all_text)
|
177
208
|
prompt_tokens = len(self.origin_input_ids_unpadded)
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
209
|
+
|
210
|
+
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
211
|
+
# TODO(lsyin): fix token fusion
|
212
|
+
warnings.warn(
|
213
|
+
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
214
|
+
)
|
215
|
+
return False
|
216
|
+
|
217
|
+
old_output_ids = self.output_ids
|
218
|
+
self.output_ids = all_ids[prompt_tokens:]
|
219
|
+
self.decoded_text = self.decoded_text + jump_forward_str
|
220
|
+
self.surr_offset = prompt_tokens
|
221
|
+
self.read_offset = len(all_ids)
|
222
|
+
|
223
|
+
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
224
|
+
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
225
|
+
surr_text_ = self.tokenizer.decode(
|
226
|
+
all_ids[self.read_offset - i : self.read_offset]
|
227
|
+
)
|
228
|
+
if not surr_text_.endswith("�"):
|
229
|
+
self.surr_offset = self.read_offset - i
|
230
|
+
break
|
185
231
|
|
186
232
|
self.regex_fsm_state = next_state
|
187
233
|
|
188
234
|
if self.return_logprob:
|
189
235
|
# For fast-forward part's logprobs
|
190
236
|
k = 0
|
191
|
-
for i, old_id in enumerate(
|
192
|
-
if old_id == self.
|
237
|
+
for i, old_id in enumerate(old_output_ids):
|
238
|
+
if old_id == self.output_ids[i]:
|
193
239
|
k = k + 1
|
194
240
|
else:
|
195
241
|
break
|
196
242
|
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
197
243
|
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
198
244
|
self.logprob_start_len = prompt_tokens + k
|
199
|
-
self.last_update_decode_tokens = len(self.
|
200
|
-
|
201
|
-
# print("=" * 100)
|
202
|
-
# print(f"Catch jump forward:\n{jump_forward_str}")
|
203
|
-
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
204
|
-
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
245
|
+
self.last_update_decode_tokens = len(self.output_ids) - k
|
205
246
|
|
206
|
-
|
207
|
-
# print("*" * 100)
|
247
|
+
return True
|
208
248
|
|
209
249
|
def __repr__(self):
|
210
250
|
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
@@ -263,6 +303,10 @@ class Batch:
|
|
263
303
|
def is_empty(self):
|
264
304
|
return len(self.reqs) == 0
|
265
305
|
|
306
|
+
# whether batch has at least 1 streaming request
|
307
|
+
def has_stream(self) -> bool:
|
308
|
+
return any(r.stream for r in self.reqs)
|
309
|
+
|
266
310
|
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
|
267
311
|
device = "cuda"
|
268
312
|
bs = len(self.reqs)
|
@@ -380,7 +424,10 @@ class Batch:
|
|
380
424
|
sorted_indices = [i for i in range(len(self.reqs))]
|
381
425
|
# TODO(lsyin): improve the priority of retraction
|
382
426
|
sorted_indices.sort(
|
383
|
-
key=lambda i: (
|
427
|
+
key=lambda i: (
|
428
|
+
len(self.reqs[i].output_ids),
|
429
|
+
-len(self.reqs[i].origin_input_ids),
|
430
|
+
),
|
384
431
|
reverse=True,
|
385
432
|
)
|
386
433
|
|
@@ -402,14 +449,9 @@ class Batch:
|
|
402
449
|
# release the last node
|
403
450
|
self.tree_cache.dec_lock_ref(req.last_node)
|
404
451
|
|
405
|
-
cur_output_str = req.partial_decode(req.output_ids)
|
406
|
-
req.prev_output_str = req.prev_output_str + cur_output_str
|
407
|
-
req.prev_output_ids.extend(req.output_ids)
|
408
|
-
|
409
452
|
req.prefix_indices = None
|
410
453
|
req.last_node = None
|
411
454
|
req.extend_input_len = 0
|
412
|
-
req.output_ids = []
|
413
455
|
|
414
456
|
# For incremental logprobs
|
415
457
|
req.last_update_decode_tokens = 0
|
@@ -427,18 +469,54 @@ class Batch:
|
|
427
469
|
|
428
470
|
for i, req in enumerate(self.reqs):
|
429
471
|
if req.jump_forward_map is not None:
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
472
|
+
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
473
|
+
req.regex_fsm_state
|
474
|
+
)
|
475
|
+
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
476
|
+
suffix_bytes = []
|
477
|
+
continuation_range = range(0x80, 0xC0)
|
478
|
+
cur_state = req.regex_fsm_state
|
479
|
+
while (
|
480
|
+
len(jump_forward_bytes)
|
481
|
+
and jump_forward_bytes[0][0] in continuation_range
|
482
|
+
):
|
483
|
+
# continuation bytes
|
484
|
+
byte_edge = jump_forward_bytes.pop(0)
|
485
|
+
suffix_bytes.append(byte_edge[0])
|
486
|
+
cur_state = byte_edge[1]
|
487
|
+
|
488
|
+
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
489
|
+
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
490
|
+
|
491
|
+
# Current ids, for cache and revert
|
492
|
+
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
493
|
+
cur_output_ids = req.output_ids
|
494
|
+
|
495
|
+
req.output_ids.extend(suffix_ids)
|
496
|
+
decode_res, new_text = req.detokenize_incrementally(inplace=False)
|
497
|
+
if not decode_res:
|
498
|
+
req.output_ids = cur_output_ids
|
434
499
|
continue
|
435
500
|
|
436
|
-
|
437
|
-
|
501
|
+
(
|
502
|
+
jump_forward_str,
|
503
|
+
next_state,
|
504
|
+
) = req.jump_forward_map.jump_forward_symbol(cur_state)
|
505
|
+
|
506
|
+
# Make the incrementally decoded text part of jump_forward_str
|
507
|
+
# so that the UTF-8 will not corrupt
|
508
|
+
jump_forward_str = new_text + jump_forward_str
|
509
|
+
if not req.jump_forward_and_retokenize(
|
510
|
+
jump_forward_str, next_state
|
511
|
+
):
|
512
|
+
req.output_ids = cur_output_ids
|
513
|
+
continue
|
438
514
|
|
439
515
|
# insert the old request into tree_cache
|
516
|
+
if req_pool_indices_cpu is None:
|
517
|
+
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
440
518
|
self.tree_cache.cache_req(
|
441
|
-
token_ids=
|
519
|
+
token_ids=cur_all_ids,
|
442
520
|
last_uncached_pos=len(req.prefix_indices),
|
443
521
|
req_pool_idx=req_pool_indices_cpu[i],
|
444
522
|
)
|
@@ -446,9 +524,6 @@ class Batch:
|
|
446
524
|
# unlock the last node
|
447
525
|
self.tree_cache.dec_lock_ref(req.last_node)
|
448
526
|
|
449
|
-
# jump-forward
|
450
|
-
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
451
|
-
|
452
527
|
# re-applying image padding
|
453
528
|
if req.pixel_values is not None:
|
454
529
|
(
|
@@ -582,7 +657,7 @@ class Batch:
|
|
582
657
|
if req.regex_fsm is not None:
|
583
658
|
allowed_mask.zero_()
|
584
659
|
allowed_mask[
|
585
|
-
req.regex_fsm.
|
660
|
+
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
586
661
|
] = 1
|
587
662
|
logits[i].masked_fill_(~allowed_mask, float("-inf"))
|
588
663
|
|
@@ -601,7 +676,7 @@ class Batch:
|
|
601
676
|
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
|
602
677
|
for i, req in enumerate(self.reqs):
|
603
678
|
if req.regex_fsm is not None:
|
604
|
-
req.regex_fsm_state = req.regex_fsm.
|
679
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
605
680
|
req.regex_fsm_state, batch_next_token_ids_cpu[i]
|
606
681
|
)
|
607
682
|
|
@@ -13,15 +13,15 @@ import zmq
|
|
13
13
|
import zmq.asyncio
|
14
14
|
|
15
15
|
from sglang.global_config import global_config
|
16
|
+
from sglang.srt.managers.controller.dp_worker import (
|
17
|
+
DataParallelWorkerThread,
|
18
|
+
start_data_parallel_worker,
|
19
|
+
)
|
16
20
|
from sglang.srt.managers.io_struct import (
|
17
21
|
AbortReq,
|
18
22
|
FlushCacheReq,
|
19
23
|
TokenizedGenerateReqInput,
|
20
24
|
)
|
21
|
-
from sglang.srt.managers.controller.dp_worker import (
|
22
|
-
DataParallelWorkerThread,
|
23
|
-
start_data_parallel_worker,
|
24
|
-
)
|
25
25
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
26
26
|
from sglang.utils import get_exception_traceback
|
27
27
|
|
@@ -136,7 +136,7 @@ class Controller:
|
|
136
136
|
self.recv_reqs = []
|
137
137
|
if next_step_input:
|
138
138
|
await self.dispatching(next_step_input)
|
139
|
-
#else:
|
139
|
+
# else:
|
140
140
|
# logger.error("There is no live worker.")
|
141
141
|
|
142
142
|
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
"""A controller that manages a group of tensor parallel workers."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
|
-
import
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
5
6
|
|
6
7
|
import uvloop
|
7
8
|
import zmq
|
@@ -49,7 +50,9 @@ class ControllerSingle:
|
|
49
50
|
# async sleep for receiving the subsequent request and avoiding cache miss
|
50
51
|
slept = False
|
51
52
|
if len(out_pyobjs) != 0:
|
52
|
-
has_finished = any(
|
53
|
+
has_finished = any(
|
54
|
+
[obj.finished_reason is not None for obj in out_pyobjs]
|
55
|
+
)
|
53
56
|
if has_finished:
|
54
57
|
if self.request_dependency_delay > 0:
|
55
58
|
slept = True
|
@@ -73,8 +76,9 @@ def start_controller_process(
|
|
73
76
|
)
|
74
77
|
|
75
78
|
try:
|
79
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
76
80
|
model_client = ModelTpClient(
|
77
|
-
|
81
|
+
[i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
|
78
82
|
server_args,
|
79
83
|
port_args.model_port_args[0],
|
80
84
|
model_overide_args,
|
@@ -87,6 +91,7 @@ def start_controller_process(
|
|
87
91
|
pipe_writer.send("init ok")
|
88
92
|
|
89
93
|
loop = asyncio.new_event_loop()
|
94
|
+
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
90
95
|
asyncio.set_event_loop(loop)
|
91
96
|
loop.create_task(controller.loop_for_recv_requests())
|
92
97
|
try:
|
@@ -94,4 +99,4 @@ def start_controller_process(
|
|
94
99
|
except Exception:
|
95
100
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
96
101
|
finally:
|
97
|
-
kill_parent_process()
|
102
|
+
kill_parent_process()
|