sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
- from sglang.srt.managers.router.model_runner import global_server_args_dict
8
+ from sglang.srt.managers.controller.model_runner import global_server_args_dict
9
9
  from sglang.srt.utils import wrap_kernel_launcher
10
10
 
11
11
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
@@ -16,6 +16,12 @@ else:
16
16
  REDUCE_TORCH_TYPE = torch.float16
17
17
 
18
18
 
19
+ @triton.jit
20
+ def tanh(x):
21
+ # Tanh is just a scaled sigmoid
22
+ return 2 * tl.sigmoid(2 * x) - 1
23
+
24
+
19
25
  @triton.jit
20
26
  def _fwd_kernel_stage1(
21
27
  Q,
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
35
41
  kv_group_num: tl.constexpr,
36
42
  BLOCK_DMODEL: tl.constexpr,
37
43
  BLOCK_N: tl.constexpr,
44
+ logit_cap: tl.constexpr,
38
45
  ):
39
46
  cur_batch = tl.program_id(0)
40
47
  cur_head = tl.program_id(1)
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
77
84
  ).to(REDUCE_TRITON_TYPE)
78
85
  att_value = tl.sum(q[None, :] * k, 1)
79
86
  att_value *= sm_scale
87
+
88
+ if logit_cap > 0:
89
+ att_value = logit_cap * tanh(att_value / logit_cap)
90
+
80
91
  off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
81
92
  tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
82
93
 
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
165
176
  B_Start_Loc,
166
177
  B_Seqlen,
167
178
  max_len_in_batch,
179
+ logit_cap,
168
180
  ):
169
181
  BLOCK = 32
170
182
  # shape constraints
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
223
235
  kv_group_num=kv_group_num,
224
236
  BLOCK_DMODEL=Lk,
225
237
  BLOCK_N=BLOCK,
238
+ logit_cap=logit_cap,
226
239
  num_warps=num_warps,
227
240
  num_stages=1,
228
241
  )
@@ -304,6 +317,7 @@ def token_attention_fwd(
304
317
  max_len_in_batch,
305
318
  other_kv_index,
306
319
  total_num_tokens,
320
+ logit_cap=-1,
307
321
  att_m=None,
308
322
  ):
309
323
  if att_m is None:
@@ -320,6 +334,7 @@ def token_attention_fwd(
320
334
  b_start_loc,
321
335
  b_seq_len,
322
336
  max_len_in_batch,
337
+ logit_cap,
323
338
  )
324
339
  _token_softmax_reducev_fwd(
325
340
  att_m,
@@ -0,0 +1,110 @@
1
+ """A data parallel worker thread."""
2
+ import asyncio
3
+ import logging
4
+ import queue
5
+ import threading
6
+ from typing import List, Callable
7
+
8
+ import uvloop
9
+ import zmq
10
+
11
+ from sglang.global_config import global_config
12
+ from sglang.srt.managers.controller.tp_worker import ModelTpClient
13
+ from sglang.srt.managers.io_struct import BatchTokenIDOut
14
+ from sglang.srt.server_args import PortArgs, ServerArgs
15
+ from sglang.srt.utils import kill_parent_process
16
+ from sglang.utils import get_exception_traceback
17
+
18
+ logger = logging.getLogger("srt.controller")
19
+ CHECKING_INTERVAL = 5
20
+
21
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
22
+
23
+
24
+ class DataParallelWorkerThread(threading.Thread):
25
+ def __init__(
26
+ self,
27
+ worker_id: int,
28
+ request_queue: queue.Queue,
29
+ detokenizer_port: int,
30
+ step_func: Callable,
31
+ ):
32
+ super(DataParallelWorkerThread, self).__init__()
33
+ self.worker_id = worker_id
34
+ self.request_queue = request_queue
35
+ self.liveness = True
36
+ self.request_dependency_delay = global_config.request_dependency_delay
37
+
38
+ context = zmq.asyncio.Context()
39
+ self.send_to_detokenizer = context.socket(zmq.PUSH)
40
+ self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
41
+
42
+ self.step = step_func
43
+
44
+ async def loop_for_forward(self):
45
+ while self.liveness:
46
+ requests = []
47
+ while not self.request_queue.empty():
48
+ requests.append(self.request_queue.get())
49
+
50
+ out_pyobjs: List[BatchTokenIDOut] = []
51
+ try:
52
+ out_pyobjs = await self.step(requests)
53
+ except Exception:
54
+ for r in requests:
55
+ self.request_queue.put(r)
56
+ logger.error(
57
+ f"Worker thread {self.worker_id}: "
58
+ f"failed to get back from Model Server\n"
59
+ f"{get_exception_traceback()}"
60
+ )
61
+ self.liveness = False
62
+ # Crash the whole server when there are any errors.
63
+ # TODO(lianmin): make this an option.
64
+ kill_parent_process()
65
+ return
66
+
67
+ for obj in out_pyobjs:
68
+ self.send_to_detokenizer.send_pyobj(obj)
69
+
70
+ # async sleep for receiving the subsequent request and avoiding cache miss
71
+ if len(out_pyobjs) != 0:
72
+ has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
73
+ if has_finished:
74
+ await asyncio.sleep(self.request_dependency_delay)
75
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
76
+
77
+ async def monitoring(self):
78
+ while True:
79
+ await asyncio.sleep(CHECKING_INTERVAL)
80
+ # can plug in monitoring logic here
81
+
82
+ def run(self):
83
+ logger.info(f"DataParallelWorkerThread {self.worker_id} start")
84
+ loop = asyncio.new_event_loop()
85
+ asyncio.set_event_loop(loop)
86
+ loop.create_task(self.monitoring())
87
+ loop.run_until_complete(self.loop_for_forward())
88
+
89
+
90
+ def start_data_parallel_worker(
91
+ server_args: ServerArgs,
92
+ port_args: PortArgs,
93
+ model_overide_args,
94
+ gpu_ids: List[int],
95
+ worker_id: int,
96
+ ):
97
+ model_tp_client = ModelTpClient(
98
+ gpu_ids,
99
+ server_args,
100
+ port_args.model_port_args[worker_id],
101
+ model_overide_args,
102
+ )
103
+ worker_thread = DataParallelWorkerThread(
104
+ worker_id=worker_id,
105
+ request_queue=queue.Queue(),
106
+ detokenizer_port=port_args.detokenizer_port,
107
+ step_func=model_tp_client.step,
108
+ )
109
+ worker_thread.start()
110
+ return worker_thread