sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -318,6 +318,10 @@ class Qwen25Detector(BaseFormatDetector):
318
318
  self.bot_token = "<tool_call>"
319
319
  self.eot_token = "</tool_call>"
320
320
 
321
+ def has_tool_call(self, text: str) -> bool:
322
+ """Check if the text contains a Qwen 2.5 format tool call."""
323
+ return self.bot_token in text
324
+
321
325
  def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
322
326
  """
323
327
  One-time parsing: Detects and parses tool calls in the provided text.
@@ -352,6 +356,10 @@ class MistralDetector(BaseFormatDetector):
352
356
  self.bot_token = "[TOOL_CALLS] ["
353
357
  self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
354
358
 
359
+ def has_tool_call(self, text: str) -> bool:
360
+ """Check if the text contains a Mistral format tool call."""
361
+ return self.bot_token in text
362
+
355
363
  def _clean_text(self, text: str) -> str:
356
364
  """
357
365
  clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
@@ -397,12 +405,21 @@ class Llama32Detector(BaseFormatDetector):
397
405
  super().__init__()
398
406
  self.bot_token = "<|python_tag|>"
399
407
 
408
+ def has_tool_call(self, text: str) -> bool:
409
+ """Check if the text contains a Llama 3.2 format tool call."""
410
+ # depending on the prompt format the Llama model may or may not
411
+ # prefix the output with the <|python_tag|> token
412
+ return "<|python_tag|>" in text or text.startswith("{")
413
+
400
414
  def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
401
415
  """Parse function calls from text, handling multiple JSON objects."""
402
- if "<|python_tag|>" not in text:
416
+ if "<|python_tag|>" not in text and not text.startswith("{"):
403
417
  return []
404
418
 
405
- _, action_text = text.split("<|python_tag|>")
419
+ if "<|python_tag|>" in text:
420
+ _, action_text = text.split("<|python_tag|>")
421
+ else:
422
+ action_text = text
406
423
 
407
424
  # Split by semicolon and process each part
408
425
  json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
@@ -501,6 +518,20 @@ class FunctionCallParser:
501
518
  self.multi_format_parser = MultiFormatParser(detectors)
502
519
  self.tools = tools
503
520
 
521
+ def has_tool_call(self, text: str) -> bool:
522
+ """
523
+ Check if the given text contains a tool call in the format supported by this parser.
524
+ This delegates to the detector's implementation.
525
+
526
+ :param text: The text to check for tool calls
527
+ :return: True if the text contains a tool call, False otherwise
528
+ """
529
+ # Check all detectors in the multi_format_parser
530
+ for detector in self.multi_format_parser.detectors:
531
+ if detector.has_tool_call(text):
532
+ return True
533
+ return False
534
+
504
535
  def parse_non_stream(self, full_text: str):
505
536
  """
506
537
  Non-streaming call: one-time parsing
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import functools
4
+ import logging
5
+ from contextlib import contextmanager
4
6
  from typing import TYPE_CHECKING, Union
5
7
 
6
8
  import torch
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
14
16
  tensor_model_parallel_all_reduce,
15
17
  )
16
18
 
19
+ logger = logging.getLogger(__name__)
20
+
17
21
  if TYPE_CHECKING:
18
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
23
 
@@ -86,6 +90,27 @@ def get_attention_dp_size():
86
90
  return _DP_SIZE
87
91
 
88
92
 
93
+ @contextmanager
94
+ def disable_dp_size():
95
+ """Patch the tp group temporarily until this function ends.
96
+
97
+ This method is for draft workers of speculative decoding to run draft model
98
+ with different tp degree from that of target model workers.
99
+
100
+ Args:
101
+ tp_group (GroupCoordinator): the tp group coordinator
102
+ """
103
+ global _DP_SIZE
104
+ assert _DP_SIZE is not None, "dp attention not initialized!"
105
+
106
+ old_dp_size = _DP_SIZE
107
+ _DP_SIZE = 1
108
+ try:
109
+ yield
110
+ finally:
111
+ _DP_SIZE = old_dp_size
112
+
113
+
89
114
  def get_dp_local_info(forward_batch: ForwardBatch):
90
115
  dp_rank = get_attention_dp_rank()
91
116
 
@@ -159,7 +184,8 @@ def dp_gather(
159
184
  layer_id != "embedding" or get_attention_tp_rank() == 0
160
185
  ):
161
186
  assert (
162
- global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
187
+ global_tokens.untyped_storage().data_ptr()
188
+ != local_tokens.untyped_storage().data_ptr()
163
189
  ), "aliasing between global_tokens and local_tokens not allowed"
164
190
  memcpy_triton(
165
191
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
@@ -174,8 +200,9 @@ def dp_gather(
174
200
  torch.ops.sglang.inplace_all_reduce(
175
201
  global_tokens, group_name=get_tp_group().unique_name
176
202
  )
203
+
177
204
  else:
178
- global_tokens = tensor_model_parallel_all_reduce(global_tokens)
205
+ global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
179
206
 
180
207
 
181
208
  def dp_scatter(
@@ -186,6 +213,7 @@ def dp_scatter(
186
213
  # local_num_tokens is not necessarily the same as local_tokens.shape[0],
187
214
  # since local_tokens may be padded for cuda graph
188
215
  local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
216
+
189
217
  local_tokens.fill_(0)
190
218
  assert local_tokens.is_contiguous()
191
219
  assert global_tokens.is_contiguous()
@@ -0,0 +1,411 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ fused_softcap_autotune = triton.autotune(
8
+ configs=[
9
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
10
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
11
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
12
+ triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
13
+ triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
14
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
15
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
16
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
17
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
18
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
19
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
20
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
21
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
22
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
23
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
24
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
25
+ triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
26
+ ],
27
+ key=["n_ele"],
28
+ )
29
+
30
+
31
+ @triton.jit
32
+ def fused_softcap_kernel(
33
+ output_ptr,
34
+ input_ptr,
35
+ n_ele,
36
+ softcap_const: tl.constexpr,
37
+ BLOCK_SIZE: tl.constexpr,
38
+ ):
39
+ pid = tl.program_id(axis=0)
40
+ block_start = pid * BLOCK_SIZE
41
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
42
+ mask = offsets < n_ele
43
+ x = tl.load(input_ptr + offsets, mask=mask)
44
+ fx = x.to(tl.float32)
45
+ fxs = fx / softcap_const
46
+ exped = tl.exp(2 * fxs)
47
+ top = exped - 1
48
+ bottom = exped + 1
49
+ output = top / bottom * softcap_const
50
+ tl.store(output_ptr + offsets, output, mask=mask)
51
+
52
+
53
+ fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
54
+
55
+
56
+ def fused_softcap(x, softcap_const, autotune=False):
57
+ output = torch.empty_like(x, dtype=torch.float32)
58
+ n_elements = output.numel()
59
+ if autotune:
60
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
61
+ fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
62
+ else:
63
+ fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
64
+ output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
65
+ )
66
+ return output
67
+
68
+
69
+ # cast to float + softcap
70
+ class Softcap:
71
+ def __init__(self, softcap_const: float):
72
+ self.softcap_const = softcap_const
73
+
74
+ def __call__(self, *args, **kwargs):
75
+ return self.forward(*args, **kwargs)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ if x.is_cuda:
79
+ return self.forward_cuda(x)
80
+ else:
81
+ return self.forward_native(x)
82
+
83
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
84
+ return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
85
+
86
+ def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
87
+ return fused_softcap(x, self.softcap_const, autotune=autotune)
88
+
89
+
90
+ rmsnorm_autotune = triton.autotune(
91
+ configs=[
92
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
93
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
94
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
95
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
96
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
97
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
98
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
99
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
100
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
101
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
102
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
103
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
104
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
105
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
106
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
107
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
108
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
109
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
110
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
111
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
112
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
113
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
114
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
115
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
116
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
117
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
118
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
119
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
120
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
121
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
122
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
123
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
124
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
125
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
126
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
127
+ ],
128
+ key=["hidden_dim"],
129
+ )
130
+
131
+
132
+ @triton.jit
133
+ def fused_dual_residual_rmsnorm_kernel(
134
+ output_ptr,
135
+ mid_ptr,
136
+ activ_ptr,
137
+ residual_ptr,
138
+ weight1_ptr,
139
+ weight2_ptr,
140
+ eps: tl.constexpr,
141
+ hidden_dim: tl.constexpr,
142
+ BLOCK_SIZE: tl.constexpr,
143
+ ):
144
+ pid = tl.program_id(axis=0)
145
+ input_start = pid * hidden_dim
146
+
147
+ offsets = tl.arange(0, BLOCK_SIZE)
148
+ mask = offsets < hidden_dim
149
+
150
+ a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
151
+ a = a_.to(tl.float32)
152
+ rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
153
+
154
+ r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
155
+ w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
156
+ w1 = w1_.to(tl.float32)
157
+
158
+ a2r = r + (a / rms * w1).to(r.dtype)
159
+ tl.store(
160
+ mid_ptr + input_start + offsets,
161
+ a2r,
162
+ mask=mask,
163
+ )
164
+
165
+ a2r = a2r.to(tl.float32)
166
+ rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
167
+
168
+ w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
169
+ w2 = w2_.to(tl.float32)
170
+
171
+ tl.store(
172
+ output_ptr + input_start + offsets,
173
+ a2r / rms2 * w2, # implicitly casts to output dtype here
174
+ mask=mask,
175
+ )
176
+
177
+
178
+ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
179
+ fused_dual_residual_rmsnorm_kernel
180
+ )
181
+
182
+
183
+ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
184
+ assert len(x.shape) == 2
185
+ assert x.shape == residual.shape and x.dtype == residual.dtype
186
+ output, mid = torch.empty_like(x), torch.empty_like(x)
187
+ bs, hidden_dim = x.shape
188
+ if autotune:
189
+ fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
190
+ output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
191
+ )
192
+ else:
193
+ config = {
194
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
195
+ "num_warps": max(
196
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
197
+ ),
198
+ }
199
+
200
+ fused_dual_residual_rmsnorm_kernel[(bs,)](
201
+ output,
202
+ mid,
203
+ x,
204
+ residual,
205
+ weight1,
206
+ weight2,
207
+ eps=eps,
208
+ hidden_dim=hidden_dim,
209
+ **config,
210
+ )
211
+
212
+ return output, mid
213
+
214
+
215
+ @triton.jit
216
+ def fused_rmsnorm_kernel(
217
+ output_ptr,
218
+ activ_ptr,
219
+ weight_ptr,
220
+ eps: tl.constexpr,
221
+ hidden_dim: tl.constexpr,
222
+ BLOCK_SIZE: tl.constexpr,
223
+ ):
224
+ pid = tl.program_id(axis=0)
225
+ input_start = pid * hidden_dim
226
+
227
+ offsets = tl.arange(0, BLOCK_SIZE)
228
+ mask = offsets < hidden_dim
229
+
230
+ a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
231
+ a = a_.to(tl.float32)
232
+ rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
233
+
234
+ w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
235
+ w1 = w1_.to(tl.float32)
236
+
237
+ a_rms = a / rms * w1
238
+
239
+ tl.store(
240
+ output_ptr + input_start + offsets,
241
+ a_rms, # implicitly casts to output dtype here
242
+ mask=mask,
243
+ )
244
+
245
+
246
+ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
247
+ assert len(x.shape) == 2
248
+ if inplace:
249
+ output = x
250
+ else:
251
+ output = torch.empty_like(x)
252
+ bs, hidden_dim = x.shape
253
+ config = {
254
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
255
+ "num_warps": max(
256
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
257
+ ),
258
+ }
259
+
260
+ fused_rmsnorm_kernel[(bs,)](
261
+ output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
262
+ )
263
+ return output
264
+
265
+
266
+ class FusedDualResidualRMSNorm:
267
+ """
268
+ Fused implementation of
269
+ y = RMSNorm2(RMSNorm1(x) + residual))
270
+ """
271
+
272
+ def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
273
+ self.rmsnorm1 = rmsnorm1
274
+ self.rmsnorm2 = rmsnorm2
275
+ self.variance_epsilon = self.rmsnorm1.variance_epsilon
276
+ assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
277
+ assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
278
+
279
+ def __call__(self, *args, **kwargs):
280
+ return self.forward(*args, **kwargs)
281
+
282
+ def forward(
283
+ self, x: torch.Tensor, residual: torch.Tensor
284
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ if x.is_cuda:
286
+ return self.forward_cuda(x, residual)
287
+ else:
288
+ return self.forward_flashinfer(x, residual)
289
+
290
+ def forward_cuda(
291
+ self, x: torch.Tensor, residual: torch.Tensor, autotune=False
292
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
293
+ return fused_dual_residual_rmsnorm(
294
+ x,
295
+ residual,
296
+ self.rmsnorm1.weight,
297
+ self.rmsnorm2.weight,
298
+ self.variance_epsilon,
299
+ autotune=autotune,
300
+ )
301
+
302
+ def forward_flashinfer(
303
+ self,
304
+ x: torch.Tensor,
305
+ residual: torch.Tensor,
306
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ normed1 = self.rmsnorm1(x)
308
+ residual = normed1 + residual
309
+ return self.rmsnorm2(residual), residual
310
+
311
+ def forward_native(
312
+ self,
313
+ x: torch.Tensor,
314
+ residual: torch.Tensor,
315
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
316
+ normed1 = self.rmsnorm1.forward_native(x)
317
+ residual = normed1 + residual
318
+ return self.rmsnorm2.forward_native(residual), residual
319
+
320
+
321
+ # gelu on first half of vector
322
+ @triton.jit
323
+ def gelu_and_mul_kernel(
324
+ out_hidden_states_ptr, # (bs, hidden_dim)
325
+ out_scales_ptr, # (bs,)
326
+ hidden_states_ptr, # (bs, hidden_dim * 2)
327
+ quant_max: tl.constexpr,
328
+ static_scale: tl.constexpr,
329
+ hidden_dim: tl.constexpr, # the output hidden_dim
330
+ BLOCK_SIZE: tl.constexpr,
331
+ ):
332
+ pid = tl.program_id(axis=0)
333
+
334
+ input_start = pid * hidden_dim * 2
335
+ output_start = pid * hidden_dim
336
+
337
+ input1_offs = tl.arange(0, BLOCK_SIZE)
338
+ mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
339
+ input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
340
+ output_offs = tl.arange(0, BLOCK_SIZE)
341
+
342
+ x1 = tl.load(
343
+ hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
344
+ ).to(tl.float32)
345
+ x3 = tl.load(
346
+ hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
347
+ ).to(tl.float32)
348
+
349
+ # gelu
350
+ # cast down before mul to better match training?
351
+ gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
352
+ out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
353
+
354
+ if quant_max is not None:
355
+ raise NotImplementedError()
356
+
357
+ tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
358
+
359
+
360
+ def gelu_and_mul_triton(
361
+ hidden_states,
362
+ scales=None,
363
+ quantize=None, # dtype to quantize to
364
+ out=None,
365
+ ):
366
+ bs, in_hidden_dim = hidden_states.shape
367
+ hidden_dim = in_hidden_dim // 2
368
+
369
+ if out is None:
370
+ out_hidden_states = torch.empty(
371
+ (bs, hidden_dim),
372
+ dtype=quantize or hidden_states.dtype,
373
+ device=hidden_states.device,
374
+ )
375
+ else:
376
+ assert out.shape == (bs, hidden_dim)
377
+ assert out.dtype == (quantize or hidden_states.dtype)
378
+ out_hidden_states = out
379
+ out_scales = None
380
+ static_scale = False
381
+ if quantize is not None:
382
+ if scales is None:
383
+ out_scales = torch.empty(
384
+ (bs,), dtype=torch.float32, device=hidden_states.device
385
+ )
386
+ else:
387
+ out_scales = scales
388
+ static_scale = True
389
+
390
+ config = {
391
+ # 8 ele per thread (not tuned)
392
+ "num_warps": max(
393
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
394
+ ),
395
+ }
396
+
397
+ gelu_and_mul_kernel[(bs,)](
398
+ out_hidden_states,
399
+ out_scales,
400
+ hidden_states,
401
+ quant_max=torch.finfo(quantize).max if quantize is not None else None,
402
+ static_scale=static_scale,
403
+ hidden_dim=hidden_dim,
404
+ BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
405
+ **config,
406
+ )
407
+
408
+ if quantize is not None:
409
+ return out_hidden_states, out_scales
410
+ else:
411
+ return out_hidden_states, None
@@ -23,6 +23,7 @@ import triton.language as tl
23
23
  from torch import nn
24
24
 
25
25
  from sglang.srt.distributed import (
26
+ get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
27
28
  tensor_model_parallel_all_gather,
28
29
  )