sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
sglang/lang/ir.py CHANGED
@@ -193,17 +193,11 @@ class SglFunction:
193
193
  backend = backend or global_config.default_backend
194
194
  return trace_program(self, kwargs, backend)
195
195
 
196
- def pin(self, backend=None):
197
- from sglang.lang.interpreter import pin_program
196
+ def cache(self, backend=None):
197
+ from sglang.lang.interpreter import cache_program
198
198
 
199
199
  backend = backend or global_config.default_backend
200
- return pin_program(self, backend)
201
-
202
- def unpin(self, backend=None):
203
- from sglang.lang.interpreter import unpin_program
204
-
205
- backend = backend or global_config.default_backend
206
- return unpin_program(self, backend)
200
+ return cache_program(self, backend)
207
201
 
208
202
  def compile(self, *, backend=None):
209
203
  from sglang.lang.compiler import compile_func
@@ -336,6 +330,15 @@ class SglImage(SglExpr):
336
330
  return f"SglImage({self.path})"
337
331
 
338
332
 
333
+ class SglVideo(SglExpr):
334
+ def __init__(self, path, num_frames):
335
+ self.path = path
336
+ self.num_frames = num_frames
337
+
338
+ def __repr__(self) -> str:
339
+ return f"SglVideo({self.path}, {self.num_frames})"
340
+
341
+
339
342
  class SglGen(SglExpr):
340
343
  def __init__(
341
344
  self,
sglang/lang/tracer.py CHANGED
@@ -109,19 +109,21 @@ class TracerProgramState(ProgramState):
109
109
  ########### Public API ###########
110
110
  ##################################
111
111
 
112
- def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
112
+ def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
113
+ assert size >= 1
114
+
113
115
  if self.only_trace_prefix:
114
116
  raise StopTracing()
115
117
 
116
- fork_node = SglFork(number)
118
+ fork_node = SglFork(size)
117
119
  fork_node.prev_node = self.last_node
118
120
 
119
121
  states = [
120
122
  TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
121
- for _ in range(number)
123
+ for _ in range(size)
122
124
  ]
123
125
 
124
- for i in range(number):
126
+ for i in range(size):
125
127
  node = SglGetForkItem(i)
126
128
  node.prev_node = fork_node
127
129
  states[i].last_node = node
@@ -0,0 +1,31 @@
1
+ import argparse
2
+ import multiprocessing as mp
3
+
4
+ from sglang.srt.server import ServerArgs, launch_server
5
+
6
+ if __name__ == "__main__":
7
+
8
+ model_overide_args = {}
9
+
10
+ model_overide_args["mm_spatial_pool_stride"] = 2
11
+ model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
12
+ model_overide_args["num_frames"] = 16
13
+ model_overide_args["model_type"] = "llavavid"
14
+ if model_overide_args["num_frames"] == 32:
15
+ model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
16
+ model_overide_args["max_sequence_length"] = 4096 * 2
17
+ model_overide_args["tokenizer_model_max_length"] = 4096 * 2
18
+ model_overide_args["model_max_length"] = 4096 * 2
19
+
20
+ parser = argparse.ArgumentParser()
21
+ ServerArgs.add_cli_args(parser)
22
+ args = parser.parse_args()
23
+
24
+ if "34b" in args.model_path.lower():
25
+ model_overide_args["image_token_index"] = 64002
26
+
27
+ server_args = ServerArgs.from_cli_args(args)
28
+
29
+ pipe_reader, pipe_writer = mp.Pipe(duplex=False)
30
+
31
+ launch_server(server_args, pipe_writer, model_overide_args)
@@ -7,6 +7,7 @@ class FSMCache(BaseCache):
7
7
  super().__init__(enable=enable)
8
8
 
9
9
  from importlib.metadata import version
10
+
10
11
  if version("outlines") >= "0.0.35":
11
12
  from transformers import AutoTokenizer
12
13
 
@@ -1,4 +1,5 @@
1
1
  import interegular
2
+
2
3
  from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
3
4
  from sglang.srt.constrained.base_cache import BaseCache
4
5
 
@@ -4,7 +4,7 @@ import dataclasses
4
4
  from enum import IntEnum, auto
5
5
  from typing import Dict, List, Optional, Tuple, Union
6
6
 
7
- from sglang.srt.managers.openai_protocol import ChatCompletionRequest
7
+ from sglang.srt.openai_protocol import ChatCompletionRequest
8
8
 
9
9
 
10
10
  class SeparatorStyle(IntEnum):
@@ -400,7 +400,7 @@ register_conv_template(
400
400
  Conversation(
401
401
  name="chatml",
402
402
  system_template="<|im_start|>system\n{system_message}",
403
- system_message="You are an AI assistant.",
403
+ system_message="You are a helpful assistant.",
404
404
  roles=("<|im_start|>user", "<|im_start|>assistant"),
405
405
  sep_style=SeparatorStyle.CHATML,
406
406
  sep="<|im_end|>",
@@ -0,0 +1,16 @@
1
+ """
2
+ Usage:
3
+ python3 -m sglang.srt.flush_cache --url http://localhost:30000
4
+ """
5
+
6
+ import argparse
7
+
8
+ import requests
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
13
+ args = parser.parse_args()
14
+
15
+ response = requests.get(args.url + "/flush_cache")
16
+ assert response.status_code == 200
@@ -6,7 +6,6 @@ import warnings
6
6
  from typing import List, Optional, Tuple, Union
7
7
 
8
8
  from huggingface_hub import snapshot_download
9
- from sglang.srt.utils import is_multimodal_model
10
9
  from transformers import (
11
10
  AutoConfig,
12
11
  AutoProcessor,
@@ -15,6 +14,8 @@ from transformers import (
15
14
  PreTrainedTokenizerFast,
16
15
  )
17
16
 
17
+ from sglang.srt.utils import is_multimodal_model
18
+
18
19
 
19
20
  def download_from_hf(model_path: str):
20
21
  if os.path.exists(model_path):
@@ -29,10 +30,17 @@ def get_config_json(model_path: str):
29
30
  return config
30
31
 
31
32
 
32
- def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
33
+ def get_config(
34
+ model: str,
35
+ trust_remote_code: bool,
36
+ revision: Optional[str] = None,
37
+ model_overide_args: Optional[dict] = None,
38
+ ):
33
39
  config = AutoConfig.from_pretrained(
34
40
  model, trust_remote_code=trust_remote_code, revision=revision
35
41
  )
42
+ if model_overide_args:
43
+ config.update(model_overide_args)
36
44
  return config
37
45
 
38
46
 
@@ -3,6 +3,7 @@
3
3
  import torch
4
4
  import triton
5
5
  import triton.language as tl
6
+
6
7
  from sglang.srt.utils import wrap_kernel_launcher
7
8
 
8
9
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import triton
3
3
  import triton.language as tl
4
+
4
5
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
5
6
  from sglang.srt.utils import wrap_kernel_launcher
6
7
 
@@ -1,11 +1,12 @@
1
1
  import torch
2
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
3
2
  from torch import nn
4
- from vllm.model_executor.parallel_utils.communication_op import (
3
+ from vllm.distributed import (
5
4
  get_tensor_model_parallel_world_size,
6
5
  tensor_model_parallel_all_gather,
7
6
  )
8
7
 
8
+ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
9
+
9
10
 
10
11
  class LogitsProcessor(nn.Module):
11
12
  def __init__(self, config):
@@ -13,76 +14,136 @@ class LogitsProcessor(nn.Module):
13
14
  self.config = config
14
15
  self.tp_size = get_tensor_model_parallel_world_size()
15
16
 
16
- def forward(self, input_ids, hidden_states, weight, input_metadata):
17
- last_index = None
17
+ def _get_normalized_prompt_logprobs(
18
+ self, prefill_token_logprobs, input_metadata: InputMetadata
19
+ ):
20
+ logprobs_cumsum = torch.cumsum(
21
+ prefill_token_logprobs, dim=0, dtype=torch.float32
22
+ )
18
23
 
19
- # Compute the last index (the first decode token) of each requeast
20
- # if we are in prefill or extend mode.
24
+ start = input_metadata.extend_start_loc.clone()
25
+ end = start + input_metadata.extend_seq_lens - 2
26
+ start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
27
+ end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
28
+ sum_logp = (
29
+ logprobs_cumsum[end]
30
+ - logprobs_cumsum[start]
31
+ + prefill_token_logprobs[start]
32
+ )
33
+ normalized_prompt_logprobs = sum_logp / (
34
+ (input_metadata.extend_seq_lens - 1).clamp(min=1)
35
+ )
36
+
37
+ return normalized_prompt_logprobs
38
+
39
+ def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
40
+ if input_metadata.forward_mode == ForwardMode.DECODE:
41
+ decode_top_logprobs = []
42
+ for i in range(all_logprobs.shape[0]):
43
+ k = input_metadata.top_logprobs_nums[i]
44
+ t = all_logprobs[i].topk(k)
45
+ v_cpu = t.values.tolist()
46
+ p_cpu = t.indices.tolist()
47
+ decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
48
+ return None, decode_top_logprobs
49
+ else:
50
+ prefill_top_logprobs, decode_top_logprobs = [], []
51
+ pt = 0
52
+ # NOTE: the GPU-CPU overhead can be reduced
53
+ extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
54
+ for i in range(len(extend_seq_lens_cpu)):
55
+ if extend_seq_lens_cpu[i] == 0:
56
+ prefill_top_logprobs.append([])
57
+ decode_top_logprobs.append([])
58
+ continue
59
+ k = input_metadata.top_logprobs_nums[i]
60
+ t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
61
+ vs_cpu = t.values.tolist()
62
+ ps_cpu = t.indices.tolist()
63
+ prefill_top_logprobs.append(
64
+ [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
65
+ )
66
+ decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
67
+ pt += extend_seq_lens_cpu[i]
68
+ return prefill_top_logprobs, decode_top_logprobs
69
+
70
+ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
71
+ # Get last index for next token prediction, except for DECODE mode.
72
+ last_index = None
21
73
  if input_metadata.forward_mode != ForwardMode.DECODE:
22
74
  last_index = (
23
- torch.cumsum(
24
- input_metadata.seq_lens - input_metadata.prefix_lens,
25
- dim=0,
26
- dtype=torch.long,
27
- )
75
+ torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
28
76
  - 1
29
77
  )
30
78
 
79
+ # Get the last hidden states and last logits
80
+ if input_metadata.forward_mode == ForwardMode.DECODE:
81
+ last_hidden = hidden_states
82
+ else:
83
+ last_hidden = hidden_states[last_index]
84
+
85
+ last_logits = torch.matmul(last_hidden, weight.T)
86
+ if self.tp_size > 1:
87
+ last_logits = tensor_model_parallel_all_gather(last_logits)
88
+ last_logits = last_logits[:, : self.config.vocab_size]
89
+
90
+ # Return only last_logits if logprob is not requested
31
91
  if not input_metadata.return_logprob:
32
- # When logprob is not requested, only compute the last logits.
33
- if input_metadata.forward_mode == ForwardMode.DECODE:
34
- last_hidden = hidden_states
35
- else:
36
- last_hidden = hidden_states[last_index]
37
- hidden_states = None
38
-
39
- last_logits = torch.matmul(last_hidden, weight.T)
40
- if self.tp_size > 1:
41
- last_logits = tensor_model_parallel_all_gather(last_logits)
42
- last_logits = last_logits[:, : self.config.vocab_size]
43
- return last_logits, (None, None, None)
92
+ hidden_states = None
93
+ return last_logits, (None, None, None, None, None)
44
94
  else:
45
95
  # When logprob is requested, compute the logits for all tokens.
46
- logits = torch.matmul(hidden_states, weight.T)
47
- if self.tp_size > 1:
48
- logits = tensor_model_parallel_all_gather(logits)
49
- logits = logits[:, : self.config.vocab_size]
50
- all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
96
+ if input_metadata.forward_mode == ForwardMode.DECODE:
97
+ all_logits = last_logits
98
+ else:
99
+ all_logits = torch.matmul(hidden_states, weight.T)
100
+ if self.tp_size > 1:
101
+ all_logits = tensor_model_parallel_all_gather(all_logits)
102
+ all_logits = all_logits[:, : self.config.vocab_size]
103
+
104
+ all_logprobs = all_logits.float()
105
+ del all_logits
106
+ all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
107
+
108
+ return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
109
+ if return_top_logprob:
110
+ prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
111
+ all_logprobs, input_metadata
112
+ )
113
+ else:
114
+ prefill_top_logprobs = decode_top_logprobs = None
51
115
 
52
116
  if input_metadata.forward_mode == ForwardMode.DECODE:
53
- last_logits = logits
54
117
  last_logprobs = all_logprobs
55
- prefill_logprobs = normalized_logprobs = None
118
+ return last_logits, (
119
+ None,
120
+ None,
121
+ None,
122
+ decode_top_logprobs,
123
+ last_logprobs,
124
+ )
56
125
  else:
57
126
  # Compute the logprobs for the last token of each request.
58
- last_logits = logits[last_index]
59
127
  last_logprobs = all_logprobs[last_index]
60
128
 
61
129
  # Compute the logprobs and normalized logprobs for the prefill tokens.
62
130
  # Note that we pad a zero at the end of each sequence for easy computation.
63
- prefill_logprobs = all_logprobs[
131
+ prefill_token_logprobs = all_logprobs[
64
132
  torch.arange(all_logprobs.shape[0], device="cuda"),
65
133
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
66
134
  ]
67
- logprobs_cumsum = torch.cumsum(
68
- prefill_logprobs, dim=0, dtype=torch.float32
69
- )
70
135
 
71
- start = input_metadata.extend_start_loc.clone()
72
- end = start + input_metadata.extend_seq_lens - 2
73
- start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
74
- end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
75
- sum_logp = (
76
- logprobs_cumsum[end]
77
- - logprobs_cumsum[start]
78
- + prefill_logprobs[start]
136
+ normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
137
+ prefill_token_logprobs, input_metadata
79
138
  )
80
- normalized_logprobs = sum_logp / (
81
- (input_metadata.extend_seq_lens - 1).clamp(min=1)
139
+ return last_logits, (
140
+ prefill_token_logprobs,
141
+ normalized_prompt_logprobs,
142
+ prefill_top_logprobs,
143
+ decode_top_logprobs,
144
+ last_logprobs,
82
145
  )
83
146
 
84
- return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
85
-
86
147
 
87
148
  if __name__ == "__main__":
88
149
  all_logprobs = torch.tensor(
@@ -93,23 +154,22 @@ if __name__ == "__main__":
93
154
  )
94
155
  seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
95
156
  input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
96
- logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
97
157
 
98
- logprobs = all_logprobs[
158
+ token_logprobs = all_logprobs[
99
159
  torch.arange(all_logprobs.shape[0], device="cuda"),
100
160
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
101
161
  ]
102
- logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
162
+ logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
103
163
 
104
164
  len_cumsum = torch.cumsum(seq_lens, dim=0)
105
165
  start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
106
166
  end = start + seq_lens - 2
107
- start.clamp_(min=0, max=logprobs.shape[0] - 1)
108
- end.clamp_(min=0, max=logprobs.shape[0] - 1)
109
- sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
167
+ start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
168
+ end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
169
+ sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
110
170
 
111
171
  # assert logprobs == [2, _, 2, 4, _]
112
- print("logprobs", logprobs)
172
+ print("token logprobs", token_logprobs)
113
173
  print("start", start)
114
174
  print("end", end)
115
175
  print("sum_logp", sum_logp)
@@ -1,9 +1,10 @@
1
1
  import torch
2
+ from torch import nn
3
+
2
4
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
3
5
  from sglang.srt.layers.extend_attention import extend_attention_fwd
4
6
  from sglang.srt.layers.token_attention import token_attention_fwd
5
7
  from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
6
- from torch import nn
7
8
 
8
9
 
9
10
  class RadixAttention(nn.Module):
@@ -4,6 +4,7 @@
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
+
7
8
  from sglang.srt.managers.router.model_runner import global_server_args_dict
8
9
  from sglang.srt.utils import wrap_kernel_launcher
9
10
 
@@ -3,6 +3,7 @@ import asyncio
3
3
  import uvloop
4
4
  import zmq
5
5
  import zmq.asyncio
6
+
6
7
  from sglang.srt.hf_transformers_utils import get_tokenizer
7
8
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
8
9
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -37,10 +38,13 @@ class DetokenizerManager:
37
38
  if isinstance(recv_obj, BatchTokenIDOut):
38
39
  output_tokens = recv_obj.output_tokens
39
40
 
40
- # TODO(lmzheng): handle skip_special_tokens per request
41
+ # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
41
42
  output_strs = self.tokenizer.batch_decode(
42
43
  output_tokens,
43
44
  skip_special_tokens=recv_obj.skip_special_tokens[0],
45
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
46
+ 0
47
+ ],
44
48
  )
45
49
 
46
50
  # Trim stop str
@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams
8
8
  @dataclass
9
9
  class GenerateReqInput:
10
10
  # The input prompt
11
- text: Union[List[str], str]
11
+ text: Optional[Union[List[str], str]] = None
12
+ # The token ids for text; one can either specify text or input_ids
13
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None
12
14
  # The image input
13
15
  image_data: Optional[Union[List[str], str]] = None
14
16
  # The sampling_params
@@ -19,13 +21,26 @@ class GenerateReqInput:
19
21
  return_logprob: Optional[Union[List[bool], bool]] = None
20
22
  # The start location of the prompt for return_logprob
21
23
  logprob_start_len: Optional[Union[List[int], int]] = None
24
+ # The number of top logprobs to return
25
+ top_logprobs_num: Optional[Union[List[int], int]] = None
22
26
  # Whether to detokenize tokens in logprobs
23
27
  return_text_in_logprobs: bool = False
24
28
  # Whether to stream output
25
29
  stream: bool = False
30
+ # TODO: make all parameters a Union[List[T], T] to allow for batched requests
26
31
 
27
32
  def post_init(self):
28
- is_single = isinstance(self.text, str)
33
+
34
+ if self.text is None:
35
+ assert self.input_ids is not None, "Either text or input_ids should be provided"
36
+ else:
37
+ assert self.input_ids is None, "Either text or input_ids should be provided"
38
+
39
+ if self.text is not None:
40
+ is_single = isinstance(self.text, str)
41
+ else:
42
+ is_single = isinstance(self.input_ids[0], int)
43
+ self.is_single = is_single
29
44
 
30
45
  if is_single:
31
46
  if self.sampling_params is None:
@@ -36,8 +51,10 @@ class GenerateReqInput:
36
51
  self.return_logprob = False
37
52
  if self.logprob_start_len is None:
38
53
  self.logprob_start_len = 0
54
+ if self.top_logprobs_num is None:
55
+ self.top_logprobs_num = 0
39
56
  else:
40
- num = len(self.text)
57
+ num = len(self.text) if self.text is not None else len(self.input_ids)
41
58
 
42
59
  if self.image_data is None:
43
60
  self.image_data = [None] * num
@@ -64,6 +81,11 @@ class GenerateReqInput:
64
81
  elif not isinstance(self.logprob_start_len, list):
65
82
  self.logprob_start_len = [self.logprob_start_len] * num
66
83
 
84
+ if self.top_logprobs_num is None:
85
+ self.top_logprobs_num = [0] * num
86
+ elif not isinstance(self.top_logprobs_num, list):
87
+ self.top_logprobs_num = [self.top_logprobs_num] * num
88
+
67
89
 
68
90
  @dataclass
69
91
  class TokenizedGenerateReqInput:
@@ -76,6 +98,7 @@ class TokenizedGenerateReqInput:
76
98
  sampling_params: SamplingParams
77
99
  return_logprob: bool
78
100
  logprob_start_len: int
101
+ top_logprobs_num: int
79
102
  stream: bool
80
103
 
81
104
 
@@ -86,6 +109,7 @@ class BatchTokenIDOut:
86
109
  output_and_jump_forward_strs: List[str]
87
110
  hit_stop_str: List[Optional[str]]
88
111
  skip_special_tokens: List[bool]
112
+ spaces_between_special_tokens: List[bool]
89
113
  meta_info: List[Dict]
90
114
  finished: List[bool]
91
115