sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -1,58 +1,81 @@
1
+ """Common utilities."""
2
+
1
3
  import base64
4
+ import fcntl
5
+ import logging
6
+ import multiprocessing
2
7
  import os
3
8
  import random
4
9
  import socket
5
- import sys
10
+ import struct
6
11
  import time
7
- import traceback
12
+ from importlib.metadata import PackageNotFoundError, version
8
13
  from io import BytesIO
9
14
  from typing import List, Optional
10
15
 
11
16
  import numpy as np
17
+ import psutil
12
18
  import requests
19
+ import rpyc
13
20
  import torch
14
- import torch.distributed as dist
21
+ import triton
22
+ from fastapi.responses import JSONResponse
23
+ from packaging import version as pkg_version
24
+ from rpyc.utils.server import ThreadedServer
25
+ from starlette.middleware.base import BaseHTTPMiddleware
15
26
 
16
- is_show_cost_time = False
27
+ logger = logging.getLogger(__name__)
17
28
 
18
29
 
19
- def mark_cost_time(func_name):
20
- def inner_func(func):
21
- def time_func(*args, **kwargs):
22
- if dist.get_rank() in [0, 1] and is_show_cost_time:
23
- torch.cuda.synchronize()
24
- start_time = time.time()
25
- ans = func(*args, **kwargs)
26
- torch.cuda.synchronize()
27
- print(func_name, "cost time:", (time.time() - start_time) * 1000)
28
- return ans
29
- else:
30
- torch.cuda.synchronize()
31
- ans = func(*args, **kwargs)
32
- torch.cuda.synchronize()
33
- return ans
30
+ show_time_cost = False
31
+ time_infos = {}
34
32
 
35
- return time_func
36
33
 
37
- return inner_func
34
+ def enable_show_time_cost():
35
+ global show_time_cost
36
+ show_time_cost = True
38
37
 
39
38
 
40
- time_mark = {}
39
+ class TimeInfo:
40
+ def __init__(self, name, interval=0.1, color=0, indent=0):
41
+ self.name = name
42
+ self.interval = interval
43
+ self.color = color
44
+ self.indent = indent
45
+
46
+ self.acc_time = 0
47
+ self.last_acc_time = 0
48
+
49
+ def check(self):
50
+ if self.acc_time - self.last_acc_time > self.interval:
51
+ self.last_acc_time = self.acc_time
52
+ return True
53
+ return False
41
54
 
55
+ def pretty_print(self):
56
+ print(f"\x1b[{self.color}m", end="")
57
+ print("-" * self.indent * 2, end="")
58
+ print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
42
59
 
43
- def mark_start(key):
60
+
61
+ def mark_start(name, interval=0.1, color=0, indent=0):
62
+ global time_infos, show_time_cost
63
+ if not show_time_cost:
64
+ return
44
65
  torch.cuda.synchronize()
45
- global time_mark
46
- time_mark[key] = time.time()
47
- return
66
+ if time_infos.get(name, None) is None:
67
+ time_infos[name] = TimeInfo(name, interval, color, indent)
68
+ time_infos[name].acc_time -= time.time()
48
69
 
49
70
 
50
- def mark_end(key, print_min_cost=0.0):
71
+ def mark_end(name):
72
+ global time_infos, show_time_cost
73
+ if not show_time_cost:
74
+ return
51
75
  torch.cuda.synchronize()
52
- global time_mark
53
- cost_time = (time.time() - time_mark[key]) * 1000
54
- if cost_time > print_min_cost:
55
- print(f"cost {key}:", cost_time)
76
+ time_infos[name].acc_time += time.time()
77
+ if time_infos[name].check():
78
+ time_infos[name].pretty_print()
56
79
 
57
80
 
58
81
  def calculate_time(show=False, min_cost_ms=0.0):
@@ -74,83 +97,86 @@ def calculate_time(show=False, min_cost_ms=0.0):
74
97
  return wrapper
75
98
 
76
99
 
77
- def set_random_seed(seed: int) -> None:
78
- random.seed(seed)
100
+ def get_available_gpu_memory(gpu_id, distributed=False):
101
+ """
102
+ Get available memory for cuda:gpu_id device.
103
+ When distributed is True, the available memory is the minimum available memory of all GPUs.
104
+ """
105
+ num_gpus = torch.cuda.device_count()
106
+ assert gpu_id < num_gpus
79
107
 
80
- torch.manual_seed(seed)
81
- if torch.cuda.is_available():
82
- torch.cuda.manual_seed_all(seed)
108
+ if torch.cuda.current_device() != gpu_id:
109
+ print(
110
+ f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
111
+ "which may cause useless memory allocation for torch CUDA context.",
112
+ )
83
113
 
114
+ torch.cuda.empty_cache()
115
+ free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
84
116
 
85
- def alloc_usable_network_port(num, used_list=()):
86
- port_list = []
87
- for port in range(10000, 65536):
88
- if port in used_list:
89
- continue
117
+ if distributed:
118
+ tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
119
+ torch.device("cuda", gpu_id)
120
+ )
121
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
122
+ free_gpu_memory = tensor.item()
90
123
 
91
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
92
- try:
93
- s.bind(("", port))
94
- port_list.append(port)
95
- except socket.error:
96
- pass
124
+ return free_gpu_memory / (1 << 30)
97
125
 
98
- if len(port_list) == num:
99
- return port_list
100
- return None
126
+
127
+ def set_random_seed(seed: int) -> None:
128
+ """Set the random seed for all libraries."""
129
+ random.seed(seed)
130
+ np.random.seed(seed)
131
+ torch.manual_seed(seed)
132
+ if torch.cuda.is_available():
133
+ torch.cuda.manual_seed_all(seed)
101
134
 
102
135
 
103
- def check_port(port):
136
+ def is_port_available(port):
137
+ """Return whether a port is available."""
104
138
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
105
139
  try:
106
140
  s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
107
141
  s.bind(("", port))
142
+ s.listen(1)
108
143
  return True
109
144
  except socket.error:
110
145
  return False
111
146
 
112
147
 
113
- def handle_port_init(
148
+ def allocate_init_ports(
114
149
  port: Optional[int] = None,
115
150
  additional_ports: Optional[List[int]] = None,
116
151
  tp_size: int = 1,
152
+ dp_size: int = 1,
117
153
  ):
118
- port = 30000 if port is None else port
119
- additional_ports = [] if additional_ports is None else additional_ports
120
- additional_ports = (
121
- [additional_ports] if isinstance(additional_ports, int) else additional_ports
122
- )
123
- # first check on server port
124
- if not check_port(port):
125
- new_port = alloc_usable_network_port(1, used_list=[port])[0]
126
- print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
127
- port = new_port
128
-
129
- # then we check on additional ports
130
- additional_unique_ports = set(additional_ports) - {port}
131
- # filter out ports that are already in use
132
- can_use_ports = [port for port in additional_unique_ports if check_port(port)]
133
-
134
- num_specified_ports = len(can_use_ports)
135
- if num_specified_ports < 4 + tp_size:
136
- addtional_can_use_ports = alloc_usable_network_port(
137
- num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
138
- )
139
- can_use_ports.extend(addtional_can_use_ports)
154
+ """Allocate ports for all connections."""
155
+ if additional_ports:
156
+ ret_ports = [port] + additional_ports
157
+ else:
158
+ ret_ports = [port]
140
159
 
141
- additional_ports = can_use_ports[: 4 + tp_size]
142
- return port, additional_ports
160
+ ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
161
+ cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
143
162
 
163
+ # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
164
+ num_ports_needed = 4 + dp_size * (1 + tp_size)
165
+ while len(ret_ports) < num_ports_needed:
166
+ if cur_port not in ret_ports and is_port_available(cur_port):
167
+ ret_ports.append(cur_port)
168
+ cur_port += 1
144
169
 
145
- def get_exception_traceback():
146
- etype, value, tb = sys.exc_info()
147
- err_str = "".join(traceback.format_exception(etype, value, tb))
148
- return err_str
170
+ if port is not None and ret_ports[0] != port:
171
+ logger.warn(
172
+ f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
173
+ )
149
174
 
175
+ return ret_ports[0], ret_ports[1:num_ports_needed]
150
176
 
151
- def get_int_token_logit_bias(tokenizer, vocab_size):
152
- from transformers import LlamaTokenizer, LlamaTokenizerFast
153
177
 
178
+ def get_int_token_logit_bias(tokenizer, vocab_size):
179
+ """Get the logit bias for integer-only tokens."""
154
180
  # a bug when model's vocab size > tokenizer.vocab_size
155
181
  vocab_size = tokenizer.vocab_size
156
182
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -164,14 +190,11 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
164
190
 
165
191
  def wrap_kernel_launcher(kernel):
166
192
  """A faster launcher for triton kernels."""
167
- import torch.distributed as dist
168
-
169
- if dist.is_initialized():
170
- rank = dist.get_rank()
171
- else:
172
- rank = 0
193
+ if int(triton.__version__.split(".")[0]) >= 3:
194
+ return None
173
195
 
174
- kernels = kernel.cache[rank].values()
196
+ gpu_id = torch.cuda.current_device()
197
+ kernels = kernel.cache[gpu_id].values()
175
198
  kernel = next(iter(kernels))
176
199
 
177
200
  # Different trition versions use different low-level names
@@ -231,20 +254,104 @@ def wrap_kernel_launcher(kernel):
231
254
 
232
255
 
233
256
  def is_multimodal_model(model):
234
- if isinstance(model, str):
235
- return "llava" in model or "yi-vl" in model
236
257
  from sglang.srt.model_config import ModelConfig
237
258
 
259
+ if isinstance(model, str):
260
+ model = model.lower()
261
+ return "llava" in model or "yi-vl" in model or "llava-next" in model
262
+
238
263
  if isinstance(model, ModelConfig):
239
264
  model_path = model.path.lower()
240
- return "llava" in model_path or "yi-vl" in model_path
241
- raise Exception("unrecognized type")
265
+ return (
266
+ "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
267
+ )
268
+
269
+ raise ValueError("unrecognized type")
270
+
271
+
272
+ def decode_video_base64(video_base64):
273
+ from PIL import Image
274
+
275
+ # Decode the base64 string
276
+ video_bytes = base64.b64decode(video_base64)
277
+
278
+ # Placeholder for the start indices of each PNG image
279
+ img_starts = []
280
+
281
+ frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
282
+
283
+ assert frame_format in [
284
+ "PNG",
285
+ "JPEG",
286
+ ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
287
+
288
+ if frame_format == "PNG":
289
+ # Find each PNG start signature to isolate images
290
+ i = 0
291
+ while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
292
+ # Check if we found the start of a PNG file
293
+ if (
294
+ video_bytes[i] == 0x89
295
+ and video_bytes[i + 1] == 0x50
296
+ and video_bytes[i + 2] == 0x4E
297
+ and video_bytes[i + 3] == 0x47
298
+ and video_bytes[i + 4] == 0x0D
299
+ and video_bytes[i + 5] == 0x0A
300
+ and video_bytes[i + 6] == 0x1A
301
+ and video_bytes[i + 7] == 0x0A
302
+ ):
303
+ img_starts.append(i)
304
+ i += 8 # Skip the PNG signature
305
+ else:
306
+ i += 1
307
+ else:
308
+ # Find each JPEG start (0xFFD8) to isolate images
309
+ i = 0
310
+ while (
311
+ i < len(video_bytes) - 1
312
+ ): # Adjusted for the length of the JPEG SOI signature
313
+ # Check if we found the start of a JPEG file
314
+ if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
315
+ img_starts.append(i)
316
+ # Move to the next byte to continue searching for the next image start
317
+ i += 2
318
+ else:
319
+ i += 1
320
+
321
+ frames = []
322
+ for start_idx in img_starts:
323
+ # Assuming each image is back-to-back, the end of one image is the start of another
324
+ # The last image goes until the end of the byte string
325
+ end_idx = (
326
+ img_starts[img_starts.index(start_idx) + 1]
327
+ if img_starts.index(start_idx) + 1 < len(img_starts)
328
+ else len(video_bytes)
329
+ )
330
+ img_bytes = video_bytes[start_idx:end_idx]
331
+
332
+ # Convert bytes to a PIL Image
333
+ img = Image.open(BytesIO(img_bytes))
334
+
335
+ # Convert PIL Image to a NumPy array
336
+ frame = np.array(img)
337
+
338
+ # Append the frame to the list of frames
339
+ frames.append(frame)
340
+
341
+ # Ensure there's at least one frame to avoid errors with np.stack
342
+ if frames:
343
+ return np.stack(frames, axis=0), img.size
344
+ else:
345
+ return np.array([]), (
346
+ 0,
347
+ 0,
348
+ ) # Return an empty array and size tuple if no frames were found
242
349
 
243
350
 
244
351
  def load_image(image_file):
245
352
  from PIL import Image
246
353
 
247
- image = None
354
+ image = image_size = None
248
355
 
249
356
  if image_file.startswith("http://") or image_file.startswith("https://"):
250
357
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -255,7 +362,265 @@ def load_image(image_file):
255
362
  elif image_file.startswith("data:"):
256
363
  image_file = image_file.split(",")[1]
257
364
  image = Image.open(BytesIO(base64.b64decode(image_file)))
365
+ elif image_file.startswith("video:"):
366
+ image_file = image_file.replace("video:", "")
367
+ image, image_size = decode_video_base64(image_file)
258
368
  else:
259
369
  image = Image.open(BytesIO(base64.b64decode(image_file)))
260
370
 
261
- return image
371
+ return image, image_size
372
+
373
+
374
+ def connect_rpyc_service(host, port):
375
+ repeat_count = 0
376
+ while repeat_count < 20:
377
+ try:
378
+ con = rpyc.connect(
379
+ host,
380
+ port,
381
+ config={
382
+ "allow_public_attrs": True,
383
+ "allow_pickle": True,
384
+ "sync_request_timeout": 3600,
385
+ },
386
+ )
387
+ break
388
+ except ConnectionRefusedError as e:
389
+ time.sleep(1)
390
+ repeat_count += 1
391
+ if repeat_count == 20:
392
+ raise RuntimeError(f"Connect rpyc error: {e}")
393
+
394
+ return con.root
395
+
396
+
397
+ def start_rpyc_service(service: rpyc.Service, port: int):
398
+ t = ThreadedServer(
399
+ service=service,
400
+ port=port,
401
+ protocol_config={
402
+ "allow_public_attrs": True,
403
+ "allow_pickle": True,
404
+ "sync_request_timeout": 3600,
405
+ },
406
+ )
407
+ t.logger.setLevel(logging.WARN)
408
+ t.start()
409
+
410
+
411
+ def start_rpyc_service_process(service: rpyc.Service, port: int):
412
+ proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
413
+ proc.start()
414
+ return proc
415
+
416
+
417
+ def suppress_other_loggers():
418
+ from vllm.logger import logger as vllm_default_logger
419
+
420
+ vllm_default_logger.setLevel(logging.WARN)
421
+ logging.getLogger("vllm.config").setLevel(logging.ERROR)
422
+ logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
423
+ logging.WARN
424
+ )
425
+ logging.getLogger("vllm.selector").setLevel(logging.WARN)
426
+ logging.getLogger("vllm.utils").setLevel(logging.WARN)
427
+
428
+
429
+ def assert_pkg_version(pkg: str, min_version: str, message: str):
430
+ try:
431
+ installed_version = version(pkg)
432
+ if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
433
+ raise Exception(
434
+ f"{pkg} is installed with version {installed_version}, which "
435
+ f"is less than the minimum required version {min_version}. " + message
436
+ )
437
+ except PackageNotFoundError:
438
+ raise Exception(
439
+ f"{pkg} with minimum required version {min_version} is not installed. "
440
+ + message
441
+ )
442
+
443
+
444
+ def kill_parent_process():
445
+ """Kill the parent process and all children of the parent process."""
446
+ current_process = psutil.Process()
447
+ parent_process = current_process.parent()
448
+ children = current_process.children(recursive=True)
449
+ for child in children:
450
+ if child.pid != current_process.pid:
451
+ os.kill(child.pid, 9)
452
+ os.kill(parent_process.pid, 9)
453
+
454
+
455
+ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
456
+ """
457
+ Monkey patch the slow p2p access check in vllm.
458
+ NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
459
+ """
460
+
461
+ import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
462
+
463
+ setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
464
+
465
+
466
+ def monkey_patch_vllm_dummy_weight_loader():
467
+ """
468
+ Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
469
+ """
470
+
471
+ from vllm.model_executor.model_loader.loader import (
472
+ CacheConfig,
473
+ DeviceConfig,
474
+ DummyModelLoader,
475
+ LoRAConfig,
476
+ ModelConfig,
477
+ MultiModalConfig,
478
+ ParallelConfig,
479
+ SchedulerConfig,
480
+ _initialize_model,
481
+ initialize_dummy_weights,
482
+ nn,
483
+ set_default_torch_dtype,
484
+ )
485
+
486
+ def load_model(
487
+ self,
488
+ *,
489
+ model_config: ModelConfig,
490
+ device_config: DeviceConfig,
491
+ lora_config: Optional[LoRAConfig],
492
+ multimodal_config: Optional[MultiModalConfig],
493
+ parallel_config: ParallelConfig,
494
+ scheduler_config: SchedulerConfig,
495
+ cache_config: CacheConfig,
496
+ ) -> nn.Module:
497
+ with set_default_torch_dtype(model_config.dtype):
498
+ with torch.device(device_config.device):
499
+ model = _initialize_model(
500
+ model_config,
501
+ self.load_config,
502
+ lora_config,
503
+ multimodal_config,
504
+ cache_config,
505
+ )
506
+
507
+ for _, module in model.named_modules():
508
+ quant_method = getattr(module, "quant_method", None)
509
+ if quant_method is not None:
510
+ quant_method.process_weights_after_loading(module)
511
+ # FIXME: Remove this after Mixtral is updated
512
+ # to use quant_method.
513
+ if hasattr(module, "process_weights_after_loading"):
514
+ module.process_weights_after_loading()
515
+
516
+ # NOTE(woosuk): For accurate performance evaluation, we assign
517
+ # random values to the weights.
518
+ initialize_dummy_weights(model)
519
+ return model.eval()
520
+
521
+ setattr(DummyModelLoader, "load_model", load_model)
522
+
523
+
524
+ API_KEY_HEADER_NAME = "X-API-Key"
525
+
526
+
527
+ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
528
+ def __init__(self, app, api_key: str):
529
+ super().__init__(app)
530
+ self.api_key = api_key
531
+
532
+ async def dispatch(self, request, call_next):
533
+ # extract API key from the request headers
534
+ api_key_header = request.headers.get(API_KEY_HEADER_NAME)
535
+ if not api_key_header or api_key_header != self.api_key:
536
+ return JSONResponse(
537
+ status_code=403,
538
+ content={"detail": "Invalid API Key"},
539
+ )
540
+ response = await call_next(request)
541
+ return response
542
+
543
+
544
+ def get_ip_address(ifname):
545
+ """
546
+ Get the IP address of a network interface.
547
+
548
+ :param ifname: Name of the network interface (e.g., 'eth0')
549
+ :return: IP address of the network interface
550
+ """
551
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
552
+ ip_address = fcntl.ioctl(
553
+ s.fileno(),
554
+ 0x8915, # SIOCGIFADDR
555
+ struct.pack("256s", bytes(ifname[:15], "utf-8")),
556
+ )[20:24]
557
+ return socket.inet_ntoa(ip_address)
558
+
559
+
560
+ def send_addrs_to_rank_0(model_port_args, server_args):
561
+ assert server_args.node_rank != 0 and server_args.dp_size == 1
562
+ import torch.distributed as dist
563
+
564
+ ifname = os.environ.get(
565
+ "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
566
+ )
567
+ ip_addr = get_ip_address(ifname)
568
+
569
+ num_tp_ports = server_args.tp_size // server_args.nnodes
570
+ model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
571
+ ip_addr = [int(x) for x in ip_addr.split(".")]
572
+ addrs_tensor = torch.tensor(
573
+ ip_addr + model_port_args.model_tp_ports, dtype=torch.int
574
+ )
575
+
576
+ init_method = f"tcp://{server_args.nccl_init_addr}"
577
+ dist.init_process_group(
578
+ backend="gloo",
579
+ init_method=init_method,
580
+ rank=server_args.node_rank,
581
+ world_size=server_args.nnodes,
582
+ )
583
+ dist.send(addrs_tensor, dst=0)
584
+ print(
585
+ f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
586
+ )
587
+
588
+ dist.barrier()
589
+ dist.destroy_process_group()
590
+
591
+
592
+ def receive_addrs(model_port_args, server_args):
593
+ assert server_args.node_rank == 0 and server_args.dp_size == 1
594
+ import torch.distributed as dist
595
+
596
+ ifname = os.environ.get(
597
+ "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
598
+ )
599
+ ip_addr = get_ip_address(ifname)
600
+
601
+ num_tp_ports = server_args.tp_size // server_args.nnodes
602
+ model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
603
+
604
+ init_method = f"tcp://{server_args.nccl_init_addr}"
605
+ dist.init_process_group(
606
+ backend="gloo",
607
+ init_method=init_method,
608
+ rank=server_args.node_rank,
609
+ world_size=server_args.nnodes,
610
+ )
611
+
612
+ for src_rank in range(1, server_args.nnodes):
613
+ tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
614
+ dist.recv(tensor, src=src_rank)
615
+ ip = ".".join([str(x) for x in tensor[:4].tolist()])
616
+ ports = tensor[4:].tolist()
617
+ model_port_args.model_tp_ips[
618
+ num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
619
+ ] = [ip] * num_tp_ports
620
+ model_port_args.model_tp_ports[
621
+ num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
622
+ ] = ports
623
+ print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
624
+
625
+ dist.barrier()
626
+ dist.destroy_process_group()