sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ limitations under the License.
19
19
 
20
20
  import logging
21
21
  from dataclasses import dataclass
22
- from typing import TYPE_CHECKING, List, Optional, Union
22
+ from typing import List, Optional, Tuple, Union
23
23
 
24
24
  import torch
25
25
 
@@ -29,20 +29,19 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
29
29
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
30
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
31
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
32
33
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
33
-
34
- if TYPE_CHECKING:
35
- from sglang.srt.layers.sampler import SampleOutput
36
-
34
+ from sglang.srt.server_args import ServerArgs
37
35
 
38
36
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
39
37
 
40
38
  # Put some global args for easy access
41
39
  global_server_args_dict = {
42
- "disable_flashinfer": False,
43
- "disable_flashinfer_sampling": False,
44
- "triton_attention_reduce_in_fp32": False,
45
- "enable_mla": False,
40
+ "attention_backend": ServerArgs.attention_backend,
41
+ "sampling_backend": ServerArgs.sampling_backend,
42
+ "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
43
+ "enable_mla": ServerArgs.enable_mla,
44
+ "torchao_config": ServerArgs.torchao_config,
46
45
  }
47
46
 
48
47
 
@@ -53,8 +52,8 @@ class BaseFinishReason:
53
52
  def __init__(self, is_error: bool = False):
54
53
  self.is_error = is_error
55
54
 
56
- def __str__(self):
57
- raise NotImplementedError("Subclasses must implement this method")
55
+ def to_json(self):
56
+ raise NotImplementedError()
58
57
 
59
58
 
60
59
  class FINISH_MATCHED_TOKEN(BaseFinishReason):
@@ -62,40 +61,57 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
62
61
  super().__init__()
63
62
  self.matched = matched
64
63
 
65
- def __str__(self) -> str:
66
- return f"FINISH_MATCHED_TOKEN: {self.matched}"
64
+ def to_json(self):
65
+ return {
66
+ "type": "stop", # to match OpenAI API's return value
67
+ "matched": self.matched,
68
+ }
67
69
 
68
70
 
69
- class FINISH_LENGTH(BaseFinishReason):
70
- def __init__(self, length: int):
71
+ class FINISH_MATCHED_STR(BaseFinishReason):
72
+ def __init__(self, matched: str):
71
73
  super().__init__()
72
- self.length = length
74
+ self.matched = matched
73
75
 
74
- def __str__(self) -> str:
75
- return f"FINISH_LENGTH: {self.length}"
76
+ def to_json(self):
77
+ return {
78
+ "type": "stop", # to match OpenAI API's return value
79
+ "matched": self.matched,
80
+ }
76
81
 
77
82
 
78
- class FINISH_MATCHED_STR(BaseFinishReason):
79
- def __init__(self, matched: str):
83
+ class FINISH_LENGTH(BaseFinishReason):
84
+ def __init__(self, length: int):
80
85
  super().__init__()
81
- self.matched = matched
86
+ self.length = length
82
87
 
83
- def __str__(self) -> str:
84
- return f"FINISH_MATCHED_STR: {self.matched}"
88
+ def to_json(self):
89
+ return {
90
+ "type": "length", # to match OpenAI API's return value
91
+ "length": self.length,
92
+ }
85
93
 
86
94
 
87
95
  class FINISH_ABORT(BaseFinishReason):
88
96
  def __init__(self):
89
97
  super().__init__(is_error=True)
90
98
 
91
- def __str__(self) -> str:
92
- return "FINISH_ABORT"
99
+ def to_json(self):
100
+ return {
101
+ "type": "abort",
102
+ }
93
103
 
94
104
 
95
105
  class Req:
96
106
  """Store all inforamtion of a request."""
97
107
 
98
- def __init__(self, rid, origin_input_text, origin_input_ids):
108
+ def __init__(
109
+ self,
110
+ rid: str,
111
+ origin_input_text: str,
112
+ origin_input_ids: Tuple[int],
113
+ lora_path: Optional[str] = None,
114
+ ):
99
115
  # Input and output info
100
116
  self.rid = rid
101
117
  self.origin_input_text = origin_input_text
@@ -103,10 +119,15 @@ class Req:
103
119
  self.origin_input_ids = origin_input_ids
104
120
  self.output_ids = [] # Each decode stage's output ids
105
121
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
122
+ self.lora_path = lora_path
106
123
 
107
124
  # Memory info
108
125
  self.req_pool_idx = None
109
126
 
127
+ # Check finish
128
+ self.tokenizer = None
129
+ self.finished_reason = None
130
+
110
131
  # For incremental decoding
111
132
  # ----- | --------- read_ids -------|
112
133
  # ----- | surr_ids |
@@ -125,38 +146,43 @@ class Req:
125
146
  # this does not include the jump forward tokens.
126
147
  self.completion_tokens_wo_jump_forward = 0
127
148
 
128
- # For vision input
149
+ # For vision inputs
129
150
  self.pixel_values = None
130
151
  self.image_sizes = None
131
152
  self.image_offsets = None
132
153
  self.pad_value = None
154
+ self.modalities = None
133
155
 
134
156
  # Prefix info
135
- self.extend_input_len = 0
136
157
  self.prefix_indices = []
158
+ self.extend_input_len = 0
137
159
  self.last_node = None
138
160
 
139
161
  # Sampling parameters
140
162
  self.sampling_params = None
141
163
  self.stream = False
142
164
 
143
- # Check finish
144
- self.tokenizer = None
145
- self.finished_reason = None
146
-
147
- # Logprobs
165
+ # Logprobs (arguments)
148
166
  self.return_logprob = False
149
- self.embedding = None
150
167
  self.logprob_start_len = 0
151
168
  self.top_logprobs_num = 0
169
+
170
+ # Logprobs (return value)
152
171
  self.normalized_prompt_logprob = None
153
172
  self.input_token_logprobs = None
154
173
  self.input_top_logprobs = None
155
174
  self.output_token_logprobs = []
156
175
  self.output_top_logprobs = []
176
+
177
+ # Logprobs (internal values)
157
178
  # The tokens is prefilled but need to be considered as decode tokens
158
179
  # and should be updated for the decode logprobs
159
180
  self.last_update_decode_tokens = 0
181
+ # The relative logprob_start_len in an extend batch
182
+ self.extend_logprob_start_len = 0
183
+
184
+ # Embedding
185
+ self.embedding = None
160
186
 
161
187
  # Constrained decoding
162
188
  self.regex_fsm: RegexGuide = None
@@ -178,19 +204,22 @@ class Req:
178
204
  def adjust_max_prefix_ids(self):
179
205
  self.fill_ids = self.origin_input_ids + self.output_ids
180
206
  input_len = len(self.fill_ids)
181
- max_prefix_len = input_len
207
+
208
+ # FIXME: To work around some bugs in logprob computation, we need to ensure each
209
+ # request has at least one token. Later, we can relax this requirement and use `input_len`.
210
+ max_prefix_len = input_len - 1
182
211
 
183
212
  if self.sampling_params.max_new_tokens > 0:
184
213
  # Need at least one token to compute logits
185
214
  max_prefix_len = min(max_prefix_len, input_len - 1)
186
215
 
187
216
  if self.return_logprob:
188
- max_prefix_len = min(max_prefix_len, self.logprob_start_len)
189
-
190
217
  if self.normalized_prompt_logprob is None:
191
218
  # Need at least two tokens to compute normalized logprob
192
219
  max_prefix_len = min(max_prefix_len, input_len - 2)
220
+ max_prefix_len = min(max_prefix_len, self.logprob_start_len)
193
221
 
222
+ max_prefix_len = max(max_prefix_len, 0)
194
223
  return self.fill_ids[:max_prefix_len]
195
224
 
196
225
  # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
@@ -330,6 +359,8 @@ class ScheduleBatch:
330
359
  token_to_kv_pool: BaseTokenToKVPool
331
360
  tree_cache: BasePrefixCache
332
361
 
362
+ forward_mode: ForwardMode = None
363
+
333
364
  # Batched arguments to model runner
334
365
  input_ids: torch.Tensor = None
335
366
  req_pool_indices: torch.Tensor = None
@@ -340,14 +371,19 @@ class ScheduleBatch:
340
371
 
341
372
  # For mixed chunekd prefill
342
373
  prefix_lens_cpu: List[int] = None
374
+ running_bs: int = None
343
375
 
344
376
  # For processing logprobs
345
377
  return_logprob: bool = False
346
378
  top_logprobs_nums: List[int] = None
347
379
 
380
+ # Stream
381
+ has_stream: bool = False
382
+
348
383
  @classmethod
349
384
  def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
350
385
  return_logprob = any(req.return_logprob for req in reqs)
386
+ has_stream = any(req.stream for req in reqs)
351
387
 
352
388
  return cls(
353
389
  reqs=reqs,
@@ -355,18 +391,15 @@ class ScheduleBatch:
355
391
  token_to_kv_pool=token_to_kv_pool,
356
392
  tree_cache=tree_cache,
357
393
  return_logprob=return_logprob,
394
+ has_stream=has_stream,
358
395
  )
359
396
 
360
397
  def batch_size(self):
361
- return len(self.reqs) if self.reqs is not None else 0
398
+ return len(self.reqs)
362
399
 
363
400
  def is_empty(self):
364
401
  return len(self.reqs) == 0
365
402
 
366
- def has_stream(self) -> bool:
367
- # Return whether batch has at least 1 streaming request
368
- return any(r.stream for r in self.reqs)
369
-
370
403
  def alloc_req_slots(self, num_reqs):
371
404
  req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
372
405
  if req_pool_indices is None:
@@ -393,6 +426,8 @@ class ScheduleBatch:
393
426
  return out_cache_loc
394
427
 
395
428
  def prepare_for_extend(self, vocab_size: int):
429
+ self.forward_mode = ForwardMode.EXTEND
430
+
396
431
  bs = self.batch_size()
397
432
  reqs = self.reqs
398
433
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -407,8 +442,8 @@ class ScheduleBatch:
407
442
  for i, req in enumerate(reqs):
408
443
  req.req_pool_idx = req_pool_indices_cpu[i]
409
444
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
410
- ext_len = seq_len - pre_len
411
445
  seq_lens.append(seq_len)
446
+ assert seq_len - pre_len == req.extend_input_len
412
447
 
413
448
  if pre_len > 0:
414
449
  self.req_to_token_pool.req_to_token[req.req_pool_idx][
@@ -416,9 +451,19 @@ class ScheduleBatch:
416
451
  ] = req.prefix_indices
417
452
 
418
453
  self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
419
- out_cache_loc[pt : pt + ext_len]
454
+ out_cache_loc[pt : pt + req.extend_input_len]
420
455
  )
421
- pt += ext_len
456
+
457
+ # Compute the relative logprob_start_len in an extend batch
458
+ if req.logprob_start_len >= pre_len:
459
+ extend_logprob_start_len = min(
460
+ req.logprob_start_len - pre_len, req.extend_input_len - 1
461
+ )
462
+ else:
463
+ extend_logprob_start_len = req.extend_input_len - 1
464
+
465
+ req.extend_logprob_start_len = extend_logprob_start_len
466
+ pt += req.extend_input_len
422
467
 
423
468
  # Set fields
424
469
  with torch.device("cuda"):
@@ -431,18 +476,13 @@ class ScheduleBatch:
431
476
  self.out_cache_loc = out_cache_loc
432
477
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
433
478
  self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
434
-
479
+ self.extend_lens_cpu = [r.extend_input_len for r in reqs]
480
+ self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
435
481
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
436
482
 
437
483
  def mix_with_running(self, running_batch: "ScheduleBatch"):
438
- # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
439
- prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
440
- prefix_lens_cpu.extend(
441
- [
442
- len(r.origin_input_ids) + len(r.output_ids) - 1
443
- for r in running_batch.reqs
444
- ]
445
- )
484
+ self.forward_mode = ForwardMode.MIXED
485
+ running_bs = running_batch.batch_size()
446
486
 
447
487
  for req in running_batch.reqs:
448
488
  req.fill_ids = req.origin_input_ids + req.output_ids
@@ -450,12 +490,22 @@ class ScheduleBatch:
450
490
 
451
491
  input_ids = torch.cat([self.input_ids, running_batch.input_ids])
452
492
  out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
453
- extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
493
+ extend_num_tokens = self.extend_num_tokens + running_bs
494
+
454
495
  self.merge(running_batch)
455
496
  self.input_ids = input_ids
456
497
  self.out_cache_loc = out_cache_loc
457
498
  self.extend_num_tokens = extend_num_tokens
458
- self.prefix_lens_cpu = prefix_lens_cpu
499
+
500
+ # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
501
+ self.prefix_lens_cpu.extend(
502
+ [
503
+ len(r.origin_input_ids) + len(r.output_ids) - 1
504
+ for r in running_batch.reqs
505
+ ]
506
+ )
507
+ self.extend_lens_cpu.extend([1] * running_bs)
508
+ self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
459
509
 
460
510
  def check_decode_mem(self):
461
511
  bs = self.batch_size()
@@ -622,6 +672,8 @@ class ScheduleBatch:
622
672
  return jump_forward_reqs
623
673
 
624
674
  def prepare_for_decode(self, input_ids=None):
675
+ self.forward_mode = ForwardMode.DECODE
676
+
625
677
  if input_ids is None:
626
678
  input_ids = [
627
679
  r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
@@ -641,8 +693,6 @@ class ScheduleBatch:
641
693
  self.req_pool_indices, self.seq_lens - 1
642
694
  ] = self.out_cache_loc
643
695
 
644
- self.sampling_info.update_regex_vocab_mask(self)
645
-
646
696
  def filter_batch(self, unfinished_indices: List[int]):
647
697
  if unfinished_indices is None or len(unfinished_indices) == 0:
648
698
  # Filter out all requests
@@ -662,6 +712,7 @@ class ScheduleBatch:
662
712
  self.out_cache_loc = None
663
713
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
664
714
  self.return_logprob = any(req.return_logprob for req in self.reqs)
715
+ self.has_stream = any(req.stream for req in self.reqs)
665
716
 
666
717
  self.sampling_info.filter(unfinished_indices, new_indices)
667
718
 
@@ -672,7 +723,6 @@ class ScheduleBatch:
672
723
  self.sampling_info.merge(other.sampling_info)
673
724
 
674
725
  self.reqs.extend(other.reqs)
675
-
676
726
  self.req_pool_indices = torch.concat(
677
727
  [self.req_pool_indices, other.req_pool_indices]
678
728
  )
@@ -683,18 +733,4 @@ class ScheduleBatch:
683
733
  self.out_cache_loc = None
684
734
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
685
735
  self.return_logprob = any(req.return_logprob for req in self.reqs)
686
-
687
- def check_sample_results(self, sample_output: SampleOutput):
688
- if not torch.all(sample_output.success):
689
- probs = sample_output.probs
690
- batch_next_token_ids = sample_output.batch_next_token_ids
691
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
692
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
693
- argmax_ids = torch.argmax(probs, dim=-1)
694
- batch_next_token_ids = torch.where(
695
- sample_output.success, batch_next_token_ids, argmax_ids
696
- )
697
- sample_output.probs = probs
698
- sample_output.batch_next_token_ids = batch_next_token_ids
699
-
700
- return sample_output.batch_next_token_ids
736
+ self.has_stream = any(req.stream for req in self.reqs)
@@ -18,6 +18,7 @@ limitations under the License.
18
18
  import asyncio
19
19
  import concurrent.futures
20
20
  import dataclasses
21
+ import json
21
22
  import logging
22
23
  import multiprocessing as mp
23
24
  import os
@@ -77,7 +78,6 @@ class TokenizerManager:
77
78
  self,
78
79
  server_args: ServerArgs,
79
80
  port_args: PortArgs,
80
- model_override_args: dict = None,
81
81
  ):
82
82
  self.server_args = server_args
83
83
 
@@ -86,8 +86,8 @@ class TokenizerManager:
86
86
  self.recv_from_detokenizer = context.socket(zmq.PULL)
87
87
  self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
88
88
 
89
- self.send_to_router = context.socket(zmq.PUSH)
90
- self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
89
+ self.send_to_controller = context.socket(zmq.PUSH)
90
+ self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
91
91
 
92
92
  # Read model args
93
93
  self.model_path = server_args.model_path
@@ -95,7 +95,7 @@ class TokenizerManager:
95
95
  self.hf_config = get_config(
96
96
  self.model_path,
97
97
  trust_remote_code=server_args.trust_remote_code,
98
- model_override_args=model_override_args,
98
+ model_override_args=json.loads(server_args.json_model_override_args),
99
99
  )
100
100
  self.is_generation = is_generation_model(
101
101
  self.hf_config.architectures, self.server_args.is_embedding
@@ -188,6 +188,7 @@ class TokenizerManager:
188
188
  pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
189
189
  obj.image_data if not_use_index else obj.image_data[index]
190
190
  )
191
+ modalities = obj.modalities
191
192
  return_logprob = (
192
193
  obj.return_logprob if not_use_index else obj.return_logprob[index]
193
194
  )
@@ -196,8 +197,6 @@ class TokenizerManager:
196
197
  if not_use_index
197
198
  else obj.logprob_start_len[index]
198
199
  )
199
- if return_logprob and logprob_start_len == -1:
200
- logprob_start_len = len(input_ids) - 1
201
200
  top_logprobs_num = (
202
201
  obj.top_logprobs_num
203
202
  if not_use_index
@@ -243,14 +242,13 @@ class TokenizerManager:
243
242
  pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
244
243
  obj.image_data[0]
245
244
  )
245
+ modalities = obj.modalities
246
246
  return_logprob = obj.return_logprob[0]
247
247
  logprob_start_len = obj.logprob_start_len[0]
248
248
  top_logprobs_num = obj.top_logprobs_num[0]
249
249
 
250
250
  # Send to the controller
251
251
  if self.is_generation:
252
- if return_logprob and logprob_start_len == -1:
253
- logprob_start_len = len(input_ids) - 1
254
252
  tokenized_obj = TokenizedGenerateReqInput(
255
253
  rid,
256
254
  input_text,
@@ -263,6 +261,12 @@ class TokenizerManager:
263
261
  logprob_start_len,
264
262
  top_logprobs_num,
265
263
  obj.stream,
264
+ modalities,
265
+ (
266
+ obj.lora_path[index]
267
+ if isinstance(obj.lora_path, list)
268
+ else obj.lora_path
269
+ ),
266
270
  )
267
271
  else: # is embedding
268
272
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -271,7 +275,7 @@ class TokenizerManager:
271
275
  input_ids,
272
276
  sampling_params,
273
277
  )
274
- self.send_to_router.send_pyobj(tokenized_obj)
278
+ self.send_to_controller.send_pyobj(tokenized_obj)
275
279
 
276
280
  # Recv results
277
281
  event = asyncio.Event()
@@ -341,11 +345,10 @@ class TokenizerManager:
341
345
  sampling_params = self._get_sampling_params(obj.sampling_params[index])
342
346
 
343
347
  if self.is_generation:
344
- if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
345
- obj.logprob_start_len[index] = len(input_ids) - 1
346
348
  pixel_values, image_hashes, image_sizes = (
347
349
  await self._get_pixel_values(obj.image_data[index])
348
350
  )
351
+ modalities = obj.modalities
349
352
 
350
353
  tokenized_obj = TokenizedGenerateReqInput(
351
354
  rid,
@@ -359,6 +362,12 @@ class TokenizerManager:
359
362
  obj.logprob_start_len[index],
360
363
  obj.top_logprobs_num[index],
361
364
  obj.stream,
365
+ modalities,
366
+ (
367
+ obj.lora_path[index]
368
+ if isinstance(obj.lora_path, list)
369
+ else obj.lora_path
370
+ ),
362
371
  )
363
372
  else:
364
373
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -367,7 +376,7 @@ class TokenizerManager:
367
376
  input_ids,
368
377
  sampling_params,
369
378
  )
370
- self.send_to_router.send_pyobj(tokenized_obj)
379
+ self.send_to_controller.send_pyobj(tokenized_obj)
371
380
 
372
381
  event = asyncio.Event()
373
382
  state = ReqState([], False, event)
@@ -500,14 +509,14 @@ class TokenizerManager:
500
509
 
501
510
  def flush_cache(self):
502
511
  req = FlushCacheReq()
503
- self.send_to_router.send_pyobj(req)
512
+ self.send_to_controller.send_pyobj(req)
504
513
 
505
514
  def abort_request(self, rid: str):
506
515
  if rid not in self.rid_to_state:
507
516
  return
508
517
  del self.rid_to_state[rid]
509
518
  req = AbortReq(rid)
510
- self.send_to_router.send_pyobj(req)
519
+ self.send_to_controller.send_pyobj(req)
511
520
 
512
521
  async def update_weights(
513
522
  self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
@@ -524,7 +533,7 @@ class TokenizerManager:
524
533
  # wait for the previous generation requests to finish
525
534
  while len(self.rid_to_state) > 0:
526
535
  await asyncio.sleep(0)
527
- self.send_to_router.send_pyobj(obj)
536
+ self.send_to_controller.send_pyobj(obj)
528
537
  self.model_update_result = asyncio.Future()
529
538
  result = await self.model_update_result
530
539
  if result.success: