sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -23,13 +23,16 @@ _SAMPLING_EPS = 1e-6
23
23
  class SamplingParams:
24
24
  def __init__(
25
25
  self,
26
- max_new_tokens: int = 16,
26
+ max_new_tokens: int = 128,
27
+ min_new_tokens: int = 0,
27
28
  stop: Optional[Union[str, List[str]]] = None,
29
+ stop_token_ids: Optional[List[int]] = [],
28
30
  temperature: float = 1.0,
29
31
  top_p: float = 1.0,
30
32
  top_k: int = -1,
31
33
  frequency_penalty: float = 0.0,
32
34
  presence_penalty: float = 0.0,
35
+ repetition_penalty: float = 1.0,
33
36
  ignore_eos: bool = False,
34
37
  skip_special_tokens: bool = True,
35
38
  spaces_between_special_tokens: bool = True,
@@ -42,8 +45,11 @@ class SamplingParams:
42
45
  self.top_k = top_k
43
46
  self.frequency_penalty = frequency_penalty
44
47
  self.presence_penalty = presence_penalty
48
+ self.repetition_penalty = repetition_penalty
45
49
  self.stop_strs = stop
50
+ self.stop_token_ids = {*stop_token_ids}
46
51
  self.max_new_tokens = max_new_tokens
52
+ self.min_new_tokens = min_new_tokens
47
53
  self.ignore_eos = ignore_eos
48
54
  self.skip_special_tokens = skip_special_tokens
49
55
  self.spaces_between_special_tokens = spaces_between_special_tokens
@@ -80,23 +86,44 @@ class SamplingParams:
80
86
  raise ValueError(
81
87
  "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
82
88
  )
89
+ if not 0.0 <= self.repetition_penalty <= 2.0:
90
+ raise ValueError(
91
+ "repetition_penalty must be in (0, 2], got "
92
+ f"{self.repetition_penalty}."
93
+ )
94
+ if not 0 <= self.min_new_tokens:
95
+ raise ValueError(
96
+ f"min_new_tokens must be in (0, max_new_tokens], got "
97
+ f"{self.min_new_tokens}."
98
+ )
83
99
  if self.max_new_tokens is not None:
84
100
  if self.max_new_tokens < 0:
85
101
  raise ValueError(
86
102
  f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
87
103
  )
104
+ if not self.min_new_tokens <= self.max_new_tokens:
105
+ raise ValueError(
106
+ f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
107
+ f"{self.min_new_tokens}."
108
+ )
88
109
 
89
110
  def normalize(self, tokenizer):
90
111
  # Process stop strings
91
112
  if self.stop_strs is None:
92
113
  self.stop_strs = []
93
- self.stop_str_max_len = 0
114
+ if self.stop_token_ids is None:
115
+ self.stop_str_max_len = 0
116
+ else:
117
+ self.stop_str_max_len = 1
94
118
  else:
95
119
  if isinstance(self.stop_strs, str):
96
120
  self.stop_strs = [self.stop_strs]
97
121
 
98
122
  stop_str_max_len = 0
99
123
  for stop_str in self.stop_strs:
100
- stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
101
- stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
124
+ if tokenizer is not None:
125
+ stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
126
+ stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
127
+ else:
128
+ stop_str_max_len = max(stop_str_max_len, len(stop_str))
102
129
  self.stop_str_max_len = stop_str_max_len
sglang/srt/server.py CHANGED
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
52
52
  start_controller_process as start_controller_process_single,
53
53
  )
54
54
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
55
- from sglang.srt.managers.io_struct import GenerateReqInput
55
+ from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
56
56
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
57
57
  from sglang.srt.openai_api.adapter import (
58
58
  load_chat_template_for_openai_api,
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
60
60
  v1_chat_completions,
61
61
  v1_completions,
62
62
  v1_delete_file,
63
+ v1_embeddings,
63
64
  v1_files_create,
64
65
  v1_retrieve_batch,
65
66
  v1_retrieve_file,
@@ -74,7 +75,8 @@ from sglang.srt.utils import (
74
75
  enable_show_time_cost,
75
76
  kill_child_process,
76
77
  maybe_set_triton_cache_manager,
77
- set_torch_compile_config,
78
+ prepare_model,
79
+ prepare_tokenizer,
78
80
  set_ulimit,
79
81
  )
80
82
  from sglang.utils import get_exception_traceback
@@ -98,6 +100,7 @@ async def health() -> Response:
98
100
  async def get_model_info():
99
101
  result = {
100
102
  "model_path": tokenizer_manager.model_path,
103
+ "is_generation": tokenizer_manager.is_generation,
101
104
  }
102
105
  return result
103
106
 
@@ -149,6 +152,21 @@ app.post("/generate")(generate_request)
149
152
  app.put("/generate")(generate_request)
150
153
 
151
154
 
155
+ async def encode_request(obj: EmbeddingReqInput, request: Request):
156
+ """Handle an embedding request."""
157
+ try:
158
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
159
+ return ret
160
+ except ValueError as e:
161
+ return JSONResponse(
162
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
163
+ )
164
+
165
+
166
+ app.post("/encode")(encode_request)
167
+ app.put("/encode")(encode_request)
168
+
169
+
152
170
  @app.post("/v1/completions")
153
171
  async def openai_v1_completions(raw_request: Request):
154
172
  return await v1_completions(tokenizer_manager, raw_request)
@@ -159,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
159
177
  return await v1_chat_completions(tokenizer_manager, raw_request)
160
178
 
161
179
 
180
+ @app.post("/v1/embeddings")
181
+ async def openai_v1_embeddings(raw_request: Request):
182
+ response = await v1_embeddings(tokenizer_manager, raw_request)
183
+ return response
184
+
185
+
162
186
  @app.get("/v1/models")
163
187
  def available_models():
164
188
  """Show available models."""
@@ -235,6 +259,10 @@ def launch_server(
235
259
  )
236
260
  logger.info(f"{server_args=}")
237
261
 
262
+ # Use model from www.modelscope.cn, first download the model.
263
+ server_args.model_path = prepare_model(server_args.model_path)
264
+ server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
265
+
238
266
  # Launch processes for multi-node tensor parallelism
239
267
  if server_args.nnodes > 1:
240
268
  if server_args.node_rank != 0:
@@ -347,10 +375,6 @@ def _set_envs_and_config(server_args: ServerArgs):
347
375
  # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
348
376
  maybe_set_triton_cache_manager()
349
377
 
350
- # Set torch compile config
351
- if server_args.enable_torch_compile:
352
- set_torch_compile_config()
353
-
354
378
  # Set global chat template
355
379
  if server_args.chat_template:
356
380
  # TODO: replace this with huggingface transformers template
@@ -360,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
360
384
  if not server_args.disable_flashinfer:
361
385
  assert_pkg_version(
362
386
  "flashinfer",
363
- "0.1.3",
387
+ "0.1.4",
364
388
  "Please uninstall the old version and "
365
389
  "reinstall the latest version by following the instructions "
366
390
  "at https://docs.flashinfer.ai/installation.html.",
@@ -385,6 +409,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
385
409
  except (AssertionError, requests.exceptions.RequestException) as e:
386
410
  last_traceback = get_exception_traceback()
387
411
  pass
412
+ model_info = res.json()
388
413
 
389
414
  if not success:
390
415
  if pipe_finish_writer is not None:
@@ -393,17 +418,24 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
393
418
  sys.exit(1)
394
419
 
395
420
  # Send a warmup request
421
+ request_name = "/generate" if model_info["is_generation"] else "/encode"
422
+ max_new_tokens = 8 if model_info["is_generation"] else 1
423
+ json_data = {
424
+ "sampling_params": {
425
+ "temperature": 0,
426
+ "max_new_tokens": max_new_tokens,
427
+ },
428
+ }
429
+ if server_args.skip_tokenizer_init:
430
+ json_data["input_ids"] = [10, 11, 12]
431
+ else:
432
+ json_data["text"] = "The capital city of France is"
433
+
396
434
  try:
397
435
  for _ in range(server_args.dp_size):
398
436
  res = requests.post(
399
- url + "/generate",
400
- json={
401
- "text": "The capital city of France is",
402
- "sampling_params": {
403
- "temperature": 0,
404
- "max_new_tokens": 8,
405
- },
406
- },
437
+ url + request_name,
438
+ json=json_data,
407
439
  headers=headers,
408
440
  timeout=600,
409
441
  )
@@ -415,6 +447,15 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
415
447
  print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
416
448
  sys.exit(1)
417
449
 
450
+ # Print warnings here
451
+ if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
452
+ logger.warning(
453
+ "You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
454
+ "This combination is an experimental feature and we noticed it can lead to "
455
+ "wrong generation results. If you want to use chunked prefill, it is recommended "
456
+ "not using `--disable-radix-cache`."
457
+ )
458
+
418
459
  logger.info("The server is fired up and ready to roll!")
419
460
  if pipe_finish_writer is not None:
420
461
  pipe_finish_writer.send("init ok")
@@ -534,5 +575,18 @@ class Runtime:
534
575
  )
535
576
  return json.dumps(response.json())
536
577
 
578
+ def encode(
579
+ self,
580
+ prompt: str,
581
+ ):
582
+ json_data = {
583
+ "text": prompt,
584
+ }
585
+ response = requests.post(
586
+ self.url + "/encode",
587
+ json=json_data,
588
+ )
589
+ return json.dumps(response.json())
590
+
537
591
  def __del__(self):
538
592
  self.shutdown()
sglang/srt/server_args.py CHANGED
@@ -27,6 +27,7 @@ class ServerArgs:
27
27
  model_path: str
28
28
  tokenizer_path: Optional[str] = None
29
29
  tokenizer_mode: str = "auto"
30
+ skip_tokenizer_init: bool = False
30
31
  load_format: str = "auto"
31
32
  dtype: str = "auto"
32
33
  trust_remote_code: bool = True
@@ -42,10 +43,11 @@ class ServerArgs:
42
43
 
43
44
  # Memory and scheduling
44
45
  mem_fraction_static: Optional[float] = None
45
- max_prefill_tokens: Optional[int] = None
46
46
  max_running_requests: Optional[int] = None
47
47
  max_num_reqs: Optional[int] = None
48
48
  max_total_tokens: Optional[int] = None
49
+ chunked_prefill_size: int = -1
50
+ max_prefill_tokens: int = 16384
49
51
  schedule_policy: str = "lpm"
50
52
  schedule_conservativeness: float = 1.0
51
53
 
@@ -62,15 +64,12 @@ class ServerArgs:
62
64
 
63
65
  # Other
64
66
  api_key: Optional[str] = None
65
- file_storage_pth: str = "SGlang_storage"
67
+ file_storage_pth: str = "SGLang_storage"
66
68
 
67
69
  # Data parallelism
68
70
  dp_size: int = 1
69
71
  load_balance_method: str = "round_robin"
70
72
 
71
- # Chunked Prefill
72
- chunked_prefill_size: Optional[int] = None
73
-
74
73
  # Optimization/debug options
75
74
  disable_flashinfer: bool = False
76
75
  disable_flashinfer_sampling: bool = False
@@ -96,6 +95,10 @@ class ServerArgs:
96
95
  if self.served_model_name is None:
97
96
  self.served_model_name = self.model_path
98
97
 
98
+ if self.chunked_prefill_size <= 0:
99
+ # Disable chunked prefill
100
+ self.chunked_prefill_size = None
101
+
99
102
  if self.mem_fraction_static is None:
100
103
  if self.tp_size >= 16:
101
104
  self.mem_fraction_static = 0.79
@@ -107,6 +110,7 @@ class ServerArgs:
107
110
  self.mem_fraction_static = 0.87
108
111
  else:
109
112
  self.mem_fraction_static = 0.88
113
+
110
114
  if isinstance(self.additional_ports, int):
111
115
  self.additional_ports = [self.additional_ports]
112
116
  elif self.additional_ports is None:
@@ -151,6 +155,11 @@ class ServerArgs:
151
155
  "tokenizer if available, and 'slow' will "
152
156
  "always use the slow tokenizer.",
153
157
  )
158
+ parser.add_argument(
159
+ "--skip-tokenizer-init",
160
+ action="store_true",
161
+ help="If set, skip init tokenizer and pass input_ids in generate request",
162
+ )
154
163
  parser.add_argument(
155
164
  "--load-format",
156
165
  type=str,
@@ -226,12 +235,6 @@ class ServerArgs:
226
235
  default=ServerArgs.mem_fraction_static,
227
236
  help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
228
237
  )
229
- parser.add_argument(
230
- "--max-prefill-tokens",
231
- type=int,
232
- default=ServerArgs.max_prefill_tokens,
233
- help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
234
- )
235
238
  parser.add_argument(
236
239
  "--max-running-requests",
237
240
  type=int,
@@ -250,6 +253,18 @@ class ServerArgs:
250
253
  default=ServerArgs.max_total_tokens,
251
254
  help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
252
255
  )
256
+ parser.add_argument(
257
+ "--chunked-prefill-size",
258
+ type=int,
259
+ default=ServerArgs.chunked_prefill_size,
260
+ help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
261
+ )
262
+ parser.add_argument(
263
+ "--max-prefill-tokens",
264
+ type=int,
265
+ default=ServerArgs.max_prefill_tokens,
266
+ help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
267
+ )
253
268
  parser.add_argument(
254
269
  "--schedule-policy",
255
270
  type=str,
@@ -347,14 +362,6 @@ class ServerArgs:
347
362
  )
348
363
  parser.add_argument("--node-rank", type=int, help="The node rank.")
349
364
 
350
- # Chunked prefill
351
- parser.add_argument(
352
- "--chunked-prefill-size",
353
- type=int,
354
- default=ServerArgs.chunked_prefill_size,
355
- help="The size of the chunked prefill.",
356
- )
357
-
358
365
  # Optimization/debug options
359
366
  parser.add_argument(
360
367
  "--disable-flashinfer",
sglang/srt/utils.py CHANGED
@@ -197,6 +197,8 @@ def allocate_init_ports(
197
197
  def get_int_token_logit_bias(tokenizer, vocab_size):
198
198
  """Get the logit bias for integer-only tokens."""
199
199
  # a bug when model's vocab size > tokenizer.vocab_size
200
+ if tokenizer == None:
201
+ return [-1e5] * vocab_size
200
202
  vocab_size = tokenizer.vocab_size
201
203
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
202
204
  for t_id in range(vocab_size):
@@ -223,6 +225,15 @@ def is_multimodal_model(model):
223
225
  raise ValueError("unrecognized type")
224
226
 
225
227
 
228
+ def is_generation_model(model_architectures):
229
+ if (
230
+ "LlamaEmbeddingModel" in model_architectures
231
+ or "MistralModel" in model_architectures
232
+ ):
233
+ return False
234
+ return True
235
+
236
+
226
237
  def decode_video_base64(video_base64):
227
238
  from PIL import Image
228
239
 
@@ -622,19 +633,6 @@ def receive_addrs(model_port_args, server_args):
622
633
  dist.destroy_process_group()
623
634
 
624
635
 
625
- def set_torch_compile_config():
626
- # The following configurations are for torch compile optimizations
627
- import torch._dynamo.config
628
- import torch._inductor.config
629
-
630
- torch._inductor.config.coordinate_descent_tuning = True
631
- torch._inductor.config.triton.unique_kernel_names = True
632
- torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
633
-
634
- # FIXME: tmp workaround
635
- torch._dynamo.config.accumulated_cache_size_limit = 256
636
-
637
-
638
636
  def set_ulimit(target_soft_limit=65535):
639
637
  resource_type = resource.RLIMIT_NOFILE
640
638
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -705,3 +703,23 @@ def add_api_key_middleware(app, api_key):
705
703
  if request.headers.get("Authorization") != "Bearer " + api_key:
706
704
  return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
707
705
  return await call_next(request)
706
+
707
+
708
+ def prepare_model(model_path):
709
+ if "SGLANG_USE_MODELSCOPE" in os.environ:
710
+ if not os.path.exists(model_path):
711
+ from modelscope import snapshot_download
712
+
713
+ return snapshot_download(model_path)
714
+ return model_path
715
+
716
+
717
+ def prepare_tokenizer(tokenizer_path):
718
+ if "SGLANG_USE_MODELSCOPE" in os.environ:
719
+ if not os.path.exists(tokenizer_path):
720
+ from modelscope import snapshot_download
721
+
722
+ return snapshot_download(
723
+ tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
724
+ )
725
+ return tokenizer_path
sglang/test/run_eval.py CHANGED
@@ -16,6 +16,8 @@ from sglang.test.simple_eval_common import (
16
16
 
17
17
 
18
18
  def run_eval(args):
19
+ set_ulimit()
20
+
19
21
  if "OPENAI_API_KEY" not in os.environ:
20
22
  os.environ["OPENAI_API_KEY"] = "EMPTY"
21
23
 
@@ -39,6 +41,14 @@ def run_eval(args):
39
41
  eval_obj = MathEval(
40
42
  filename, equality_checker, args.num_examples, args.num_threads
41
43
  )
44
+ elif args.eval_name == "mgsm":
45
+ from sglang.test.simple_eval_mgsm import MGSMEval
46
+
47
+ eval_obj = MGSMEval(args.num_examples, args.num_threads)
48
+ elif args.eval_name == "mgsm_en":
49
+ from sglang.test.simple_eval_mgsm import MGSMEval
50
+
51
+ eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"])
42
52
  elif args.eval_name == "gpqa":
43
53
  from sglang.test.simple_eval_gpqa import GPQAEval
44
54
 
@@ -109,7 +119,6 @@ if __name__ == "__main__":
109
119
  parser.add_argument("--eval-name", type=str, default="mmlu")
110
120
  parser.add_argument("--num-examples", type=int)
111
121
  parser.add_argument("--num-threads", type=int, default=512)
112
- set_ulimit()
113
122
  args = parser.parse_args()
114
123
 
115
124
  run_eval(args)
sglang/test/runners.py CHANGED
@@ -23,23 +23,19 @@ import torch.nn.functional as F
23
23
  from transformers import AutoModelForCausalLM, AutoTokenizer
24
24
 
25
25
  from sglang.srt.server import Runtime
26
+ from sglang.srt.utils import is_generation_model
26
27
 
27
28
  DEFAULT_PROMPTS = [
28
- "The capital of France is",
29
+ # the output of gemma-2-2b from SRT is unstable on the commented prompt
30
+ # "The capital of France is",
29
31
  "The capital of the United Kindom is",
30
32
  "Today is a sunny day and I like",
33
+ "AI is a field of computer science focused on",
31
34
  ]
32
35
 
33
36
  NUM_TOP_LOGPROBS = 5
34
37
 
35
38
 
36
- def is_embedding_model(model_path):
37
- # FIXME incomplete list
38
- if "e5-mistral-7b-instruct" in model_path.lower():
39
- return True
40
- return False
41
-
42
-
43
39
  def get_dtype_str(torch_dtype):
44
40
  if torch_dtype is torch.float16:
45
41
  return "float16"
@@ -49,10 +45,11 @@ def get_dtype_str(torch_dtype):
49
45
 
50
46
  @dataclass
51
47
  class ModelOutput:
52
- output_strs: str = None
53
- top_input_logprobs: torch.Tensor = None
54
- top_output_logprobs: torch.Tensor = None
55
- embed_logits: torch.Tensor = None
48
+ output_strs: List[str] = None
49
+ output_ids: List[int] = None
50
+ top_input_logprobs: List[torch.Tensor] = None
51
+ top_output_logprobs: List[torch.Tensor] = None
52
+ embed_logits: List[torch.Tensor] = None
56
53
 
57
54
 
58
55
  class HFRunner:
@@ -60,7 +57,7 @@ class HFRunner:
60
57
  self,
61
58
  model_path,
62
59
  torch_dtype=torch.float16,
63
- is_embedding_model=None,
60
+ is_generation_model=None,
64
61
  ):
65
62
  self.in_queue = multiprocessing.Queue()
66
63
  self.out_queue = multiprocessing.Queue()
@@ -72,13 +69,13 @@ class HFRunner:
72
69
  self.out_queue,
73
70
  model_path,
74
71
  torch_dtype,
75
- is_embedding_model,
72
+ is_generation_model,
76
73
  ),
77
74
  )
78
75
  self.model_proc.start()
79
76
 
80
77
  def start_model_process(
81
- self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
78
+ self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
82
79
  ):
83
80
  self.tokenizer = AutoTokenizer.from_pretrained(
84
81
  model_path,
@@ -86,12 +83,12 @@ class HFRunner:
86
83
  trust_remote_code=True,
87
84
  )
88
85
 
89
- self.is_embedding_model = (
90
- is_embedding_model(model_path)
91
- if is_embedding_model is None
92
- else is_embedding_model
86
+ self.is_generation_model = (
87
+ is_generation_model(model_path)
88
+ if is_generation_model is None
89
+ else is_generation_model
93
90
  )
94
- if not self.is_embedding_model:
91
+ if self.is_generation_model:
95
92
  self.model = AutoModelForCausalLM.from_pretrained(
96
93
  model_path,
97
94
  torch_dtype=torch_dtype,
@@ -103,13 +100,13 @@ class HFRunner:
103
100
 
104
101
  self.model = SentenceTransformer(
105
102
  model_path,
106
- device="cpu",
107
- ).to(dtype=torch_dtype)
103
+ model_kwargs={"torch_dtype": torch_dtype},
104
+ )
108
105
 
109
106
  while True:
110
107
  prompts, max_new_tokens = in_queue.get()
111
108
  if prompts is not None:
112
- if not self.is_embedding_model:
109
+ if self.is_generation_model:
113
110
  output_strs = []
114
111
  prefill_logprobs = []
115
112
  for p in prompts:
@@ -123,7 +120,9 @@ class HFRunner:
123
120
  output_ids = self.model.generate(
124
121
  input_ids, do_sample=False, max_new_tokens=max_new_tokens
125
122
  )
126
- output_strs.append(self.tokenizer.decode(output_ids[0]))
123
+ output_strs.append(
124
+ self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
125
+ )
127
126
 
128
127
  logits = self.model.forward(input_ids).logits[0]
129
128
  logprobs = F.log_softmax(
@@ -144,7 +143,6 @@ class HFRunner:
144
143
  )
145
144
 
146
145
  else:
147
- assert isinstance(prompts, List[str])
148
146
  logits = self.model.encode(prompts).tolist()
149
147
 
150
148
  out_queue.put(ModelOutput(embed_logits=logits))
@@ -152,7 +150,7 @@ class HFRunner:
152
150
  def forward(
153
151
  self,
154
152
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
155
- max_new_tokens=64,
153
+ max_new_tokens=8,
156
154
  ):
157
155
  self.in_queue.put((prompts, max_new_tokens))
158
156
  return self.out_queue.get()
@@ -175,16 +173,13 @@ class SRTRunner:
175
173
  model_path,
176
174
  tp_size=1,
177
175
  torch_dtype=torch.float16,
178
- is_embedding_model=None,
176
+ is_generation_model=None,
179
177
  ):
180
- self.is_embedding_model = (
181
- is_embedding_model(model_path)
182
- if is_embedding_model is None
183
- else is_embedding_model
178
+ self.is_generation_model = (
179
+ is_generation_model(model_path)
180
+ if is_generation_model is None
181
+ else is_generation_model
184
182
  )
185
- if self.is_embedding_model:
186
- raise NotImplementedError()
187
-
188
183
  self.runtime = Runtime(
189
184
  model_path=model_path,
190
185
  tp_size=tp_size,
@@ -194,40 +189,45 @@ class SRTRunner:
194
189
  def forward(
195
190
  self,
196
191
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
197
- max_new_tokens=64,
192
+ max_new_tokens=8,
198
193
  ):
199
- # the return value contains logprobs from prefill
200
- output_strs = []
201
- top_input_logprobs = []
202
- sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
203
- for prompt in prompts:
204
- response = self.runtime.generate(
205
- prompt,
206
- sampling_params=sampling_params,
207
- return_logprob=True,
208
- top_logprobs_num=NUM_TOP_LOGPROBS,
209
- )
210
- response = json.loads(response)
211
- output_strs.append(response["text"])
212
- top_input_logprobs.append(
213
- [
214
- [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
215
- for x in response["meta_info"]["input_top_logprobs"][1:]
216
- ]
217
- + [
194
+ if self.is_generation_model:
195
+ # the return value contains logprobs from prefill
196
+ output_strs = []
197
+ top_input_logprobs = []
198
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
199
+ for prompt in prompts:
200
+ response = self.runtime.generate(
201
+ prompt,
202
+ sampling_params=sampling_params,
203
+ return_logprob=True,
204
+ top_logprobs_num=NUM_TOP_LOGPROBS,
205
+ )
206
+ response = json.loads(response)
207
+ output_strs.append(response["text"])
208
+ top_input_logprobs.append(
218
209
  [
219
- tup[0]
220
- for tup in response["meta_info"]["output_top_logprobs"][0][
221
- :NUM_TOP_LOGPROBS
210
+ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
211
+ for x in response["meta_info"]["input_top_logprobs"][1:]
212
+ ]
213
+ + [
214
+ [
215
+ tup[0]
216
+ for tup in response["meta_info"]["output_top_logprobs"][0][
217
+ :NUM_TOP_LOGPROBS
218
+ ]
222
219
  ]
223
220
  ]
224
- ]
225
- )
226
- # print(response["meta_info"]["output_top_logprobs"][0])
221
+ )
227
222
 
228
- return ModelOutput(
229
- output_strs=output_strs, top_input_logprobs=top_input_logprobs
230
- )
223
+ return ModelOutput(
224
+ output_strs=output_strs, top_input_logprobs=top_input_logprobs
225
+ )
226
+ else:
227
+ response = self.runtime.encode(prompts)
228
+ response = json.loads(response)
229
+ logits = [x["embedding"] for x in response]
230
+ return ModelOutput(embed_logits=logits)
231
231
 
232
232
  def __enter__(self):
233
233
  return self