sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,14 @@
15
15
  import json
16
16
  import logging
17
17
  from enum import IntEnum, auto
18
- from typing import List, Optional
18
+ from typing import List, Optional, Union
19
19
 
20
+ import torch
20
21
  from transformers import PretrainedConfig
21
22
 
22
23
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
23
- from sglang.srt.utils import get_bool_env_var
24
+ from sglang.srt.layers.quantization import QUANTIZATION_METHODS
25
+ from sglang.srt.utils import get_bool_env_var, is_hip
24
26
 
25
27
  logger = logging.getLogger(__name__)
26
28
 
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
33
35
  class ModelConfig:
34
36
  def __init__(
35
37
  self,
36
- path: str,
38
+ model_path: str,
37
39
  trust_remote_code: bool = True,
38
40
  revision: Optional[str] = None,
39
41
  context_length: Optional[int] = None,
40
42
  model_override_args: Optional[dict] = None,
41
43
  is_embedding: Optional[bool] = None,
44
+ dtype: str = "auto",
45
+ quantization: Optional[str] = None,
42
46
  ) -> None:
47
+ self.model_path = model_path
48
+ self.revision = revision
49
+ self.quantization = quantization
43
50
  # Parse args
44
51
  self.model_override_args = json.loads(model_override_args)
45
52
  self.hf_config = get_config(
46
- path,
53
+ model_path,
47
54
  trust_remote_code=trust_remote_code,
48
55
  revision=revision,
49
56
  model_override_args=self.model_override_args,
@@ -56,6 +63,7 @@ class ModelConfig:
56
63
  )
57
64
  self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
58
65
  self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
66
+ self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
59
67
 
60
68
  # Derive context length
61
69
  derived_context_len = get_context_length(self.hf_text_config)
@@ -116,6 +124,8 @@ class ModelConfig:
116
124
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
117
125
  self.vocab_size = self.hf_text_config.vocab_size
118
126
 
127
+ self._verify_quantization()
128
+
119
129
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
120
130
  def get_total_num_kv_heads(self) -> int:
121
131
  """Returns the total number of KV heads."""
@@ -174,6 +184,86 @@ class ModelConfig:
174
184
  # parallel size so each GPU has at least one KV head.
175
185
  return max(1, total_num_kv_heads // tensor_parallel_size)
176
186
 
187
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
188
+ def _parse_quant_hf_config(self):
189
+ quant_cfg = getattr(self.hf_config, "quantization_config", None)
190
+ if quant_cfg is None:
191
+ # compressed-tensors uses a "compression_config" key
192
+ quant_cfg = getattr(self.hf_config, "compression_config", None)
193
+ return quant_cfg
194
+
195
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
196
+ def _verify_quantization(self) -> None:
197
+ supported_quantization = [*QUANTIZATION_METHODS]
198
+ rocm_supported_quantization = [
199
+ "awq",
200
+ "gptq",
201
+ "fp8",
202
+ "compressed_tensors",
203
+ "compressed-tensors",
204
+ "fbgemm_fp8",
205
+ ]
206
+ optimized_quantization_methods = [
207
+ "fp8",
208
+ "marlin",
209
+ "modelopt",
210
+ "gptq_marlin_24",
211
+ "gptq_marlin",
212
+ "awq_marlin",
213
+ "fbgemm_fp8",
214
+ "compressed_tensors",
215
+ "compressed-tensors",
216
+ "experts_int8",
217
+ ]
218
+ if self.quantization is not None:
219
+ self.quantization = self.quantization.lower()
220
+
221
+ # Parse quantization method from the HF model config, if available.
222
+ quant_cfg = self._parse_quant_hf_config()
223
+
224
+ if quant_cfg is not None:
225
+ quant_method = quant_cfg.get("quant_method", "").lower()
226
+
227
+ # Detect which checkpoint is it
228
+ for _, method in QUANTIZATION_METHODS.items():
229
+ quantization_override = method.override_quantization_method(
230
+ quant_cfg, self.quantization
231
+ )
232
+ if quantization_override:
233
+ quant_method = quantization_override
234
+ self.quantization = quantization_override
235
+ break
236
+
237
+ # Verify quantization configurations.
238
+ if self.quantization is None:
239
+ self.quantization = quant_method
240
+ elif self.quantization != quant_method:
241
+ raise ValueError(
242
+ "Quantization method specified in the model config "
243
+ f"({quant_method}) does not match the quantization "
244
+ f"method specified in the `quantization` argument "
245
+ f"({self.quantization})."
246
+ )
247
+
248
+ if self.quantization is not None:
249
+ if self.quantization not in supported_quantization:
250
+ raise ValueError(
251
+ f"Unknown quantization method: {self.quantization}. Must "
252
+ f"be one of {supported_quantization}."
253
+ )
254
+ if is_hip() and self.quantization not in rocm_supported_quantization:
255
+ raise ValueError(
256
+ f"{self.quantization} quantization is currently not "
257
+ f"supported in ROCm."
258
+ )
259
+ if self.quantization not in optimized_quantization_methods:
260
+ logger.warning(
261
+ "%s quantization is not fully "
262
+ "optimized yet. The speed can be slower than "
263
+ "non-quantized models.",
264
+ self.quantization,
265
+ )
266
+
177
267
 
178
268
  def get_hf_text_config(config: PretrainedConfig):
179
269
  """Get the "sub" config relevant to llm for multi modal models.
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
183
273
  if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
184
274
  # We support non-hf version of llava models, so we do not want to
185
275
  # read the wrong values from the unused default text_config.
276
+ # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
277
+ # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
278
+ setattr(config, "torch_dtype", torch.float16)
186
279
  return config
187
280
 
188
281
  if hasattr(config, "text_config"):
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
195
288
  return config
196
289
 
197
290
 
291
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
292
+ _STR_DTYPE_TO_TORCH_DTYPE = {
293
+ "half": torch.float16,
294
+ "float16": torch.float16,
295
+ "float": torch.float32,
296
+ "float32": torch.float32,
297
+ "bfloat16": torch.bfloat16,
298
+ }
299
+
300
+
301
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
302
+ def _get_and_verify_dtype(
303
+ config: PretrainedConfig,
304
+ dtype: Union[str, torch.dtype],
305
+ ) -> torch.dtype:
306
+ # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
307
+ # because config.torch_dtype can be None.
308
+ config_dtype = getattr(config, "torch_dtype", None)
309
+ if config_dtype is None:
310
+ config_dtype = torch.float32
311
+
312
+ if isinstance(dtype, str):
313
+ dtype = dtype.lower()
314
+ if dtype == "auto":
315
+ if config_dtype == torch.float32:
316
+ if config.model_type == "gemma2":
317
+ logger.info(
318
+ "For Gemma 2, we downcast float32 to bfloat16 instead "
319
+ "of float16 by default. Please specify `dtype` if you "
320
+ "want to use float16."
321
+ )
322
+ torch_dtype = torch.bfloat16
323
+ else:
324
+ # Following the common practice, we use float16 for float32
325
+ # models.
326
+ torch_dtype = torch.float16
327
+ else:
328
+ torch_dtype = config_dtype
329
+ else:
330
+ if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
331
+ raise ValueError(f"Unknown dtype: {dtype}")
332
+ torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
333
+ elif isinstance(dtype, torch.dtype):
334
+ torch_dtype = dtype
335
+ else:
336
+ raise ValueError(f"Unknown dtype: {dtype}")
337
+
338
+ # Verify the dtype.
339
+ if torch_dtype != config_dtype:
340
+ if torch_dtype == torch.float32:
341
+ # Upcasting to float32 is allowed.
342
+ logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
343
+ pass
344
+ elif config_dtype == torch.float32:
345
+ # Downcasting from float32 to float16 or bfloat16 is allowed.
346
+ logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
347
+ pass
348
+ else:
349
+ # Casting between float16 and bfloat16 is allowed with a warning.
350
+ logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
351
+
352
+ return torch_dtype
353
+
354
+
198
355
  def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
199
356
  # We have two ways to determine whether a model is a generative model.
200
357
  # 1. Check the model architectue
@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
121
121
  self.attention_dropout = attention_dropout
122
122
  self.rope_scaling = rope_scaling
123
123
 
124
- # NOTE: the following section from original transformers config
125
- # for Qwen2-VL is commented out to address rope config loading issue
126
- #
127
- # if self.rope_scaling is not None and "type" in self.rope_scaling:
128
- # if self.rope_scaling["type"] == "mrope":
129
- # self.rope_scaling["type"] = "default"
130
- # self.rope_scaling["rope_type"] = self.rope_scaling["type"]
131
- # rope_config_validation(self)
124
+ # NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
125
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
126
+ if self.rope_scaling["type"] == "mrope":
127
+ self.rope_scaling["type"] = "default"
128
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
132
129
 
133
130
  super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
@@ -152,7 +152,12 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
152
152
  raise ValueError(f"Invalid key_type: {key_type}")
153
153
 
154
154
  try:
155
- guide = RegexGuide(regex, self.outlines_tokenizer)
155
+ if hasattr(RegexGuide, "from_regex"):
156
+ # outlines >= 0.1.1
157
+ guide = RegexGuide.from_regex(regex, self.outlines_tokenizer)
158
+ else:
159
+ # outlines <= 0.0.46
160
+ guide = RegexGuide(regex, self.outlines_tokenizer)
156
161
  except interegular.patterns.InvalidSyntax as e:
157
162
  logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
158
163
  return None
@@ -23,7 +23,14 @@ from collections import defaultdict
23
23
  import interegular
24
24
  from interegular import InvalidSyntax
25
25
  from outlines.caching import cache as disk_cache
26
- from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
26
+
27
+ try:
28
+ # outlines >= 0.1.0
29
+ from outlines_core.fsm.outlines_core_rs import FSMInfo
30
+ from outlines_core.fsm.regex import make_byte_level_fsm, make_deterministic_fsm
31
+ except ImportError:
32
+ # outlines <= 0.0.46
33
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
27
34
 
28
35
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
29
36
 
@@ -0,0 +1,3 @@
1
+ from .communication_op import *
2
+ from .parallel_state import *
3
+ from .utils import *
@@ -0,0 +1,34 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ import torch
5
+ import torch.distributed
6
+
7
+ from .parallel_state import get_tp_group
8
+
9
+
10
+ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
11
+ """All-reduce the input tensor across model parallel group."""
12
+ return get_tp_group().all_reduce(input_)
13
+
14
+
15
+ def tensor_model_parallel_all_gather(
16
+ input_: torch.Tensor, dim: int = -1
17
+ ) -> torch.Tensor:
18
+ """All-gather the input tensor across model parallel group."""
19
+ return get_tp_group().all_gather(input_, dim)
20
+
21
+
22
+ def tensor_model_parallel_gather(
23
+ input_: torch.Tensor, dst: int = 0, dim: int = -1
24
+ ) -> Optional[torch.Tensor]:
25
+ """Gather the input tensor across model parallel group."""
26
+ return get_tp_group().gather(input_, dst, dim)
27
+
28
+
29
+ def broadcast_tensor_dict(
30
+ tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
31
+ ):
32
+ if not torch.distributed.is_initialized():
33
+ return tensor_dict
34
+ return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
@@ -0,0 +1,182 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
2
+ """This file is a pure Python wrapper for the cudart library.
3
+ It avoids the need to compile a separate shared library, and is
4
+ convenient for use when we just need to call a few functions.
5
+ """
6
+
7
+ import ctypes
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ # this line makes it possible to directly load `libcudart.so` using `ctypes`
13
+ import torch # noqa
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # === export types and functions from cudart to Python ===
18
+ # for the original cudart definition, please check
19
+ # https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
20
+
21
+ cudaError_t = ctypes.c_int
22
+ cudaMemcpyKind = ctypes.c_int
23
+
24
+
25
+ class cudaIpcMemHandle_t(ctypes.Structure):
26
+ _fields_ = [("internal", ctypes.c_byte * 128)]
27
+
28
+
29
+ @dataclass
30
+ class Function:
31
+ name: str
32
+ restype: Any
33
+ argtypes: List[Any]
34
+
35
+
36
+ def find_loaded_library(lib_name) -> Optional[str]:
37
+ """
38
+ According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
39
+ the file `/proc/self/maps` contains the memory maps of the process, which includes the
40
+ shared libraries loaded by the process. We can use this file to find the path of the
41
+ a loaded library.
42
+ """ # noqa
43
+ found = False
44
+ with open("/proc/self/maps") as f:
45
+ for line in f:
46
+ if lib_name in line:
47
+ found = True
48
+ break
49
+ if not found:
50
+ # the library is not loaded in the current process
51
+ return None
52
+ # if lib_name is libcudart, we need to match a line with:
53
+ # address /path/to/libcudart-hash.so.11.0
54
+ start = line.index("/")
55
+ path = line[start:].strip()
56
+ filename = path.split("/")[-1]
57
+ assert filename.rpartition(".so")[0].startswith(
58
+ lib_name
59
+ ), f"Unexpected filename: {filename} for library {lib_name}"
60
+ return path
61
+
62
+
63
+ class CudaRTLibrary:
64
+ exported_functions = [
65
+ # ​cudaError_t cudaSetDevice ( int device )
66
+ Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
67
+ # cudaError_t cudaDeviceSynchronize ( void )
68
+ Function("cudaDeviceSynchronize", cudaError_t, []),
69
+ # ​cudaError_t cudaDeviceReset ( void )
70
+ Function("cudaDeviceReset", cudaError_t, []),
71
+ # const char* cudaGetErrorString ( cudaError_t error )
72
+ Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
73
+ # ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
74
+ Function(
75
+ "cudaMalloc",
76
+ cudaError_t,
77
+ [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
78
+ ),
79
+ # ​cudaError_t cudaFree ( void* devPtr )
80
+ Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
81
+ # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
82
+ Function(
83
+ "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
84
+ ),
85
+ # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
86
+ Function(
87
+ "cudaMemcpy",
88
+ cudaError_t,
89
+ [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
90
+ ),
91
+ # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
92
+ Function(
93
+ "cudaIpcGetMemHandle",
94
+ cudaError_t,
95
+ [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
96
+ ),
97
+ # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
98
+ Function(
99
+ "cudaIpcOpenMemHandle",
100
+ cudaError_t,
101
+ [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
102
+ ),
103
+ ]
104
+
105
+ # class attribute to store the mapping from the path to the library
106
+ # to avoid loading the same library multiple times
107
+ path_to_library_cache: Dict[str, Any] = {}
108
+
109
+ # class attribute to store the mapping from library path
110
+ # to the corresponding dictionary
111
+ path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
112
+
113
+ def __init__(self, so_file: Optional[str] = None):
114
+ if so_file is None:
115
+ so_file = find_loaded_library("libcudart")
116
+ assert so_file is not None, "libcudart is not loaded in the current process"
117
+ if so_file not in CudaRTLibrary.path_to_library_cache:
118
+ lib = ctypes.CDLL(so_file)
119
+ CudaRTLibrary.path_to_library_cache[so_file] = lib
120
+ self.lib = CudaRTLibrary.path_to_library_cache[so_file]
121
+
122
+ if so_file not in CudaRTLibrary.path_to_dict_mapping:
123
+ _funcs = {}
124
+ for func in CudaRTLibrary.exported_functions:
125
+ f = getattr(self.lib, func.name)
126
+ f.restype = func.restype
127
+ f.argtypes = func.argtypes
128
+ _funcs[func.name] = f
129
+ CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
130
+ self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
131
+
132
+ def CUDART_CHECK(self, result: cudaError_t) -> None:
133
+ if result != 0:
134
+ error_str = self.cudaGetErrorString(result)
135
+ raise RuntimeError(f"CUDART error: {error_str}")
136
+
137
+ def cudaGetErrorString(self, error: cudaError_t) -> str:
138
+ return self.funcs["cudaGetErrorString"](error).decode("utf-8")
139
+
140
+ def cudaSetDevice(self, device: int) -> None:
141
+ self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
142
+
143
+ def cudaDeviceSynchronize(self) -> None:
144
+ self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
145
+
146
+ def cudaDeviceReset(self) -> None:
147
+ self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
148
+
149
+ def cudaMalloc(self, size: int) -> ctypes.c_void_p:
150
+ devPtr = ctypes.c_void_p()
151
+ self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
152
+ return devPtr
153
+
154
+ def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
155
+ self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
156
+
157
+ def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
158
+ self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
159
+
160
+ def cudaMemcpy(
161
+ self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
162
+ ) -> None:
163
+ cudaMemcpyDefault = 4
164
+ kind = cudaMemcpyDefault
165
+ self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
166
+
167
+ def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
168
+ handle = cudaIpcMemHandle_t()
169
+ self.CUDART_CHECK(
170
+ self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
171
+ )
172
+ return handle
173
+
174
+ def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
175
+ cudaIpcMemLazyEnablePeerAccess = 1
176
+ devPtr = ctypes.c_void_p()
177
+ self.CUDART_CHECK(
178
+ self.funcs["cudaIpcOpenMemHandle"](
179
+ ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
180
+ )
181
+ )
182
+ return devPtr