sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sglang/launch_server_llavavid.py +26 -0
  2. sglang/srt/constrained/fsm_cache.py +11 -2
  3. sglang/srt/constrained/jump_forward.py +1 -0
  4. sglang/srt/hf_transformers_utils.py +0 -149
  5. sglang/srt/layers/activation.py +93 -11
  6. sglang/srt/layers/layernorm.py +47 -4
  7. sglang/srt/layers/logits_processor.py +4 -4
  8. sglang/srt/layers/sampler.py +15 -68
  9. sglang/srt/managers/io_struct.py +5 -4
  10. sglang/srt/managers/schedule_batch.py +20 -25
  11. sglang/srt/managers/tokenizer_manager.py +74 -61
  12. sglang/srt/managers/tp_worker.py +49 -43
  13. sglang/srt/model_executor/cuda_graph_runner.py +17 -31
  14. sglang/srt/model_executor/forward_batch_info.py +9 -26
  15. sglang/srt/model_executor/model_runner.py +20 -17
  16. sglang/srt/models/chatglm.py +13 -5
  17. sglang/srt/models/commandr.py +1 -5
  18. sglang/srt/models/dbrx.py +1 -5
  19. sglang/srt/models/deepseek.py +1 -5
  20. sglang/srt/models/deepseek_v2.py +1 -5
  21. sglang/srt/models/gemma.py +3 -7
  22. sglang/srt/models/gemma2.py +2 -56
  23. sglang/srt/models/gpt_bigcode.py +2 -6
  24. sglang/srt/models/grok.py +10 -8
  25. sglang/srt/models/internlm2.py +1 -5
  26. sglang/srt/models/llama2.py +6 -11
  27. sglang/srt/models/llama_classification.py +2 -6
  28. sglang/srt/models/llama_embedding.py +3 -4
  29. sglang/srt/models/llava.py +69 -91
  30. sglang/srt/models/llavavid.py +40 -86
  31. sglang/srt/models/minicpm.py +1 -5
  32. sglang/srt/models/mixtral.py +1 -5
  33. sglang/srt/models/mixtral_quant.py +1 -5
  34. sglang/srt/models/qwen.py +2 -5
  35. sglang/srt/models/qwen2.py +5 -10
  36. sglang/srt/models/qwen2_moe.py +21 -24
  37. sglang/srt/models/stablelm.py +1 -5
  38. sglang/srt/models/yivl.py +2 -7
  39. sglang/srt/openai_api/adapter.py +85 -4
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +1 -74
  42. sglang/srt/sampling/sampling_params.py +4 -0
  43. sglang/srt/server.py +11 -4
  44. sglang/srt/utils.py +18 -33
  45. sglang/test/runners.py +2 -2
  46. sglang/test/test_layernorm.py +53 -1
  47. sglang/version.py +1 -1
  48. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
  49. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
  50. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
  51. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  """
4
2
  Copyright 2023-2024 SGLang Team
5
3
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,7 +17,7 @@ limitations under the License.
19
17
 
20
18
  import logging
21
19
  from dataclasses import dataclass
22
- from typing import TYPE_CHECKING, List, Optional, Union
20
+ from typing import List, Optional, Union
23
21
 
24
22
  import torch
25
23
 
@@ -31,10 +29,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
29
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
30
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
33
31
 
34
- if TYPE_CHECKING:
35
- from sglang.srt.layers.sampler import SampleOutput
36
-
37
-
38
32
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
39
33
 
40
34
  # Put some global args for easy access
@@ -127,8 +121,8 @@ class Req:
127
121
 
128
122
  # For vision input
129
123
  self.pixel_values = None
130
- self.image_size = None
131
- self.image_offset = None
124
+ self.image_sizes = None
125
+ self.image_offsets = None
132
126
  self.pad_value = None
133
127
 
134
128
  # Prefix info
@@ -268,7 +262,14 @@ class Req:
268
262
 
269
263
  all_text = self.origin_input_text + self.decoded_text + jump_forward_str
270
264
  all_ids = self.tokenizer.encode(all_text)
265
+ if not all_ids:
266
+ logger.warning("Encoded all_text resulted in empty all_ids")
267
+ return False
268
+
271
269
  prompt_tokens = len(self.origin_input_ids_unpadded)
270
+ if prompt_tokens > len(all_ids):
271
+ logger.warning("prompt_tokens is larger than encoded all_ids")
272
+ return False
272
273
 
273
274
  if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
274
275
  # TODO(lsyin): fix token fusion
@@ -599,12 +600,12 @@ class ScheduleBatch:
599
600
  if req.pixel_values is not None:
600
601
  (
601
602
  req.origin_input_ids,
602
- req.image_offset,
603
+ req.image_offsets,
603
604
  ) = model_runner.model.pad_input_ids(
604
605
  req.origin_input_ids_unpadded,
605
606
  req.pad_value,
606
- req.pixel_values.shape,
607
- req.image_size,
607
+ req.pixel_values,
608
+ req.image_sizes,
608
609
  )
609
610
 
610
611
  jump_forward_reqs.append(req)
@@ -677,17 +678,11 @@ class ScheduleBatch:
677
678
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
678
679
  self.return_logprob = any(req.return_logprob for req in self.reqs)
679
680
 
680
- def check_sample_results(self, sample_output: SampleOutput):
681
- if not torch.all(sample_output.success):
682
- probs = sample_output.probs
683
- batch_next_token_ids = sample_output.batch_next_token_ids
684
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
685
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
686
- argmax_ids = torch.argmax(probs, dim=-1)
687
- batch_next_token_ids = torch.where(
688
- sample_output.success, batch_next_token_ids, argmax_ids
689
- )
690
- sample_output.probs = probs
691
- sample_output.batch_next_token_ids = batch_next_token_ids
681
+ def sample(self, logits: torch.Tensor):
682
+ from sglang.srt.layers.sampler import Sampler
683
+
684
+ sampler = Sampler()
685
+
686
+ batch_next_token_ids = sampler(logits, self.sampling_info)
692
687
 
693
- return sample_output.batch_next_token_ids
688
+ return batch_next_token_ids
@@ -23,6 +23,7 @@ import multiprocessing as mp
23
23
  import os
24
24
  from typing import Dict, List, Optional, Tuple, Union
25
25
 
26
+ import fastapi
26
27
  import numpy as np
27
28
  import transformers
28
29
  import uvloop
@@ -96,21 +97,18 @@ class TokenizerManager:
96
97
  trust_remote_code=server_args.trust_remote_code,
97
98
  model_overide_args=model_overide_args,
98
99
  )
99
-
100
100
  self.is_generation = is_generation_model(
101
101
  self.hf_config.architectures, self.server_args.is_embedding
102
102
  )
103
-
104
- if server_args.context_length is not None:
105
- self.context_len = server_args.context_length
106
- else:
107
- self.context_len = get_context_length(self.hf_config)
103
+ self.context_len = server_args.context_length or get_context_length(
104
+ self.hf_config
105
+ )
108
106
 
109
107
  # Create tokenizer
110
108
  if server_args.skip_tokenizer_init:
111
109
  self.tokenizer = self.processor = None
112
110
  else:
113
- if is_multimodal_model(self.model_path):
111
+ if is_multimodal_model(self.hf_config.architectures):
114
112
  self.processor = get_processor(
115
113
  server_args.tokenizer_path,
116
114
  tokenizer_mode=server_args.tokenizer_mode,
@@ -118,6 +116,9 @@ class TokenizerManager:
118
116
  )
119
117
  self.tokenizer = self.processor.tokenizer
120
118
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
119
+
120
+ # We want to parallelize the image pre-processing so we
121
+ # create an executor for it
121
122
  self.executor = concurrent.futures.ProcessPoolExecutor(
122
123
  initializer=init_global_processor,
123
124
  mp_context=mp.get_context("fork"),
@@ -134,12 +135,14 @@ class TokenizerManager:
134
135
  self.to_create_loop = True
135
136
  self.rid_to_state: Dict[str, ReqState] = {}
136
137
 
137
- # for update model weights
138
+ # For update model weights
138
139
  self.model_update_lock = asyncio.Lock()
139
140
  self.model_update_result = None
140
141
 
141
142
  async def generate_request(
142
- self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
143
+ self,
144
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
145
+ request: Optional[fastapi.Request] = None,
143
146
  ):
144
147
  if self.to_create_loop:
145
148
  self.create_handle_loop()
@@ -160,7 +163,7 @@ class TokenizerManager:
160
163
  async def _handle_single_request(
161
164
  self,
162
165
  obj: Union[GenerateReqInput, EmbeddingReqInput],
163
- request,
166
+ request: Optional[fastapi.Request] = None,
164
167
  index: Optional[int] = None,
165
168
  is_cache_for_prefill: Optional[bool] = False,
166
169
  ):
@@ -182,8 +185,8 @@ class TokenizerManager:
182
185
  )
183
186
 
184
187
  if self.is_generation:
185
- pixel_values, image_hash, image_size = await self._get_pixel_values(
186
- obj.image_data
188
+ pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
189
+ obj.image_data if not_use_index else obj.image_data[index]
187
190
  )
188
191
  return_logprob = (
189
192
  obj.return_logprob if not_use_index else obj.return_logprob[index]
@@ -195,7 +198,6 @@ class TokenizerManager:
195
198
  )
196
199
  if return_logprob and logprob_start_len == -1:
197
200
  logprob_start_len = len(input_ids) - 1
198
-
199
201
  top_logprobs_num = (
200
202
  obj.top_logprobs_num
201
203
  if not_use_index
@@ -238,13 +240,14 @@ class TokenizerManager:
238
240
 
239
241
  sampling_params = SamplingParams(**obj.sampling_params[0])
240
242
  sampling_params.max_new_tokens = 0
241
- pixel_values, image_hash, image_size = await self._get_pixel_values(
243
+ pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
242
244
  obj.image_data[0]
243
245
  )
244
246
  return_logprob = obj.return_logprob[0]
245
247
  logprob_start_len = obj.logprob_start_len[0]
246
248
  top_logprobs_num = obj.top_logprobs_num[0]
247
249
 
250
+ # Send to the controller
248
251
  if self.is_generation:
249
252
  if return_logprob and logprob_start_len == -1:
250
253
  logprob_start_len = len(input_ids) - 1
@@ -253,8 +256,8 @@ class TokenizerManager:
253
256
  input_text,
254
257
  input_ids,
255
258
  pixel_values,
256
- image_hash,
257
- image_size,
259
+ image_hashes,
260
+ image_sizes,
258
261
  sampling_params,
259
262
  return_logprob,
260
263
  logprob_start_len,
@@ -268,24 +271,24 @@ class TokenizerManager:
268
271
  input_ids,
269
272
  sampling_params,
270
273
  )
271
-
272
274
  self.send_to_router.send_pyobj(tokenized_obj)
273
275
 
276
+ # Recv results
274
277
  event = asyncio.Event()
275
278
  state = ReqState([], False, event)
276
279
  self.rid_to_state[rid] = state
277
280
  if not is_cache_for_prefill:
278
- async for response in self._wait_for_response(
279
- event, state, obj, rid, request
280
- ):
281
+ async for response in self._wait_for_response(state, obj, rid, request):
281
282
  yield response
282
283
  else:
283
284
  assert self.is_generation
284
- await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
285
+ await self._wait_for_cache_prefill_response(state, obj, rid, request)
285
286
  yield input_ids
286
287
 
287
288
  async def _handle_batch_request(
288
- self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
289
+ self,
290
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
291
+ request: Optional[fastapi.Request] = None,
289
292
  ):
290
293
  batch_size = obj.batch_size
291
294
  if self.is_generation:
@@ -340,8 +343,8 @@ class TokenizerManager:
340
343
  if self.is_generation:
341
344
  if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
342
345
  obj.logprob_start_len[index] = len(input_ids) - 1
343
- pixel_values, image_hash, image_size = await self._get_pixel_values(
344
- obj.image_data[index]
346
+ pixel_values, image_hashes, image_sizes = (
347
+ await self._get_pixel_values(obj.image_data[index])
345
348
  )
346
349
 
347
350
  tokenized_obj = TokenizedGenerateReqInput(
@@ -349,8 +352,8 @@ class TokenizerManager:
349
352
  input_text,
350
353
  input_ids,
351
354
  pixel_values,
352
- image_hash,
353
- image_size,
355
+ image_hashes,
356
+ image_sizes,
354
357
  sampling_params,
355
358
  obj.return_logprob[index],
356
359
  obj.logprob_start_len[index],
@@ -372,7 +375,6 @@ class TokenizerManager:
372
375
 
373
376
  generators.append(
374
377
  self._wait_for_response(
375
- event,
376
378
  state,
377
379
  obj,
378
380
  rid,
@@ -388,6 +390,7 @@ class TokenizerManager:
388
390
  tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
389
391
  output_list = [None] * len(tasks)
390
392
 
393
+ # Recv results
391
394
  while tasks:
392
395
  done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
393
396
 
@@ -426,25 +429,18 @@ class TokenizerManager:
426
429
  sampling_params.verify()
427
430
  return sampling_params
428
431
 
429
- async def _get_pixel_values(self, image_data):
430
- if image_data is None:
431
- return None, None, None
432
- else:
433
- return await self._get_pixel_values_internal(image_data)
434
-
435
432
  async def _wait_for_response(
436
433
  self,
437
- event: asyncio.Event,
438
434
  state: ReqState,
439
435
  obj: Union[GenerateReqInput, EmbeddingReqInput],
440
436
  rid: str,
441
- request,
442
- index: int = None,
437
+ request: Optional[fastapi.Request] = None,
438
+ index: Optional[int] = None,
443
439
  response_index: int = 0,
444
440
  ):
445
441
  while True:
446
442
  try:
447
- await asyncio.wait_for(event.wait(), timeout=4)
443
+ await asyncio.wait_for(state.event.wait(), timeout=4)
448
444
  except asyncio.TimeoutError:
449
445
  if request is not None and await request.is_disconnected():
450
446
  for rid in [obj.rid] if obj.is_single else obj.rid:
@@ -478,16 +474,15 @@ class TokenizerManager:
478
474
  yield out
479
475
  break
480
476
 
481
- event.clear()
477
+ state.event.clear()
482
478
  yield out
483
479
 
484
480
  async def _wait_for_cache_prefill_response(
485
481
  self,
486
- event: asyncio.Event,
487
482
  state: ReqState,
488
483
  obj: GenerateReqInput,
489
484
  rid: str,
490
- request,
485
+ request: Optional[fastapi.Request] = None,
491
486
  ):
492
487
  while True:
493
488
  try:
@@ -514,7 +509,9 @@ class TokenizerManager:
514
509
  req = AbortReq(rid)
515
510
  self.send_to_router.send_pyobj(req)
516
511
 
517
- async def update_weights(self, obj: UpdateWeightReqInput, request):
512
+ async def update_weights(
513
+ self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
514
+ ):
518
515
  if self.to_create_loop:
519
516
  self.create_handle_loop()
520
517
 
@@ -659,12 +656,11 @@ class TokenizerManager:
659
656
  )
660
657
  return top_logprobs
661
658
 
662
- async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
663
- aspect_ratio = (
664
- getattr(self.hf_config, "image_aspect_ratio", None)
665
- if aspect_ratio is None
666
- else aspect_ratio
667
- )
659
+ async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
660
+ if not image_data:
661
+ return None, None, None
662
+
663
+ aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
668
664
  grid_pinpoints = (
669
665
  self.hf_config.image_grid_pinpoints
670
666
  if hasattr(self.hf_config, "image_grid_pinpoints")
@@ -673,35 +669,42 @@ class TokenizerManager:
673
669
  )
674
670
 
675
671
  if isinstance(image_data, list) and len(image_data) > 0:
676
- pixel_values, image_hash, image_size = [], [], []
672
+ # Multiple images
677
673
  if len(image_data) > 1:
678
674
  aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
675
+ pixel_values, image_hashes, image_sizes = [], [], []
679
676
  for img_data in image_data:
680
677
  pixel_v, image_h, image_s = await self._process_single_image(
681
678
  img_data, aspect_ratio, grid_pinpoints
682
679
  )
683
680
  pixel_values.append(pixel_v)
684
- image_hash.append(image_h)
685
- image_size.append(image_s)
686
- pixel_values = np.stack(pixel_values, axis=0)
681
+ image_hashes.append(image_h)
682
+ image_sizes.append(image_s)
683
+
684
+ if isinstance(pixel_values[0], np.ndarray):
685
+ pixel_values = np.stack(pixel_values, axis=0)
687
686
  else:
687
+ # A single image
688
688
  pixel_values, image_hash, image_size = await self._process_single_image(
689
689
  image_data[0], aspect_ratio, grid_pinpoints
690
690
  )
691
- image_hash = [image_hash]
692
- image_size = [image_size]
691
+ image_hashes = [image_hash]
692
+ image_sizes = [image_size]
693
693
  elif isinstance(image_data, str):
694
+ # A single image
694
695
  pixel_values, image_hash, image_size = await self._process_single_image(
695
696
  image_data, aspect_ratio, grid_pinpoints
696
697
  )
697
- image_hash = [image_hash]
698
- image_size = [image_size]
698
+ image_hashes = [image_hash]
699
+ image_sizes = [image_size]
699
700
  else:
700
- pixel_values, image_hash, image_size = None, None, None
701
+ raise ValueError(f"Invalid image data: {image_data}")
701
702
 
702
- return pixel_values, image_hash, image_size
703
+ return pixel_values, image_hashes, image_sizes
703
704
 
704
- async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
705
+ async def _process_single_image(
706
+ self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
707
+ ):
705
708
  if self.executor is not None:
706
709
  loop = asyncio.get_event_loop()
707
710
  return await loop.run_in_executor(
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
732
735
 
733
736
 
734
737
  def _process_single_image_task(
735
- image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
738
+ image_data: Union[str, bytes],
739
+ image_aspect_ratio: Optional[str] = None,
740
+ image_grid_pinpoints: Optional[str] = None,
741
+ processor=None,
736
742
  ):
737
743
  try:
738
744
  processor = processor or global_processor
739
745
  image, image_size = load_image(image_data)
740
746
  if image_size is not None:
747
+ # It is a video with multiple images
741
748
  image_hash = hash(image_data)
742
749
  pixel_values = processor.image_processor(image)["pixel_values"]
743
750
  for _ in range(len(pixel_values)):
@@ -745,6 +752,7 @@ def _process_single_image_task(
745
752
  pixel_values = np.stack(pixel_values, axis=0)
746
753
  return pixel_values, image_hash, image_size
747
754
  else:
755
+ # It is an image
748
756
  image_hash = hash(image_data)
749
757
  if image_aspect_ratio == "pad":
750
758
  image = expand2square(
@@ -754,13 +762,18 @@ def _process_single_image_task(
754
762
  pixel_values = processor.image_processor(image.convert("RGB"))[
755
763
  "pixel_values"
756
764
  ][0]
757
- elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
765
+ elif image_aspect_ratio == "anyres" or (
766
+ image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
767
+ ):
758
768
  pixel_values = process_anyres_image(
759
769
  image, processor.image_processor, image_grid_pinpoints
760
770
  )
761
771
  else:
762
772
  pixel_values = processor.image_processor(image)["pixel_values"][0]
763
- pixel_values = pixel_values.astype(np.float16)
773
+
774
+ if isinstance(pixel_values, np.ndarray):
775
+ pixel_values = pixel_values.astype(np.float16)
776
+
764
777
  return pixel_values, image_hash, image.size
765
778
  except Exception:
766
779
  logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
31
31
  from sglang.srt.constrained.fsm_cache import FSMCache
32
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
34
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
36
36
  AbortReq,
37
37
  BatchEmbeddingOut,
@@ -108,7 +108,7 @@ class ModelTpServer:
108
108
  if server_args.skip_tokenizer_init:
109
109
  self.tokenizer = self.processor = None
110
110
  else:
111
- if is_multimodal_model(server_args.model_path):
111
+ if is_multimodal_model(self.model_config.hf_config.architectures):
112
112
  self.processor = get_processor(
113
113
  server_args.tokenizer_path,
114
114
  tokenizer_mode=server_args.tokenizer_mode,
@@ -197,6 +197,16 @@ class ModelTpServer:
197
197
  "trust_remote_code": server_args.trust_remote_code,
198
198
  },
199
199
  skip_tokenizer_init=server_args.skip_tokenizer_init,
200
+ json_schema_mode=False,
201
+ )
202
+ self.json_fsm_cache = FSMCache(
203
+ server_args.tokenizer_path,
204
+ {
205
+ "tokenizer_mode": server_args.tokenizer_mode,
206
+ "trust_remote_code": server_args.trust_remote_code,
207
+ },
208
+ skip_tokenizer_init=server_args.skip_tokenizer_init,
209
+ json_schema_mode=True,
200
210
  )
201
211
  self.jump_forward_cache = JumpForwardCache()
202
212
 
@@ -323,34 +333,42 @@ class ModelTpServer:
323
333
  if self.model_runner.is_generation:
324
334
  req.pixel_values = recv_req.pixel_values
325
335
  if req.pixel_values is not None:
326
- image_hash = (
327
- hash(tuple(recv_req.image_hash))
328
- if isinstance(recv_req.image_hash, list)
329
- else recv_req.image_hash
330
- )
336
+ # Use image hash as fake token_ids, which is then used
337
+ # for prefix matching
338
+ image_hash = hash(tuple(recv_req.image_hashes))
331
339
  req.pad_value = [
332
340
  (image_hash) % self.model_config.vocab_size,
333
341
  (image_hash >> 16) % self.model_config.vocab_size,
334
342
  (image_hash >> 32) % self.model_config.vocab_size,
335
343
  (image_hash >> 64) % self.model_config.vocab_size,
336
344
  ]
337
- req.image_size = recv_req.image_size
345
+ req.image_sizes = recv_req.image_sizes
338
346
  (
339
347
  req.origin_input_ids,
340
- req.image_offset,
348
+ req.image_offsets,
341
349
  ) = self.model_runner.model.pad_input_ids(
342
350
  req.origin_input_ids_unpadded,
343
351
  req.pad_value,
344
- req.pixel_values.shape,
345
- req.image_size,
352
+ req.pixel_values,
353
+ req.image_sizes,
346
354
  )
347
355
  req.return_logprob = recv_req.return_logprob
348
356
  req.logprob_start_len = recv_req.logprob_start_len
349
357
  req.top_logprobs_num = recv_req.top_logprobs_num
350
358
  req.stream = recv_req.stream
351
359
 
360
+ # Init regex fsm fron json
361
+ if req.sampling_params.json_schema is not None:
362
+ req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
363
+ req.sampling_params.json_schema
364
+ )
365
+ if not self.disable_regex_jump_forward:
366
+ req.jump_forward_map = self.jump_forward_cache.query(
367
+ computed_regex_string
368
+ )
369
+
352
370
  # Init regex fsm
353
- if req.sampling_params.regex is not None:
371
+ elif req.sampling_params.regex is not None:
354
372
  req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
355
373
  if not self.disable_regex_jump_forward:
356
374
  req.jump_forward_map = self.jump_forward_cache.query(
@@ -486,29 +504,21 @@ class ModelTpServer:
486
504
  if self.model_runner.is_generation:
487
505
  # Forward and sample the next tokens
488
506
  if batch.extend_num_tokens != 0:
489
- sample_output, logits_output = self.model_runner.forward(
490
- batch, ForwardMode.EXTEND
491
- )
492
- next_token_ids = batch.check_sample_results(sample_output)
507
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
508
+ next_token_ids = batch.sample(output.next_token_logits)
493
509
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
494
510
  next_token_ids
495
511
  )
496
512
 
497
513
  # Move logprobs to cpu
498
- if logits_output.next_token_logprobs is not None:
499
- logits_output.next_token_logprobs = (
500
- logits_output.next_token_logprobs[
501
- torch.arange(
502
- len(next_token_ids), device=next_token_ids.device
503
- ),
504
- next_token_ids,
505
- ].tolist()
506
- )
507
- logits_output.input_token_logprobs = (
508
- logits_output.input_token_logprobs.tolist()
509
- )
510
- logits_output.normalized_prompt_logprobs = (
511
- logits_output.normalized_prompt_logprobs.tolist()
514
+ if output.next_token_logprobs is not None:
515
+ output.next_token_logprobs = output.next_token_logprobs[
516
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
517
+ next_token_ids,
518
+ ].tolist()
519
+ output.input_token_logprobs = output.input_token_logprobs.tolist()
520
+ output.normalized_prompt_logprobs = (
521
+ output.normalized_prompt_logprobs.tolist()
512
522
  )
513
523
 
514
524
  next_token_ids = next_token_ids.tolist()
@@ -547,14 +557,12 @@ class ModelTpServer:
547
557
  self.req_to_token_pool.free(req.req_pool_idx)
548
558
 
549
559
  if req.return_logprob:
550
- self.add_logprob_return_values(
551
- i, req, pt, next_token_ids, logits_output
552
- )
560
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
553
561
  pt += req.extend_input_len
554
562
  else:
555
563
  assert batch.extend_num_tokens != 0
556
- logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
557
- embeddings = logits_output.embeddings.tolist()
564
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
565
+ embeddings = output.embeddings.tolist()
558
566
 
559
567
  # Check finish conditions
560
568
  for i, req in enumerate(batch.reqs):
@@ -582,7 +590,7 @@ class ModelTpServer:
582
590
  req: Req,
583
591
  pt: int,
584
592
  next_token_ids: List[int],
585
- output: LogitsProcessorOutput,
593
+ output: LogitProcessorOutput,
586
594
  ):
587
595
  if req.normalized_prompt_logprob is None:
588
596
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -664,17 +672,15 @@ class ModelTpServer:
664
672
  batch.prepare_for_decode()
665
673
 
666
674
  # Forward and sample the next tokens
667
- sample_output, logits_output = self.model_runner.forward(
668
- batch, ForwardMode.DECODE
669
- )
670
- next_token_ids = batch.check_sample_results(sample_output)
675
+ output = self.model_runner.forward(batch, ForwardMode.DECODE)
676
+ next_token_ids = batch.sample(output.next_token_logits)
671
677
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
672
678
  next_token_ids
673
679
  )
674
680
 
675
681
  # Move logprobs to cpu
676
- if logits_output.next_token_logprobs is not None:
677
- next_token_logprobs = logits_output.next_token_logprobs[
682
+ if output.next_token_logprobs is not None:
683
+ next_token_logprobs = output.next_token_logprobs[
678
684
  torch.arange(len(next_token_ids), device=next_token_ids.device),
679
685
  next_token_ids,
680
686
  ].tolist()
@@ -700,7 +706,7 @@ class ModelTpServer:
700
706
  (next_token_logprobs[i], next_token_id)
701
707
  )
702
708
  if req.top_logprobs_num > 0:
703
- req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
709
+ req.output_top_logprobs.append(output.output_top_logprobs[i])
704
710
 
705
711
  self.handle_finished_requests(batch)
706
712