sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +4 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/bench_latency.py +299 -0
  6. sglang/global_config.py +4 -1
  7. sglang/lang/compiler.py +2 -2
  8. sglang/lang/interpreter.py +1 -1
  9. sglang/lang/ir.py +15 -5
  10. sglang/launch_server.py +4 -1
  11. sglang/launch_server_llavavid.py +2 -1
  12. sglang/srt/constrained/__init__.py +13 -6
  13. sglang/srt/constrained/fsm_cache.py +6 -3
  14. sglang/srt/constrained/jump_forward.py +113 -25
  15. sglang/srt/conversation.py +2 -0
  16. sglang/srt/flush_cache.py +2 -0
  17. sglang/srt/hf_transformers_utils.py +64 -9
  18. sglang/srt/layers/fused_moe.py +186 -89
  19. sglang/srt/layers/logits_processor.py +53 -25
  20. sglang/srt/layers/radix_attention.py +34 -7
  21. sglang/srt/managers/controller/dp_worker.py +6 -3
  22. sglang/srt/managers/controller/infer_batch.py +142 -67
  23. sglang/srt/managers/controller/manager_multi.py +5 -5
  24. sglang/srt/managers/controller/manager_single.py +8 -3
  25. sglang/srt/managers/controller/model_runner.py +154 -54
  26. sglang/srt/managers/controller/radix_cache.py +4 -0
  27. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  28. sglang/srt/managers/controller/tp_worker.py +140 -135
  29. sglang/srt/managers/detokenizer_manager.py +15 -19
  30. sglang/srt/managers/io_struct.py +10 -4
  31. sglang/srt/managers/tokenizer_manager.py +14 -13
  32. sglang/srt/model_config.py +83 -4
  33. sglang/srt/models/chatglm.py +399 -0
  34. sglang/srt/models/commandr.py +2 -2
  35. sglang/srt/models/dbrx.py +1 -1
  36. sglang/srt/models/gemma.py +5 -1
  37. sglang/srt/models/grok.py +204 -137
  38. sglang/srt/models/llama2.py +11 -4
  39. sglang/srt/models/llama_classification.py +104 -0
  40. sglang/srt/models/llava.py +11 -8
  41. sglang/srt/models/llavavid.py +1 -1
  42. sglang/srt/models/mixtral.py +164 -115
  43. sglang/srt/models/mixtral_quant.py +0 -1
  44. sglang/srt/models/qwen.py +1 -1
  45. sglang/srt/models/qwen2.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/models/yivl.py +2 -2
  48. sglang/srt/openai_api_adapter.py +33 -23
  49. sglang/srt/openai_protocol.py +1 -1
  50. sglang/srt/server.py +60 -19
  51. sglang/srt/server_args.py +79 -44
  52. sglang/srt/utils.py +146 -37
  53. sglang/test/test_programs.py +28 -10
  54. sglang/utils.py +4 -3
  55. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
  56. sglang-0.1.18.dist-info/RECORD +78 -0
  57. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  58. sglang/srt/managers/router/infer_batch.py +0 -596
  59. sglang/srt/managers/router/manager.py +0 -82
  60. sglang/srt/managers/router/model_rpc.py +0 -818
  61. sglang/srt/managers/router/model_runner.py +0 -445
  62. sglang/srt/managers/router/radix_cache.py +0 -267
  63. sglang/srt/managers/router/scheduler.py +0 -59
  64. sglang-0.1.17.dist-info/RECORD +0 -81
  65. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  66. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,17 @@
1
1
  """Inference-only LLaVa model compatible with HuggingFace weights."""
2
2
 
3
- from typing import List, Iterable, Optional, Tuple
3
+ from typing import Iterable, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
7
  from torch import nn
8
- from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig
8
+ from transformers import (
9
+ CLIPVisionConfig,
10
+ CLIPVisionModel,
11
+ LlavaConfig,
12
+ MistralConfig,
13
+ Qwen2Config,
14
+ )
9
15
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
16
  from vllm.config import CacheConfig
11
17
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@@ -19,8 +25,8 @@ from sglang.srt.mm_utils import (
19
25
  unpad_image_shape,
20
26
  )
21
27
  from sglang.srt.models.llama2 import LlamaForCausalLM
22
- from sglang.srt.models.qwen2 import Qwen2ForCausalLM
23
28
  from sglang.srt.models.mistral import MistralForCausalLM
29
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
24
30
 
25
31
 
26
32
  class LlavaLlamaForCausalLM(nn.Module):
@@ -359,6 +365,7 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
359
365
 
360
366
  first_call = True
361
367
 
368
+
362
369
  def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
363
370
  batch_size = pixel_values.shape[0]
364
371
 
@@ -388,8 +395,4 @@ def monkey_path_clip_vision_embed_forward():
388
395
  )
389
396
 
390
397
 
391
- EntryClass = [
392
- LlavaLlamaForCausalLM,
393
- LlavaQwenForCausalLM,
394
- LlavaMistralForCausalLM
395
- ]
398
+ EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
@@ -1,6 +1,6 @@
1
1
  """Inference-only LLaVa video model compatible with HuggingFace weights."""
2
2
 
3
- from typing import List, Iterable, Optional, Tuple
3
+ from typing import Iterable, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
@@ -33,13 +33,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
33
  from vllm.model_executor.utils import set_weight_attrs
34
34
  from vllm.utils import print_warning_once
35
35
 
36
-
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
39
38
  from sglang.srt.managers.controller.model_runner import InputMetadata
40
39
 
41
40
 
42
-
43
41
  class MixtralMoE(nn.Module):
44
42
  """A tensor-parallel MoE implementation for Mixtral that shards each expert
45
43
  across all ranks.
@@ -76,32 +74,46 @@ class MixtralMoE(nn.Module):
76
74
  self.params_dtype = params_dtype
77
75
 
78
76
  # Gate always runs at half / full precision for now.
79
- self.gate = ReplicatedLinear(self.hidden_size,
80
- self.num_total_experts,
81
- bias=False,
82
- params_dtype=self.params_dtype,
83
- quant_config=None)
77
+ self.gate = ReplicatedLinear(
78
+ self.hidden_size,
79
+ self.num_total_experts,
80
+ bias=False,
81
+ params_dtype=self.params_dtype,
82
+ quant_config=None,
83
+ )
84
84
 
85
85
  if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
86
86
  params_dtype = torch.float8_e4m3fn
87
87
 
88
88
  self.w13_weight = nn.Parameter(
89
- torch.empty(self.num_total_experts,
90
- 2 * self.intermediate_size,
91
- self.hidden_size,
92
- dtype=params_dtype))
89
+ torch.empty(
90
+ self.num_total_experts,
91
+ 2 * self.intermediate_size,
92
+ self.hidden_size,
93
+ dtype=params_dtype,
94
+ )
95
+ )
93
96
  self.w2_weight = nn.Parameter(
94
- torch.empty(self.num_total_experts,
95
- self.hidden_size,
96
- self.intermediate_size,
97
- dtype=params_dtype))
98
-
99
- set_weight_attrs(self.w13_weight, {
100
- "weight_loader": self.weight_loader,
101
- })
102
- set_weight_attrs(self.w2_weight, {
103
- "weight_loader": self.weight_loader,
104
- })
97
+ torch.empty(
98
+ self.num_total_experts,
99
+ self.hidden_size,
100
+ self.intermediate_size,
101
+ dtype=params_dtype,
102
+ )
103
+ )
104
+
105
+ set_weight_attrs(
106
+ self.w13_weight,
107
+ {
108
+ "weight_loader": self.weight_loader,
109
+ },
110
+ )
111
+ set_weight_attrs(
112
+ self.w2_weight,
113
+ {
114
+ "weight_loader": self.weight_loader,
115
+ },
116
+ )
105
117
 
106
118
  # Used for fp8.
107
119
  self.w13_scale = None
@@ -111,46 +123,68 @@ class MixtralMoE(nn.Module):
111
123
 
112
124
  if self.use_fp8:
113
125
  # WEIGHT_SCALE (for fp8)
114
- self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
115
- dtype=torch.float32),
116
- requires_grad=False)
117
- self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
118
- dtype=torch.float32),
119
- requires_grad=False)
126
+ self.w13_scale = nn.Parameter(
127
+ torch.ones(self.num_total_experts, dtype=torch.float32),
128
+ requires_grad=False,
129
+ )
130
+ self.w2_scale = nn.Parameter(
131
+ torch.ones(self.num_total_experts, dtype=torch.float32),
132
+ requires_grad=False,
133
+ )
120
134
 
121
135
  # If loading fp8 checkpoint, pass the weight loaders.
122
136
  # If loading an fp16 checkpoint, do not (we will quantize in
123
137
  # process_weights_after_loading()
124
138
  if quant_config.is_checkpoint_fp8_serialized:
125
- set_weight_attrs(self.w13_scale, {
126
- "weight_loader": self.weight_loader,
127
- })
128
- set_weight_attrs(self.w2_scale, {
129
- "weight_loader": self.weight_loader,
130
- })
139
+ set_weight_attrs(
140
+ self.w13_scale,
141
+ {
142
+ "weight_loader": self.weight_loader,
143
+ },
144
+ )
145
+ set_weight_attrs(
146
+ self.w2_scale,
147
+ {
148
+ "weight_loader": self.weight_loader,
149
+ },
150
+ )
131
151
 
132
152
  # ACT_SCALE (for fp8)
133
153
  if quant_config.activation_scheme == "static":
134
154
  if not quant_config.is_checkpoint_fp8_serialized:
135
155
  raise ValueError(
136
156
  "Found static activation scheme for checkpoint that "
137
- "was not serialized fp8.")
138
- self.a13_scale = nn.Parameter(torch.zeros(
139
- self.num_total_experts, dtype=torch.float32),
140
- requires_grad=False)
141
- self.a2_scale = nn.Parameter(torch.zeros(
142
- self.num_total_experts, dtype=torch.float32),
143
- requires_grad=False)
144
-
145
- set_weight_attrs(self.a13_scale, {
146
- "weight_loader": self.weight_loader,
147
- })
148
- set_weight_attrs(self.a2_scale, {
149
- "weight_loader": self.weight_loader,
150
- })
151
-
152
- def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
153
- weight_name: str, expert_id: int):
157
+ "was not serialized fp8."
158
+ )
159
+ self.a13_scale = nn.Parameter(
160
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
161
+ requires_grad=False,
162
+ )
163
+ self.a2_scale = nn.Parameter(
164
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
165
+ requires_grad=False,
166
+ )
167
+
168
+ set_weight_attrs(
169
+ self.a13_scale,
170
+ {
171
+ "weight_loader": self.weight_loader,
172
+ },
173
+ )
174
+ set_weight_attrs(
175
+ self.a2_scale,
176
+ {
177
+ "weight_loader": self.weight_loader,
178
+ },
179
+ )
180
+
181
+ def weight_loader(
182
+ self,
183
+ param: nn.Parameter,
184
+ loaded_weight: torch.Tensor,
185
+ weight_name: str,
186
+ expert_id: int,
187
+ ):
154
188
  tp_rank = get_tensor_model_parallel_rank()
155
189
  param_data = param.data
156
190
  shard_size = self.intermediate_size
@@ -158,8 +192,9 @@ class MixtralMoE(nn.Module):
158
192
  if weight_name.endswith("w1.weight"):
159
193
  param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
160
194
  if weight_name.endswith("w3.weight"):
161
- param_data[expert_id,
162
- shard_size:2 * shard_size, :] = loaded_weight[shard, :]
195
+ param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
196
+ shard, :
197
+ ]
163
198
  if weight_name.endswith("w2.weight"):
164
199
  param_data[expert_id, :, :] = loaded_weight[:, shard]
165
200
  if "act_scale" in weight_name or "weight_scale" in weight_name:
@@ -172,17 +207,17 @@ class MixtralMoE(nn.Module):
172
207
 
173
208
  # If checkpoint is fp16, quantize here.
174
209
  if not self.quant_config.is_checkpoint_fp8_serialized:
175
- w13_weight = torch.empty_like(self.w13_weight.data,
176
- dtype=torch.float8_e4m3fn)
177
- w2_weight = torch.empty_like(self.w2_weight.data,
178
- dtype=torch.float8_e4m3fn)
210
+ w13_weight = torch.empty_like(
211
+ self.w13_weight.data, dtype=torch.float8_e4m3fn
212
+ )
213
+ w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
179
214
  for expert in range(self.num_total_experts):
180
- w13_weight[expert, :, :], self.w13_scale[
181
- expert] = ops.scaled_fp8_quant(
182
- self.w13_weight.data[expert, :, :])
183
- w2_weight[expert, :, :], self.w2_scale[
184
- expert] = ops.scaled_fp8_quant(
185
- self.w2_weight.data[expert, :, :])
215
+ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
216
+ self.w13_weight.data[expert, :, :]
217
+ )
218
+ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
219
+ self.w2_weight.data[expert, :, :]
220
+ )
186
221
  self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
187
222
  self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
188
223
 
@@ -193,40 +228,40 @@ class MixtralMoE(nn.Module):
193
228
  if self.a13_scale is None or self.a2_scale is None:
194
229
  raise ValueError(
195
230
  "QuantConfig has static quantization, but found "
196
- "activation scales are None.")
231
+ "activation scales are None."
232
+ )
197
233
 
198
- if (not all_close_1d(self.a13_scale)
199
- or not all_close_1d(self.a2_scale)):
234
+ if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
200
235
  print_warning_once(
201
236
  "Found act_scales that are not equal for fp8 MoE layer. "
202
- "Using the maximum across experts for each layer. ")
237
+ "Using the maximum across experts for each layer. "
238
+ )
203
239
 
204
- self.a13_scale = nn.Parameter(self.a13_scale.max(),
205
- requires_grad=False)
206
- self.a2_scale = nn.Parameter(self.a2_scale.max(),
207
- requires_grad=False)
240
+ self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
241
+ self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
208
242
 
209
243
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
244
  num_tokens, hidden_size = hidden_states.shape
211
245
  hidden_states = hidden_states.view(-1, self.hidden_size)
212
246
  # router_logits: (num_tokens, n_experts)
213
247
  router_logits, _ = self.gate(hidden_states)
214
- final_hidden_states = fused_moe(hidden_states,
215
- self.w13_weight,
216
- self.w2_weight,
217
- router_logits,
218
- self.top_k,
219
- renormalize=True,
220
- inplace=True,
221
- use_fp8=self.use_fp8,
222
- w1_scale=self.w13_scale,
223
- w2_scale=self.w2_scale,
224
- a1_scale=self.a13_scale,
225
- a2_scale=self.a2_scale)
248
+ final_hidden_states = fused_moe(
249
+ hidden_states,
250
+ self.w13_weight,
251
+ self.w2_weight,
252
+ router_logits,
253
+ self.top_k,
254
+ renormalize=True,
255
+ inplace=True,
256
+ use_fp8=self.use_fp8,
257
+ w1_scale=self.w13_scale,
258
+ w2_scale=self.w2_scale,
259
+ a1_scale=self.a13_scale,
260
+ a2_scale=self.a2_scale,
261
+ )
226
262
 
227
263
  if self.tp_size > 1:
228
- final_hidden_states = tensor_model_parallel_all_reduce(
229
- final_hidden_states)
264
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
230
265
 
231
266
  return final_hidden_states.view(num_tokens, hidden_size)
232
267
 
@@ -335,7 +370,8 @@ class MixtralDecoderLayer(nn.Module):
335
370
  top_k=config.num_experts_per_tok,
336
371
  hidden_size=config.hidden_size,
337
372
  intermediate_size=config.intermediate_size,
338
- quant_config=quant_config)
373
+ quant_config=quant_config,
374
+ )
339
375
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
340
376
  self.post_attention_layernorm = RMSNorm(
341
377
  config.hidden_size, eps=config.rms_norm_eps
@@ -444,35 +480,48 @@ class MixtralForCausalLM(nn.Module):
444
480
  ("qkv_proj", "v_proj", "v"),
445
481
  ]
446
482
 
447
- expert_params_mapping = [
448
- # These are the weight scales for the experts
449
- # (param_name, weight_name, expert_id)
450
- ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
451
- f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
452
- for expert_id in range(self.config.num_local_experts)
453
- for weight_name in ["w1", "w2", "w3"]
454
- ] + [
455
- # These are the weights for the experts
456
- # (param_name, weight_name, expert_id)
457
- ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
458
- f"experts.{expert_id}.{weight_name}.weight", expert_id)
459
- for expert_id in range(self.config.num_local_experts)
460
- for weight_name in ["w1", "w2", "w3"]
461
- ] + [
462
- # These are the activation scales for the experts
463
- # (param_name, weight_name, expert_id)
464
- ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
465
- f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
466
- for expert_id in range(self.config.num_local_experts)
467
- for weight_name in ["w1", "w2", "w3"]
468
- ]
483
+ expert_params_mapping = (
484
+ [
485
+ # These are the weight scales for the experts
486
+ # (param_name, weight_name, expert_id)
487
+ (
488
+ "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
489
+ f"experts.{expert_id}.{weight_name}.weight_scale",
490
+ expert_id,
491
+ )
492
+ for expert_id in range(self.config.num_local_experts)
493
+ for weight_name in ["w1", "w2", "w3"]
494
+ ]
495
+ + [
496
+ # These are the weights for the experts
497
+ # (param_name, weight_name, expert_id)
498
+ (
499
+ "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
500
+ f"experts.{expert_id}.{weight_name}.weight",
501
+ expert_id,
502
+ )
503
+ for expert_id in range(self.config.num_local_experts)
504
+ for weight_name in ["w1", "w2", "w3"]
505
+ ]
506
+ + [
507
+ # These are the activation scales for the experts
508
+ # (param_name, weight_name, expert_id)
509
+ (
510
+ "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
511
+ f"experts.{expert_id}.{weight_name}.act_scale",
512
+ expert_id,
513
+ )
514
+ for expert_id in range(self.config.num_local_experts)
515
+ for weight_name in ["w1", "w2", "w3"]
516
+ ]
517
+ )
469
518
 
470
519
  params_dict = dict(self.named_parameters())
471
520
  for name, loaded_weight in weights:
472
521
  if "rotary_emb.inv_freq" in name:
473
522
  continue
474
523
 
475
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
524
+ for param_name, weight_name, shard_id in stacked_params_mapping:
476
525
  if weight_name not in name:
477
526
  continue
478
527
  name = name.replace(weight_name, param_name)
@@ -490,18 +539,18 @@ class MixtralForCausalLM(nn.Module):
490
539
  name = name.replace(weight_name, param_name)
491
540
  param = params_dict[name]
492
541
  weight_loader = param.weight_loader
493
- weight_loader(param,
494
- loaded_weight,
495
- weight_name,
496
- expert_id=expert_id)
542
+ weight_loader(
543
+ param, loaded_weight, weight_name, expert_id=expert_id
544
+ )
497
545
  break
498
546
  else:
499
547
  # Skip loading extra bias for GPTQ models.
500
548
  if name.endswith(".bias") and name not in params_dict:
501
549
  continue
502
550
  param = params_dict[name]
503
- weight_loader = getattr(param, "weight_loader",
504
- default_weight_loader)
551
+ weight_loader = getattr(
552
+ param, "weight_loader", default_weight_loader
553
+ )
505
554
  weight_loader(param, loaded_weight)
506
555
 
507
556
 
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
28
28
  )
29
29
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
30
 
31
-
32
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
32
  from sglang.srt.layers.radix_attention import RadixAttention
34
33
  from sglang.srt.managers.controller.model_runner import InputMetadata
sglang/srt/models/qwen.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
3
- from typing import Any, Dict, Optional, Iterable, Tuple
3
+ from typing import Any, Dict, Iterable, Optional, Tuple
4
4
 
5
5
  import torch
6
6
  from torch import nn
@@ -1,7 +1,7 @@
1
1
  # Adapted from llama2.py
2
2
  # Modify details for the adaptation of Qwen2 model.
3
3
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, Optional, Tuple, Iterable
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from torch import nn
@@ -2,7 +2,7 @@
2
2
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
3
3
  """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
4
4
  model compatible with HuggingFace weights."""
5
- from typing import Optional, Tuple, Iterable
5
+ from typing import Iterable, Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from torch import nn
sglang/srt/models/yivl.py CHANGED
@@ -1,14 +1,14 @@
1
1
  """Inference-only Yi-VL model."""
2
2
 
3
- from typing import Tuple, Iterable, Optional
3
+ from typing import Iterable, Optional, Tuple
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
7
7
  from transformers import CLIPVisionModel, LlavaConfig
8
8
  from vllm.config import CacheConfig
9
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
9
10
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
10
11
 
11
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
12
  from sglang.srt.models.llava import (
13
13
  LlavaLlamaForCausalLM,
14
14
  monkey_path_clip_vision_embed_forward,
@@ -6,7 +6,7 @@ import os
6
6
  from http import HTTPStatus
7
7
 
8
8
  from fastapi import Request
9
- from fastapi.responses import StreamingResponse, JSONResponse
9
+ from fastapi.responses import JSONResponse, StreamingResponse
10
10
 
11
11
  from sglang.srt.conversation import (
12
12
  Conversation,
@@ -40,21 +40,18 @@ chat_template_name = None
40
40
  def create_error_response(
41
41
  message: str,
42
42
  err_type: str = "BadRequestError",
43
- status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
44
- error = ErrorResponse(message=message,
45
- type=err_type,
46
- code=status_code.value)
47
- return JSONResponse(content=error.model_dump(),
48
- status_code=error.code)
43
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
44
+ ):
45
+ error = ErrorResponse(message=message, type=err_type, code=status_code.value)
46
+ return JSONResponse(content=error.model_dump(), status_code=error.code)
49
47
 
50
48
 
51
49
  def create_streaming_error_response(
52
50
  message: str,
53
51
  err_type: str = "BadRequestError",
54
- status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
55
- error = ErrorResponse(message=message,
56
- type=err_type,
57
- code=status_code.value)
52
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
53
+ ) -> str:
54
+ error = ErrorResponse(message=message, type=err_type, code=status_code.value)
58
55
  json_str = json.dumps({"error": error.model_dump()})
59
56
  return json_str
60
57
 
@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
125
122
  n_prev_token = 0
126
123
  try:
127
124
  async for content in tokenizer_manager.generate_request(
128
- adapted_request, raw_request):
125
+ adapted_request, raw_request
126
+ ):
129
127
  text = content["text"]
130
128
  prompt_tokens = content["meta_info"]["prompt_tokens"]
131
129
  completion_tokens = content["meta_info"]["completion_tokens"]
@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
154
152
  decode_token_logprobs=content["meta_info"][
155
153
  "decode_token_logprobs"
156
154
  ][n_prev_token:],
157
- decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
158
- n_prev_token:
159
- ],
155
+ decode_top_logprobs=content["meta_info"][
156
+ "decode_top_logprobs"
157
+ ][n_prev_token:],
160
158
  )
161
159
 
162
- n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
160
+ n_prev_token = len(
161
+ content["meta_info"]["decode_token_logprobs"]
162
+ )
163
163
  else:
164
164
  logprobs = None
165
165
 
@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
188
188
  yield f"data: {error}\n\n"
189
189
  yield "data: [DONE]\n\n"
190
190
 
191
- return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
192
- background=tokenizer_manager.create_abort_task(adapted_request))
191
+ return StreamingResponse(
192
+ generate_stream_resp(),
193
+ media_type="text/event-stream",
194
+ background=tokenizer_manager.create_abort_task(adapted_request),
195
+ )
193
196
 
194
197
  # Non-streaming response.
195
198
  try:
196
199
  ret = await tokenizer_manager.generate_request(
197
- adapted_request, raw_request).__anext__()
200
+ adapted_request, raw_request
201
+ ).__anext__()
198
202
  except ValueError as e:
199
203
  return create_error_response(str(e))
200
204
 
@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
299
303
 
300
304
  stream_buffer = ""
301
305
  try:
302
- async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
306
+ async for content in tokenizer_manager.generate_request(
307
+ adapted_request, raw_request
308
+ ):
303
309
  if is_first:
304
310
  # First chunk with role
305
311
  is_first = False
@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
334
340
  yield f"data: {error}\n\n"
335
341
  yield "data: [DONE]\n\n"
336
342
 
337
- return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
338
- background=tokenizer_manager.create_abort_task(adapted_request))
343
+ return StreamingResponse(
344
+ generate_stream_resp(),
345
+ media_type="text/event-stream",
346
+ background=tokenizer_manager.create_abort_task(adapted_request),
347
+ )
339
348
 
340
349
  # Non-streaming response.
341
350
  try:
342
351
  ret = await tokenizer_manager.generate_request(
343
- adapted_request, raw_request).__anext__()
352
+ adapted_request, raw_request
353
+ ).__anext__()
344
354
  except ValueError as e:
345
355
  return create_error_response(str(e))
346
356
 
@@ -1,4 +1,4 @@
1
- """pydantic models for OpenAI API protocol"""
1
+ """Pydantic models for OpenAI API protocol"""
2
2
 
3
3
  import time
4
4
  from typing import Dict, List, Optional, Union