sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ """
2
+ Run few-shot GSM-8K evaluation.
3
+
4
+ Usage:
5
+ python3 -m sglang.test.few_shot_gsm8k --num-questions 200
6
+ """
7
+
8
+ import argparse
9
+ import ast
10
+ import re
11
+ import time
12
+
13
+ import numpy as np
14
+
15
+ from sglang.api import set_default_backend
16
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
17
+ from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
18
+
19
+ INVALID = -9999999
20
+
21
+
22
+ def get_one_example(lines, i, include_answer):
23
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
24
+ if include_answer:
25
+ ret += " " + lines[i]["answer"]
26
+ return ret
27
+
28
+
29
+ def get_few_shot_examples(lines, k):
30
+ ret = ""
31
+ for i in range(k):
32
+ ret += get_one_example(lines, i, True) + "\n\n"
33
+ return ret
34
+
35
+
36
+ def get_answer_value(answer_str):
37
+ answer_str = answer_str.replace(",", "")
38
+ numbers = re.findall(r"\d+", answer_str)
39
+ if len(numbers) < 1:
40
+ return INVALID
41
+ try:
42
+ return ast.literal_eval(numbers[-1])
43
+ except SyntaxError:
44
+ return INVALID
45
+
46
+
47
+ def main(args):
48
+ # Select backend
49
+ set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
50
+
51
+ # Read data
52
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
53
+ filename = download_and_cache_file(url)
54
+ lines = list(read_jsonl(filename))
55
+
56
+ # Construct prompts
57
+ num_questions = args.num_questions
58
+ num_shots = args.num_shots
59
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
60
+
61
+ questions = []
62
+ labels = []
63
+ for i in range(len(lines[:num_questions])):
64
+ questions.append(get_one_example(lines, i, False))
65
+ labels.append(get_answer_value(lines[i]["answer"]))
66
+ assert all(l != INVALID for l in labels)
67
+ arguments = [{"question": q} for q in questions]
68
+
69
+ #####################################
70
+ ######### SGL Program Begin #########
71
+ #####################################
72
+
73
+ import sglang as sgl
74
+
75
+ @sgl.function
76
+ def few_shot_gsm8k(s, question):
77
+ s += few_shot_examples + question
78
+ s += sgl.gen(
79
+ "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
80
+ )
81
+
82
+ #####################################
83
+ ########## SGL Program End ##########
84
+ #####################################
85
+
86
+ # Run requests
87
+ tic = time.time()
88
+ states = few_shot_gsm8k.run_batch(
89
+ arguments,
90
+ temperature=0,
91
+ num_threads=args.parallel,
92
+ progress_bar=True,
93
+ )
94
+ latency = time.time() - tic
95
+
96
+ preds = []
97
+ for i in range(len(states)):
98
+ preds.append(get_answer_value(states[i]["answer"]))
99
+
100
+ # print(f"{preds=}")
101
+ # print(f"{labels=}")
102
+
103
+ # Compute accuracy
104
+ acc = np.mean(np.array(preds) == np.array(labels))
105
+ invalid = np.mean(np.array(preds) == INVALID)
106
+
107
+ # Compute speed
108
+ num_output_tokens = sum(
109
+ s.get_meta_info("answer")["completion_tokens"] for s in states
110
+ )
111
+ output_throughput = num_output_tokens / latency
112
+
113
+ # Print results
114
+ print(f"Accuracy: {acc:.3f}")
115
+ print(f"Invalid: {invalid:.3f}")
116
+ print(f"Latency: {latency:.3f} s")
117
+ print(f"Output throughput: {output_throughput:.3f} token/s")
118
+
119
+ # Dump results
120
+ dump_state_text("tmp_output_gsm8k.txt", states)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--num-shots", type=int, default=5)
126
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
127
+ parser.add_argument("--num-questions", type=int, default=200)
128
+ parser.add_argument("--parallel", type=int, default=128)
129
+ parser.add_argument("--host", type=str, default="http://127.0.0.1")
130
+ parser.add_argument("--port", type=int, default=30000)
131
+ args = parser.parse_args()
132
+ main(args)
sglang/test/runners.py CHANGED
@@ -21,6 +21,7 @@ from typing import List, Union
21
21
 
22
22
  import torch
23
23
  import torch.nn.functional as F
24
+ from peft import PeftModel
24
25
  from transformers import AutoModelForCausalLM, AutoTokenizer
25
26
 
26
27
  from sglang.srt.server import Runtime
@@ -50,6 +51,13 @@ def get_dtype_str(torch_dtype):
50
51
  raise NotImplementedError()
51
52
 
52
53
 
54
+ def get_top_logprobs(logits, k):
55
+ logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
56
+ del logits
57
+ logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1)
58
+ return logprobs
59
+
60
+
53
61
  @dataclass
54
62
  class ModelOutput:
55
63
  output_strs: List[str] = None
@@ -65,8 +73,10 @@ class HFRunner:
65
73
  model_path,
66
74
  torch_dtype,
67
75
  is_generation,
76
+ output_str_only=False,
68
77
  ):
69
78
  self.is_generation = is_generation
79
+ self.output_str_only = output_str_only
70
80
 
71
81
  self.in_queue = mp.Queue()
72
82
  self.out_queue = mp.Queue()
@@ -89,7 +99,7 @@ class HFRunner:
89
99
  )
90
100
 
91
101
  if self.is_generation:
92
- self.model = AutoModelForCausalLM.from_pretrained(
102
+ self.base_model = AutoModelForCausalLM.from_pretrained(
93
103
  model_path,
94
104
  torch_dtype=torch_dtype,
95
105
  trust_remote_code=False,
@@ -104,12 +114,16 @@ class HFRunner:
104
114
  )
105
115
 
106
116
  while True:
107
- prompts, max_new_tokens = in_queue.get()
117
+ prompts, max_new_tokens, lora_paths = in_queue.get()
118
+ if lora_paths is not None:
119
+ assert len(prompts) == len(lora_paths)
120
+
108
121
  if prompts is not None:
109
122
  if self.is_generation:
110
123
  output_strs = []
111
- prefill_logprobs = []
112
- for p in prompts:
124
+ top_input_logprobs = []
125
+ top_output_logprobs = []
126
+ for i, p in enumerate(prompts):
113
127
  if isinstance(p, str):
114
128
  input_ids = self.tokenizer.encode(
115
129
  p, return_tensors="pt"
@@ -117,40 +131,68 @@ class HFRunner:
117
131
  else:
118
132
  input_ids = torch.tensor([p], device="cuda")
119
133
 
120
- output_ids = self.model.generate(
121
- input_ids, do_sample=False, max_new_tokens=max_new_tokens
134
+ if lora_paths is not None and lora_paths[i] is not None:
135
+ self.model = PeftModel.from_pretrained(
136
+ self.base_model,
137
+ lora_paths[i],
138
+ torch_dtype=torch_dtype,
139
+ is_trainable=False,
140
+ )
141
+ else:
142
+ self.model = self.base_model
143
+
144
+ outputs = self.model.generate(
145
+ input_ids,
146
+ do_sample=False,
147
+ temperature=None,
148
+ top_p=None,
149
+ max_new_tokens=max_new_tokens,
150
+ return_dict_in_generate=True,
151
+ output_scores=(not self.output_str_only),
122
152
  )
123
153
  output_strs.append(
124
- self.tokenizer.decode(output_ids[0][len(input_ids[0]) :])
154
+ self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
125
155
  )
126
-
127
- logits = self.model.forward(input_ids).logits[0]
128
- logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
129
- logprobs, top_indices = torch.topk(
130
- logprobs, k=NUM_TOP_LOGPROBS, dim=-1
131
- )
132
- # print("index", top_indices)
133
- prefill_logprobs.append(logprobs.tolist())
134
- del logits
135
- del logprobs
156
+ if not self.output_str_only:
157
+ # outputs.scores: (num_token, 1, vocab_size)
158
+ top_output_logprobs.append(
159
+ [
160
+ get_top_logprobs(
161
+ logits[0], NUM_TOP_LOGPROBS
162
+ ).tolist()
163
+ for logits in outputs.scores
164
+ ]
165
+ )
166
+ del outputs
167
+
168
+ input_logits = self.model.forward(input_ids).logits[0]
169
+ top_input_logprobs.append(
170
+ get_top_logprobs(
171
+ input_logits, NUM_TOP_LOGPROBS
172
+ ).tolist()
173
+ )
174
+ del input_logits
136
175
 
137
176
  out_queue.put(
138
177
  ModelOutput(
139
- output_strs=output_strs, top_input_logprobs=prefill_logprobs
178
+ output_strs=output_strs,
179
+ top_input_logprobs=top_input_logprobs,
180
+ top_output_logprobs=top_output_logprobs,
140
181
  )
141
182
  )
142
183
 
143
184
  else:
185
+ assert not self.output_str_only
144
186
  logits = self.model.encode(prompts).tolist()
145
-
146
187
  out_queue.put(ModelOutput(embed_logits=logits))
147
188
 
148
189
  def forward(
149
190
  self,
150
191
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
151
192
  max_new_tokens=8,
193
+ lora_paths=None,
152
194
  ):
153
- self.in_queue.put((prompts, max_new_tokens))
195
+ self.in_queue.put((prompts, max_new_tokens, lora_paths))
154
196
  return self.out_queue.get()
155
197
 
156
198
  def terminate(self):
@@ -173,6 +215,10 @@ class SRTRunner:
173
215
  is_generation,
174
216
  tp_size=1,
175
217
  port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
218
+ lora_paths=None,
219
+ max_loras_per_batch=4,
220
+ disable_cuda_graph=False,
221
+ disable_radix_cache=False,
176
222
  ):
177
223
  self.is_generation = is_generation
178
224
  self.runtime = Runtime(
@@ -183,21 +229,28 @@ class SRTRunner:
183
229
  mem_fraction_static=0.69,
184
230
  trust_remote_code=False,
185
231
  is_embedding=not self.is_generation,
232
+ lora_paths=lora_paths,
233
+ max_loras_per_batch=max_loras_per_batch,
234
+ disable_cuda_graph=disable_cuda_graph,
235
+ disable_radix_cache=disable_radix_cache,
186
236
  )
187
237
 
188
238
  def forward(
189
239
  self,
190
240
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
191
241
  max_new_tokens=8,
242
+ lora_paths=None,
192
243
  ):
193
244
  if self.is_generation:
194
245
  # the return value contains logprobs from prefill
195
246
  output_strs = []
196
247
  top_input_logprobs = []
248
+ top_output_logprobs = []
197
249
  sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
198
- for prompt in prompts:
250
+ for i, prompt in enumerate(prompts):
199
251
  response = self.runtime.generate(
200
252
  prompt,
253
+ lora_path=lora_paths[i] if lora_paths else None,
201
254
  sampling_params=sampling_params,
202
255
  return_logprob=True,
203
256
  logprob_start_len=0,
@@ -219,9 +272,48 @@ class SRTRunner:
219
272
  ]
220
273
  ]
221
274
  )
275
+ top_output_logprobs.append(
276
+ [
277
+ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
278
+ for x in response["meta_info"]["output_top_logprobs"]
279
+ ]
280
+ )
281
+
282
+ return ModelOutput(
283
+ output_strs=output_strs,
284
+ top_input_logprobs=top_input_logprobs,
285
+ top_output_logprobs=top_output_logprobs,
286
+ )
287
+ else:
288
+ response = self.runtime.encode(prompts)
289
+ response = json.loads(response)
290
+ logits = [x["embedding"] for x in response]
291
+ return ModelOutput(embed_logits=logits)
292
+
293
+ def batch_forward(
294
+ self,
295
+ prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
296
+ max_new_tokens=8,
297
+ lora_paths=None,
298
+ ):
299
+ """
300
+ testing serving by sending all prompts once
301
+ only return output strings and no logprobs
302
+ """
303
+ if self.is_generation:
304
+ # the return value contains logprobs from prefill
305
+ output_strs = []
306
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
307
+ response = self.runtime.generate(
308
+ prompts,
309
+ lora_path=lora_paths if lora_paths else None,
310
+ sampling_params=sampling_params,
311
+ )
312
+ response = json.loads(response)
313
+ output_strs = [r["text"] for r in response]
222
314
 
223
315
  return ModelOutput(
224
- output_strs=output_strs, top_input_logprobs=top_input_logprobs
316
+ output_strs=output_strs,
225
317
  )
226
318
  else:
227
319
  response = self.runtime.encode(prompts)
@@ -7,7 +7,7 @@ import time
7
7
  import numpy as np
8
8
 
9
9
  import sglang as sgl
10
- from sglang.utils import fetch_and_cache_jsonl
10
+ from sglang.utils import download_and_cache_file, read_jsonl
11
11
 
12
12
 
13
13
  def test_few_shot_qa():
@@ -456,10 +456,6 @@ def test_chat_completion_speculative():
456
456
  def test_hellaswag_select():
457
457
  """Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
458
458
 
459
- url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
460
- lines = fetch_and_cache_jsonl(url)
461
-
462
- # Construct prompts
463
459
  def get_one_example(lines, i, include_answer):
464
460
  ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
465
461
  if include_answer:
@@ -472,6 +468,12 @@ def test_hellaswag_select():
472
468
  ret += get_one_example(lines, i, True) + "\n\n"
473
469
  return ret
474
470
 
471
+ # Read data
472
+ url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
473
+ filename = download_and_cache_file(url)
474
+ lines = list(read_jsonl(filename))
475
+
476
+ # Construct prompts
475
477
  num_questions = 200
476
478
  num_shots = 20
477
479
  few_shot_examples = get_few_shot_examples(lines, num_shots)
sglang/test/test_utils.py CHANGED
@@ -7,6 +7,7 @@ import subprocess
7
7
  import threading
8
8
  import time
9
9
  from functools import partial
10
+ from types import SimpleNamespace
10
11
  from typing import Callable, List, Optional
11
12
 
12
13
  import numpy as np
@@ -14,6 +15,7 @@ import requests
14
15
  import torch
15
16
  import torch.nn.functional as F
16
17
 
18
+ from sglang.bench_serving import run_benchmark
17
19
  from sglang.global_config import global_config
18
20
  from sglang.lang.backend.openai import OpenAI
19
21
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
@@ -28,7 +30,13 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Meta-Llama-3.1-70B-Instruc
28
30
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
29
31
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8"
30
32
 
31
- if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
33
+
34
+ def is_in_ci():
35
+ """Return whether it is in CI runner."""
36
+ return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
37
+
38
+
39
+ if is_in_ci():
32
40
  DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
33
41
  DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
34
42
  else:
@@ -501,3 +509,79 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
501
509
 
502
510
  def get_similarities(vec1, vec2):
503
511
  return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
512
+
513
+
514
+ def run_bench_serving(model, num_prompts, request_rate, other_server_args):
515
+ # Launch the server
516
+ base_url = DEFAULT_URL_FOR_TEST
517
+ process = popen_launch_server(
518
+ model,
519
+ base_url,
520
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
521
+ other_args=other_server_args,
522
+ )
523
+
524
+ # Run benchmark
525
+ args = SimpleNamespace(
526
+ backend="sglang",
527
+ base_url=base_url,
528
+ host=None,
529
+ port=None,
530
+ dataset_name="random",
531
+ dataset_path="",
532
+ model=None,
533
+ tokenizer=None,
534
+ num_prompts=num_prompts,
535
+ sharegpt_output_len=None,
536
+ random_input_len=4096,
537
+ random_output_len=2048,
538
+ random_range_ratio=0.0,
539
+ request_rate=request_rate,
540
+ multi=None,
541
+ seed=0,
542
+ output_file=None,
543
+ disable_tqdm=False,
544
+ disable_stream=False,
545
+ disable_ignore_eos=False,
546
+ extra_request_body=None,
547
+ )
548
+
549
+ try:
550
+ res = run_benchmark(args)
551
+ finally:
552
+ kill_child_process(process.pid)
553
+
554
+ assert res["completed"] == num_prompts
555
+ return res
556
+
557
+
558
+ def run_bench_latency(model, other_args):
559
+ command = [
560
+ "python3",
561
+ "-m",
562
+ "sglang.bench_latency",
563
+ "--model-path",
564
+ model,
565
+ "--batch-size",
566
+ "1",
567
+ "--input",
568
+ "128",
569
+ "--output",
570
+ "8",
571
+ *other_args,
572
+ ]
573
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
574
+
575
+ try:
576
+ stdout, stderr = process.communicate()
577
+ output = stdout.decode()
578
+ error = stderr.decode()
579
+ print(f"Output: {output}", flush=True)
580
+ print(f"Error: {error}", flush=True)
581
+
582
+ lastline = output.split("\n")[-3]
583
+ output_throughput = float(lastline.split(" ")[-2])
584
+ finally:
585
+ kill_child_process(process.pid)
586
+
587
+ return output_throughput
sglang/utils.py CHANGED
@@ -12,7 +12,7 @@ import urllib.request
12
12
  from concurrent.futures import ThreadPoolExecutor
13
13
  from io import BytesIO
14
14
  from json import dumps
15
- from typing import Union
15
+ from typing import Optional, Union
16
16
 
17
17
  import numpy as np
18
18
  import requests
@@ -38,13 +38,11 @@ def is_same_type(values: list):
38
38
 
39
39
  def read_jsonl(filename: str):
40
40
  """Read a JSONL file."""
41
- rets = []
42
41
  with open(filename) as fin:
43
42
  for line in fin:
44
43
  if line.startswith("#"):
45
44
  continue
46
- rets.append(json.loads(line))
47
- return rets
45
+ yield json.loads(line)
48
46
 
49
47
 
50
48
  def dump_state_text(filename: str, states: list, mode: str = "w"):
@@ -264,38 +262,35 @@ class LazyImport:
264
262
  return module(*args, **kwargs)
265
263
 
266
264
 
267
- def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
268
- """Read and cache a jsonl file from a url."""
265
+ def download_and_cache_file(url: str, filename: Optional[str] = None):
266
+ """Read and cache a file from a url."""
267
+ if filename is None:
268
+ filename = os.path.join("/tmp", url.split("/")[-1])
269
269
 
270
270
  # Check if the cache file already exists
271
- if os.path.exists(cache_file):
272
- print("Loading data from cache...")
273
- with open(cache_file, "r") as f:
274
- data = [json.loads(line) for line in f]
275
- else:
276
- print("Downloading data from URL...")
277
- # Stream the response to show the progress bar
278
- response = requests.get(url, stream=True)
279
- response.raise_for_status() # Check for request errors
280
-
281
- # Total size of the file in bytes
282
- total_size = int(response.headers.get("content-length", 0))
283
- chunk_size = 1024 # Download in chunks of 1KB
284
-
285
- # Use tqdm to display the progress bar
286
- with open(cache_file, "wb") as f, tqdm(
287
- desc=cache_file,
288
- total=total_size,
289
- unit="B",
290
- unit_scale=True,
291
- unit_divisor=1024,
292
- ) as bar:
293
- for chunk in response.iter_content(chunk_size=chunk_size):
294
- f.write(chunk)
295
- bar.update(len(chunk))
296
-
297
- # Convert the data to a list of dictionaries
298
- with open(cache_file, "r") as f:
299
- data = [json.loads(line) for line in f]
300
-
301
- return data
271
+ if os.path.exists(filename):
272
+ return filename
273
+
274
+ print(f"Downloading from {url} to {filename}")
275
+
276
+ # Stream the response to show the progress bar
277
+ response = requests.get(url, stream=True)
278
+ response.raise_for_status() # Check for request errors
279
+
280
+ # Total size of the file in bytes
281
+ total_size = int(response.headers.get("content-length", 0))
282
+ chunk_size = 1024 # Download in chunks of 1KB
283
+
284
+ # Use tqdm to display the progress bar
285
+ with open(filename, "wb") as f, tqdm(
286
+ desc=filename,
287
+ total=total_size,
288
+ unit="B",
289
+ unit_scale=True,
290
+ unit_divisor=1024,
291
+ ) as bar:
292
+ for chunk in response.iter_content(chunk_size=chunk_size):
293
+ f.write(chunk)
294
+ bar.update(len(chunk))
295
+
296
+ return filename
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.0"
1
+ __version__ = "0.3.1"