sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -32,16 +32,33 @@ import psutil
|
|
32
32
|
import setproctitle
|
33
33
|
import torch
|
34
34
|
import zmq
|
35
|
+
from torch.distributed import barrier
|
35
36
|
|
36
37
|
from sglang.global_config import global_config
|
37
38
|
from sglang.srt.configs.model_config import ModelConfig
|
38
39
|
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
40
|
+
from sglang.srt.disaggregation.decode import (
|
41
|
+
DecodePreallocQueue,
|
42
|
+
DecodeTransferQueue,
|
43
|
+
SchedulerDisaggregationDecodeMixin,
|
44
|
+
)
|
45
|
+
from sglang.srt.disaggregation.prefill import (
|
46
|
+
PrefillBootstrapQueue,
|
47
|
+
SchedulerDisaggregationPrefillMixin,
|
48
|
+
)
|
49
|
+
from sglang.srt.disaggregation.utils import (
|
50
|
+
DisaggregationMode,
|
51
|
+
ReqToMetadataIdxAllocator,
|
52
|
+
)
|
39
53
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
40
54
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
41
55
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
56
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
42
57
|
from sglang.srt.managers.io_struct import (
|
43
58
|
AbortReq,
|
44
59
|
CloseSessionReqInput,
|
60
|
+
ExpertDistributionReq,
|
61
|
+
ExpertDistributionReqOutput,
|
45
62
|
FlushCacheReq,
|
46
63
|
GetInternalStateReq,
|
47
64
|
GetInternalStateReqOutput,
|
@@ -59,6 +76,8 @@ from sglang.srt.managers.io_struct import (
|
|
59
76
|
ReleaseMemoryOccupationReqOutput,
|
60
77
|
ResumeMemoryOccupationReqInput,
|
61
78
|
ResumeMemoryOccupationReqOutput,
|
79
|
+
RpcReqInput,
|
80
|
+
RpcReqOutput,
|
62
81
|
SetInternalStateReq,
|
63
82
|
SetInternalStateReqOutput,
|
64
83
|
TokenizedEmbeddingReqInput,
|
@@ -72,7 +91,7 @@ from sglang.srt.managers.io_struct import (
|
|
72
91
|
)
|
73
92
|
from sglang.srt.managers.schedule_batch import (
|
74
93
|
FINISH_ABORT,
|
75
|
-
|
94
|
+
MultimodalInputs,
|
76
95
|
Req,
|
77
96
|
ScheduleBatch,
|
78
97
|
global_server_args_dict,
|
@@ -98,6 +117,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
98
117
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
99
118
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
100
119
|
from sglang.srt.utils import (
|
120
|
+
DynamicGradMode,
|
101
121
|
broadcast_pyobj,
|
102
122
|
configure_logger,
|
103
123
|
crash_on_warnings,
|
@@ -111,6 +131,8 @@ from sglang.srt.utils import (
|
|
111
131
|
)
|
112
132
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
113
133
|
|
134
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
135
|
+
|
114
136
|
logger = logging.getLogger(__name__)
|
115
137
|
|
116
138
|
# Test retract decode for debugging purposes
|
@@ -133,7 +155,11 @@ class EmbeddingBatchResult:
|
|
133
155
|
bid: int
|
134
156
|
|
135
157
|
|
136
|
-
class Scheduler(
|
158
|
+
class Scheduler(
|
159
|
+
SchedulerOutputProcessorMixin,
|
160
|
+
SchedulerDisaggregationDecodeMixin,
|
161
|
+
SchedulerDisaggregationPrefillMixin,
|
162
|
+
):
|
137
163
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
138
164
|
|
139
165
|
def __init__(
|
@@ -193,8 +219,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
193
219
|
self.send_to_detokenizer = get_zmq_socket(
|
194
220
|
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
195
221
|
)
|
222
|
+
|
223
|
+
self.recv_from_rpc = get_zmq_socket(
|
224
|
+
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
225
|
+
)
|
196
226
|
else:
|
197
227
|
self.recv_from_tokenizer = None
|
228
|
+
self.recv_from_rpc = None
|
198
229
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
199
230
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
200
231
|
|
@@ -376,9 +407,16 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
376
407
|
(ProfileReq, self.profile),
|
377
408
|
(GetInternalStateReq, self.get_internal_state),
|
378
409
|
(SetInternalStateReq, self.set_internal_state),
|
410
|
+
(RpcReqInput, self.handle_rpc_request),
|
411
|
+
(ExpertDistributionReq, self.expert_distribution_handle),
|
379
412
|
]
|
380
413
|
)
|
381
414
|
|
415
|
+
self.disaggregation_mode = DisaggregationMode(
|
416
|
+
self.server_args.disaggregation_mode
|
417
|
+
)
|
418
|
+
self.init_disaggregation()
|
419
|
+
|
382
420
|
def init_tokenizer(self):
|
383
421
|
server_args = self.server_args
|
384
422
|
|
@@ -435,6 +473,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
435
473
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
436
474
|
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
|
437
475
|
page_size=self.page_size,
|
476
|
+
hicache_ratio=server_args.hicache_ratio,
|
438
477
|
)
|
439
478
|
else:
|
440
479
|
self.tree_cache = RadixCache(
|
@@ -478,7 +517,74 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
478
517
|
},
|
479
518
|
)
|
480
519
|
|
481
|
-
|
520
|
+
def init_disaggregation(self):
|
521
|
+
if (
|
522
|
+
self.disaggregation_mode == DisaggregationMode.DECODE
|
523
|
+
): # *2 for the headroom.
|
524
|
+
buffer_size = (self.req_to_token_pool.size) * 2
|
525
|
+
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
526
|
+
buffer_size
|
527
|
+
)
|
528
|
+
aux_dtype = torch.int32
|
529
|
+
# A list of metadata buffers. The shape is (b, metadata_size) where
|
530
|
+
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
531
|
+
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
532
|
+
output_id_buffer = torch.zeros(
|
533
|
+
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
534
|
+
)
|
535
|
+
metadata_buffers = [output_id_buffer]
|
536
|
+
|
537
|
+
# The decode requests polling kv cache
|
538
|
+
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
539
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
540
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
541
|
+
metadata_buffers=metadata_buffers,
|
542
|
+
)
|
543
|
+
|
544
|
+
# The decode requests pending for pre-allocation
|
545
|
+
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
546
|
+
req_to_token_pool=self.req_to_token_pool,
|
547
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
548
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
549
|
+
metadata_buffers=metadata_buffers,
|
550
|
+
aux_dtype=aux_dtype,
|
551
|
+
scheduler=self,
|
552
|
+
transfer_queue=self.disagg_decode_transfer_queue,
|
553
|
+
tree_cache=self.tree_cache,
|
554
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
555
|
+
tp_rank=self.tp_rank,
|
556
|
+
tp_size=self.tp_size,
|
557
|
+
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
558
|
+
)
|
559
|
+
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
560
|
+
# *2 for the headroom.
|
561
|
+
buffer_size = self.max_running_requests * 2
|
562
|
+
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
563
|
+
buffer_size
|
564
|
+
)
|
565
|
+
aux_dtype = torch.int32
|
566
|
+
# A list of metadata buffers. The shape is (b, metadata_size) where
|
567
|
+
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
568
|
+
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
569
|
+
output_id_buffer = torch.zeros(
|
570
|
+
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
571
|
+
)
|
572
|
+
metadata_buffers = [output_id_buffer]
|
573
|
+
|
574
|
+
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
|
575
|
+
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
576
|
+
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
577
|
+
metadata_buffers=metadata_buffers,
|
578
|
+
aux_dtype=aux_dtype,
|
579
|
+
tp_rank=self.tp_rank,
|
580
|
+
tp_size=self.tp_size,
|
581
|
+
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
582
|
+
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
583
|
+
)
|
584
|
+
# The prefill requests that are in the middle of kv sending
|
585
|
+
self.disagg_prefill_infight_queue: List[Req] = []
|
586
|
+
|
587
|
+
@DynamicGradMode()
|
482
588
|
def event_loop_normal(self):
|
483
589
|
"""A normal scheduler loop."""
|
484
590
|
while True:
|
@@ -498,7 +604,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
498
604
|
|
499
605
|
self.last_batch = batch
|
500
606
|
|
501
|
-
@
|
607
|
+
@DynamicGradMode()
|
502
608
|
def event_loop_overlap(self):
|
503
609
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
504
610
|
self.result_queue = deque()
|
@@ -538,6 +644,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
538
644
|
|
539
645
|
self.last_batch = batch
|
540
646
|
|
647
|
+
@torch.no_grad()
|
648
|
+
def event_loop_normal_disagg_prefill(self):
|
649
|
+
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
650
|
+
|
651
|
+
while True:
|
652
|
+
recv_reqs = self.recv_requests()
|
653
|
+
self.process_input_requests(recv_reqs)
|
654
|
+
self.waiting_queue.extend(
|
655
|
+
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
656
|
+
)
|
657
|
+
self.process_prefill_chunk()
|
658
|
+
batch = self.get_new_batch_prefill()
|
659
|
+
self.cur_batch = batch
|
660
|
+
|
661
|
+
if batch:
|
662
|
+
result = self.run_batch(batch)
|
663
|
+
self.process_batch_result_disagg_prefill(batch, result)
|
664
|
+
|
665
|
+
if len(self.disagg_prefill_infight_queue) > 0:
|
666
|
+
self.process_disagg_prefill_infight_queue()
|
667
|
+
|
668
|
+
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
|
669
|
+
self.check_memory()
|
670
|
+
self.new_token_ratio = self.init_new_token_ratio
|
671
|
+
|
672
|
+
self.last_batch = batch
|
673
|
+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
674
|
+
# Otherwise, it hangs under high concurrency
|
675
|
+
self.running_batch.batch_is_full = False
|
676
|
+
|
677
|
+
@torch.no_grad()
|
678
|
+
def event_loop_normal_disagg_decode(self):
|
679
|
+
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
680
|
+
|
681
|
+
while True:
|
682
|
+
recv_reqs = self.recv_requests()
|
683
|
+
self.process_input_requests(recv_reqs)
|
684
|
+
# polling and allocating kv cache
|
685
|
+
self.process_decode_queue()
|
686
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
687
|
+
self.cur_batch = batch
|
688
|
+
|
689
|
+
if batch:
|
690
|
+
# Generate fake extend output.
|
691
|
+
if batch.forward_mode.is_extend():
|
692
|
+
# Note: Logprobs should be handled on the prefill engine.
|
693
|
+
self.stream_output(
|
694
|
+
batch.reqs, [False for _ in range(len(batch.reqs))]
|
695
|
+
)
|
696
|
+
else:
|
697
|
+
result = self.run_batch(batch)
|
698
|
+
self.process_batch_result(batch, result)
|
699
|
+
|
700
|
+
if batch is None and (
|
701
|
+
len(self.disagg_decode_transfer_queue.queue)
|
702
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
703
|
+
== 0
|
704
|
+
):
|
705
|
+
# When the server is idle, do self-check and re-init some states
|
706
|
+
self.check_memory()
|
707
|
+
self.new_token_ratio = self.init_new_token_ratio
|
708
|
+
|
709
|
+
self.last_batch = batch
|
710
|
+
|
541
711
|
def recv_requests(self) -> List[Req]:
|
542
712
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
543
713
|
if self.attn_tp_rank == 0:
|
@@ -549,6 +719,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
549
719
|
except zmq.ZMQError:
|
550
720
|
break
|
551
721
|
recv_reqs.append(recv_req)
|
722
|
+
|
723
|
+
while True:
|
724
|
+
try:
|
725
|
+
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
726
|
+
except zmq.ZMQError:
|
727
|
+
break
|
728
|
+
recv_reqs.append(recv_rpc)
|
552
729
|
else:
|
553
730
|
recv_reqs = None
|
554
731
|
|
@@ -600,7 +777,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
600
777
|
|
601
778
|
output = self._request_dispatcher(recv_req)
|
602
779
|
if output is not None:
|
603
|
-
|
780
|
+
if isinstance(output, RpcReqOutput):
|
781
|
+
if self.recv_from_rpc is not None:
|
782
|
+
self.recv_from_rpc.send_pyobj(output)
|
783
|
+
else:
|
784
|
+
self.send_to_tokenizer.send_pyobj(output)
|
604
785
|
|
605
786
|
def handle_generate_request(
|
606
787
|
self,
|
@@ -666,8 +847,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
666
847
|
return
|
667
848
|
|
668
849
|
# Handle multimodal inputs
|
669
|
-
if recv_req.
|
670
|
-
image_inputs =
|
850
|
+
if recv_req.mm_inputs is not None:
|
851
|
+
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
|
671
852
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
672
853
|
req.origin_input_ids = self.pad_input_ids_func(
|
673
854
|
req.origin_input_ids, image_inputs
|
@@ -681,7 +862,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
681
862
|
)
|
682
863
|
logger.error(error_msg)
|
683
864
|
req.origin_input_ids = [0]
|
684
|
-
req.
|
865
|
+
req.multimodal_inputs = None
|
685
866
|
req.sampling_params.max_new_tokens = 0
|
686
867
|
req.finished_reason = FINISH_ABORT(
|
687
868
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
@@ -756,10 +937,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
756
937
|
self._add_request_to_queue(req)
|
757
938
|
|
758
939
|
def _add_request_to_queue(self, req: Req):
|
759
|
-
self.
|
940
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
941
|
+
self.disagg_prefill_pending_queue.add(req)
|
760
942
|
|
761
|
-
|
762
|
-
|
943
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
944
|
+
self.disagg_decode_prealloc_queue.add(req)
|
945
|
+
|
946
|
+
else:
|
947
|
+
self.waiting_queue.append(req)
|
948
|
+
|
949
|
+
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
950
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
951
|
+
self.disagg_decode_prealloc_queue.extend(reqs)
|
952
|
+
else:
|
953
|
+
self.waiting_queue.extend(reqs)
|
763
954
|
|
764
955
|
def handle_embedding_request(
|
765
956
|
self,
|
@@ -775,7 +966,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
775
966
|
|
776
967
|
# Handle multimodal inputs
|
777
968
|
if recv_req.image_inputs is not None:
|
778
|
-
image_inputs =
|
969
|
+
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
|
779
970
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
780
971
|
req.origin_input_ids = self.pad_input_ids_func(
|
781
972
|
req.origin_input_ids, image_inputs
|
@@ -789,7 +980,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
789
980
|
)
|
790
981
|
logger.error(error_msg)
|
791
982
|
req.origin_input_ids = [0]
|
792
|
-
req.
|
983
|
+
req.multimodal_inputs = None
|
793
984
|
req.sampling_params.max_new_tokens = 0
|
794
985
|
req.finished_reason = FINISH_ABORT(
|
795
986
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
@@ -875,7 +1066,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
875
1066
|
f"#token: {num_used}, "
|
876
1067
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
877
1068
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
878
|
-
f"largest-len: {self._largest_prefill_decode_len}, "
|
879
1069
|
f"#queue-req: {len(self.waiting_queue)}, "
|
880
1070
|
)
|
881
1071
|
spec_accept_length = 0
|
@@ -893,7 +1083,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
893
1083
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
894
1084
|
f"accept len: {spec_accept_length:.2f}, "
|
895
1085
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
896
|
-
f"largest-len: {self._largest_prefill_decode_len}, "
|
897
1086
|
f"#queue-req: {len(self.waiting_queue)}, "
|
898
1087
|
)
|
899
1088
|
|
@@ -1492,6 +1681,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1492
1681
|
server_args=global_server_args_dict,
|
1493
1682
|
)
|
1494
1683
|
|
1684
|
+
def handle_rpc_request(self, recv_req: RpcReqInput):
|
1685
|
+
# Handle RPC requests
|
1686
|
+
logger.info(
|
1687
|
+
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
|
1688
|
+
)
|
1689
|
+
|
1690
|
+
success = True
|
1691
|
+
exec = None
|
1692
|
+
try:
|
1693
|
+
func = getattr(self, recv_req.method)
|
1694
|
+
func(recv_req.parameters)
|
1695
|
+
except Exception as e:
|
1696
|
+
success = False
|
1697
|
+
exec = e
|
1698
|
+
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
|
1699
|
+
|
1700
|
+
barrier()
|
1701
|
+
return RpcReqOutput(success, "" if not exec else str(exec))
|
1702
|
+
|
1703
|
+
def save_remote_model(self, params):
|
1704
|
+
url = params["url"]
|
1705
|
+
|
1706
|
+
if isinstance(self.tp_worker, TpModelWorkerClient):
|
1707
|
+
worker = self.tp_worker.worker
|
1708
|
+
else:
|
1709
|
+
worker = self.tp_worker
|
1710
|
+
|
1711
|
+
worker.model_runner.save_remote_model(url)
|
1712
|
+
|
1713
|
+
def save_sharded_model(self, params):
|
1714
|
+
if isinstance(self.tp_worker, TpModelWorkerClient):
|
1715
|
+
worker = self.tp_worker.worker
|
1716
|
+
else:
|
1717
|
+
worker = self.tp_worker
|
1718
|
+
|
1719
|
+
worker.model_runner.save_sharded_model(
|
1720
|
+
path=params["path"],
|
1721
|
+
pattern=params["pattern"],
|
1722
|
+
max_size=params["max_size"],
|
1723
|
+
)
|
1724
|
+
|
1495
1725
|
def abort_request(self, recv_req: AbortReq):
|
1496
1726
|
# Delete requests in the waiting queue
|
1497
1727
|
to_del = []
|
@@ -1561,6 +1791,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1561
1791
|
return GetWeightsByNameReqOutput(parameter)
|
1562
1792
|
|
1563
1793
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
1794
|
+
self.memory_saver_adapter.check_validity(
|
1795
|
+
caller_name="release_memory_occupation"
|
1796
|
+
)
|
1564
1797
|
self.stashed_model_static_state = _export_static_state(
|
1565
1798
|
self.tp_worker.worker.model_runner.model
|
1566
1799
|
)
|
@@ -1569,6 +1802,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1569
1802
|
return ReleaseMemoryOccupationReqOutput()
|
1570
1803
|
|
1571
1804
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
1805
|
+
self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
|
1572
1806
|
self.memory_saver_adapter.resume()
|
1573
1807
|
_import_static_state(
|
1574
1808
|
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
@@ -1668,6 +1902,17 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|
1668
1902
|
ProfileReqOutput(success=True, message="Succeeded.")
|
1669
1903
|
)
|
1670
1904
|
|
1905
|
+
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
1906
|
+
if recv_req == ExpertDistributionReq.START_RECORD:
|
1907
|
+
expert_distribution_recorder.start_record()
|
1908
|
+
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
1909
|
+
expert_distribution_recorder.stop_record()
|
1910
|
+
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
1911
|
+
expert_distribution_recorder.dump_record()
|
1912
|
+
else:
|
1913
|
+
raise ValueError("Unrecognized ExpertDistributionReq value")
|
1914
|
+
return ExpertDistributionReqOutput()
|
1915
|
+
|
1671
1916
|
def open_session(self, recv_req: OpenSessionReqInput):
|
1672
1917
|
# handle error
|
1673
1918
|
session_id = recv_req.session_id
|
@@ -1726,7 +1971,7 @@ def run_scheduler_process(
|
|
1726
1971
|
prefix = f" DP{dp_rank} TP{tp_rank}"
|
1727
1972
|
|
1728
1973
|
# Config the process
|
1729
|
-
|
1974
|
+
kill_itself_when_parent_died()
|
1730
1975
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
1731
1976
|
faulthandler.enable()
|
1732
1977
|
parent_process = psutil.Process().parent()
|
@@ -1753,10 +1998,18 @@ def run_scheduler_process(
|
|
1753
1998
|
"max_req_input_len": scheduler.max_req_input_len,
|
1754
1999
|
}
|
1755
2000
|
)
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
scheduler.
|
2001
|
+
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2002
|
+
|
2003
|
+
if disaggregation_mode == DisaggregationMode.NULL:
|
2004
|
+
if scheduler.enable_overlap:
|
2005
|
+
scheduler.event_loop_overlap()
|
2006
|
+
else:
|
2007
|
+
scheduler.event_loop_normal()
|
2008
|
+
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
2009
|
+
scheduler.event_loop_normal_disagg_prefill()
|
2010
|
+
elif disaggregation_mode == DisaggregationMode.DECODE:
|
2011
|
+
scheduler.event_loop_normal_disagg_decode()
|
2012
|
+
|
1760
2013
|
except Exception:
|
1761
2014
|
traceback = get_exception_traceback()
|
1762
2015
|
logger.error(f"Scheduler hit an exception: {traceback}")
|
@@ -138,7 +138,7 @@ class Session:
|
|
138
138
|
token_ids_logprob=req.token_ids_logprob,
|
139
139
|
)
|
140
140
|
if last_req is not None:
|
141
|
-
new_req.
|
141
|
+
new_req.multimodal_inputs = last_req.mm_inputs
|
142
142
|
new_req.tokenizer = tokenizer
|
143
143
|
if abort:
|
144
144
|
new_req.to_abort = True
|
@@ -16,7 +16,6 @@
|
|
16
16
|
import asyncio
|
17
17
|
import copy
|
18
18
|
import dataclasses
|
19
|
-
import json
|
20
19
|
import logging
|
21
20
|
import os
|
22
21
|
import pickle
|
@@ -49,11 +48,9 @@ from fastapi import BackgroundTasks
|
|
49
48
|
|
50
49
|
from sglang.srt.aio_rwlock import RWLock
|
51
50
|
from sglang.srt.configs.model_config import ModelConfig
|
51
|
+
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
52
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
52
53
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
53
|
-
from sglang.srt.managers.image_processor import (
|
54
|
-
get_dummy_image_processor,
|
55
|
-
get_image_processor,
|
56
|
-
)
|
57
54
|
from sglang.srt.managers.io_struct import (
|
58
55
|
AbortReq,
|
59
56
|
BatchEmbeddingOut,
|
@@ -63,6 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|
63
60
|
CloseSessionReqInput,
|
64
61
|
ConfigureLoggingReq,
|
65
62
|
EmbeddingReqInput,
|
63
|
+
ExpertDistributionReq,
|
64
|
+
ExpertDistributionReqOutput,
|
66
65
|
FlushCacheReq,
|
67
66
|
GenerateReqInput,
|
68
67
|
GetInternalStateReq,
|
@@ -91,6 +90,11 @@ from sglang.srt.managers.io_struct import (
|
|
91
90
|
UpdateWeightsFromTensorReqInput,
|
92
91
|
UpdateWeightsFromTensorReqOutput,
|
93
92
|
)
|
93
|
+
from sglang.srt.managers.multimodal_processor import (
|
94
|
+
get_dummy_processor,
|
95
|
+
get_mm_processor,
|
96
|
+
import_processors,
|
97
|
+
)
|
94
98
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
95
99
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
96
100
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -168,27 +172,33 @@ class TokenizerManager:
|
|
168
172
|
self.context_len = self.model_config.context_len
|
169
173
|
self.image_token_id = self.model_config.image_token_id
|
170
174
|
|
171
|
-
|
172
|
-
|
175
|
+
if self.model_config.is_multimodal:
|
176
|
+
import_processors()
|
177
|
+
_processor = get_processor(
|
178
|
+
server_args.tokenizer_path,
|
179
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
180
|
+
trust_remote_code=server_args.trust_remote_code,
|
181
|
+
revision=server_args.revision,
|
182
|
+
)
|
173
183
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
184
|
+
# We want to parallelize the image pre-processing so we create an executor for it
|
185
|
+
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
186
|
+
# images even with skip_tokenizer_init=False.
|
187
|
+
self.mm_processor = get_mm_processor(
|
188
|
+
self.model_config.hf_config, server_args, _processor
|
189
|
+
)
|
190
|
+
|
191
|
+
if server_args.skip_tokenizer_init:
|
192
|
+
self.tokenizer = self.processor = None
|
193
|
+
else:
|
194
|
+
self.processor = _processor
|
185
195
|
self.tokenizer = self.processor.tokenizer
|
186
196
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
197
|
+
else:
|
198
|
+
self.mm_processor = get_dummy_processor()
|
187
199
|
|
188
|
-
|
189
|
-
self.
|
190
|
-
self.model_config.hf_config, server_args, self.processor
|
191
|
-
)
|
200
|
+
if server_args.skip_tokenizer_init:
|
201
|
+
self.tokenizer = self.processor = None
|
192
202
|
else:
|
193
203
|
self.tokenizer = get_tokenizer(
|
194
204
|
server_args.tokenizer_path,
|
@@ -255,6 +265,9 @@ class TokenizerManager:
|
|
255
265
|
self.get_internal_state_communicator = _Communicator(
|
256
266
|
self.send_to_scheduler, server_args.dp_size
|
257
267
|
)
|
268
|
+
self.expert_distribution_communicator = _Communicator(
|
269
|
+
self.send_to_scheduler, server_args.dp_size
|
270
|
+
)
|
258
271
|
|
259
272
|
self._result_dispatcher = TypeBasedDispatcher(
|
260
273
|
[
|
@@ -304,10 +317,24 @@ class TokenizerManager:
|
|
304
317
|
GetInternalStateReqOutput,
|
305
318
|
self.get_internal_state_communicator.handle_recv,
|
306
319
|
),
|
320
|
+
(
|
321
|
+
ExpertDistributionReqOutput,
|
322
|
+
self.expert_distribution_communicator.handle_recv,
|
323
|
+
),
|
307
324
|
(HealthCheckOutput, lambda x: None),
|
308
325
|
]
|
309
326
|
)
|
310
327
|
|
328
|
+
self.disaggregation_mode = DisaggregationMode(
|
329
|
+
self.server_args.disaggregation_mode
|
330
|
+
)
|
331
|
+
# for disaggregtion, start kv boostrap server on prefill
|
332
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
333
|
+
# only start bootstrap server on prefill tm
|
334
|
+
self.bootstrap_server = KVBootstrapServer(
|
335
|
+
self.server_args.disaggregation_bootstrap_port
|
336
|
+
)
|
337
|
+
|
311
338
|
async def generate_request(
|
312
339
|
self,
|
313
340
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -372,7 +399,7 @@ class TokenizerManager:
|
|
372
399
|
)
|
373
400
|
input_ids = self.tokenizer.encode(input_text)
|
374
401
|
|
375
|
-
image_inputs: Dict = await self.
|
402
|
+
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
376
403
|
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
377
404
|
)
|
378
405
|
if image_inputs and "input_ids" in image_inputs:
|
@@ -620,6 +647,15 @@ class TokenizerManager:
|
|
620
647
|
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
621
648
|
self.send_to_scheduler.send_pyobj(req)
|
622
649
|
|
650
|
+
async def start_expert_distribution_record(self):
|
651
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
652
|
+
|
653
|
+
async def stop_expert_distribution_record(self):
|
654
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
655
|
+
|
656
|
+
async def dump_expert_distribution_record(self):
|
657
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
658
|
+
|
623
659
|
async def update_weights_from_disk(
|
624
660
|
self,
|
625
661
|
obj: UpdateWeightFromDiskReqInput,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -214,7 +214,7 @@ class TpModelWorker:
|
|
214
214
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
215
215
|
success, message = self.model_runner.update_weights_from_tensor(
|
216
216
|
named_tensors=MultiprocessingSerializer.deserialize(
|
217
|
-
recv_req.serialized_named_tensors
|
217
|
+
recv_req.serialized_named_tensors[self.tp_rank]
|
218
218
|
),
|
219
219
|
load_format=recv_req.load_format,
|
220
220
|
)
|
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
|
|
33
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
34
34
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
35
35
|
from sglang.srt.server_args import ServerArgs
|
36
|
-
from sglang.srt.utils import get_compiler_backend
|
36
|
+
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
37
37
|
from sglang.utils import get_exception_traceback
|
38
38
|
|
39
39
|
logger = logging.getLogger(__name__)
|
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
|
|
69
69
|
self.future_token_ids_ct = 0
|
70
70
|
self.future_token_ids_limit = self.max_running_requests * 3
|
71
71
|
self.future_token_ids_map = torch.empty(
|
72
|
-
(self.max_running_requests * 5,), dtype=torch.
|
72
|
+
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
73
73
|
)
|
74
74
|
|
75
75
|
# Launch threads
|
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
|
|
115
115
|
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
116
116
|
self.parent_process.send_signal(signal.SIGQUIT)
|
117
117
|
|
118
|
-
@
|
118
|
+
@DynamicGradMode()
|
119
119
|
def forward_thread_func_(self):
|
120
120
|
batch_pt = 0
|
121
121
|
batch_lists = [None] * 2
|