sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +1 -1
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +4 -2
  11. sglang/srt/layers/linear.py +159 -55
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -3
  15. sglang/srt/layers/parameter.py +431 -0
  16. sglang/srt/layers/quantization/__init__.py +3 -2
  17. sglang/srt/layers/quantization/fp8.py +1 -1
  18. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  20. sglang/srt/managers/cache_controller.py +307 -0
  21. sglang/srt/managers/data_parallel_controller.py +2 -0
  22. sglang/srt/managers/schedule_batch.py +7 -1
  23. sglang/srt/managers/scheduler.py +10 -6
  24. sglang/srt/managers/session_controller.py +1 -1
  25. sglang/srt/managers/tokenizer_manager.py +6 -2
  26. sglang/srt/mem_cache/memory_pool.py +206 -1
  27. sglang/srt/metrics/collector.py +22 -30
  28. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  29. sglang/srt/model_executor/forward_batch_info.py +20 -15
  30. sglang/srt/model_executor/model_runner.py +10 -4
  31. sglang/srt/models/chatglm.py +1 -1
  32. sglang/srt/models/dbrx.py +1 -1
  33. sglang/srt/models/grok.py +25 -16
  34. sglang/srt/models/llama.py +9 -2
  35. sglang/srt/sampling/sampling_batch_info.py +1 -0
  36. sglang/srt/server.py +11 -8
  37. sglang/srt/server_args.py +12 -1
  38. sglang/srt/speculative/eagle_utils.py +93 -85
  39. sglang/srt/speculative/eagle_worker.py +47 -33
  40. sglang/srt/utils.py +32 -5
  41. sglang/test/test_programs.py +23 -1
  42. sglang/test/test_utils.py +36 -7
  43. sglang/version.py +1 -1
  44. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +6 -7
  45. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +48 -43
  46. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  47. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  48. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,307 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Copyright 2023-2025 SGLang Team
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+ import logging
19
+ import threading
20
+ from queue import PriorityQueue, Queue
21
+ from typing import Optional
22
+
23
+ import torch
24
+
25
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class CacheOperation:
31
+
32
+ counter = 0
33
+
34
+ def __init__(
35
+ self,
36
+ host_indices: torch.Tensor,
37
+ device_indices: torch.Tensor,
38
+ node_id: int,
39
+ priority: Optional[int] = None,
40
+ ):
41
+ self.host_indices = host_indices
42
+ self.device_indices = device_indices
43
+ self.node_ids = [node_id]
44
+ self.data = None
45
+
46
+ self.id = CacheOperation.counter
47
+ CacheOperation.counter += 1
48
+ # default priority is the order of creation
49
+ self.priority = priority if priority is not None else self.id
50
+
51
+ def merge(self, other: "CacheOperation") -> None:
52
+ # multiple operations can be merged into a single operation for batch processing
53
+ self.host_indices = torch.cat([self.host_indices, other.host_indices])
54
+ self.device_indices = torch.cat([self.device_indices, other.device_indices])
55
+ self.priority = min(self.priority, other.priority)
56
+ self.node_ids.extend(other.node_ids)
57
+
58
+ def __lt__(self, other: "CacheOperation"):
59
+ return self.priority < other.priority
60
+
61
+
62
+ class TransferBuffer:
63
+ """
64
+ Overlapping buffer preparation and transfer operations to improve throughput.
65
+ """
66
+
67
+ def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None:
68
+ self.buffers = Queue(maxsize=buffer_count)
69
+ # todo: adjust the buffer size based on throughput profile of the system
70
+ self.max_buffer_size = max_buffer_size
71
+
72
+ def full(self) -> bool:
73
+ return self.buffers.full()
74
+
75
+ def empty(self) -> bool:
76
+ return self.buffers.empty()
77
+
78
+ def put(self, item, block=True) -> None:
79
+ self.buffers.put(item, block=block)
80
+
81
+ def get(self, block=True) -> Optional[CacheOperation]:
82
+ try:
83
+ return self.buffers.get(block=block)
84
+ except Exception as e:
85
+ logger.error(e)
86
+
87
+
88
+ class HiCacheController:
89
+
90
+ def __init__(
91
+ self,
92
+ mem_pool_device: MHATokenToKVPool,
93
+ mem_pool_host: MLATokenToKVPoolHost,
94
+ write_policy: str = "write_through_selective",
95
+ ):
96
+
97
+ self.mem_pool_device = mem_pool_device
98
+ self.mem_pool_host = mem_pool_host
99
+ self.write_policy = write_policy
100
+
101
+ if write_policy not in [
102
+ "write_through",
103
+ "write_through_selective",
104
+ "write_back",
105
+ ]:
106
+ raise ValueError(f"Invalid write policy: {write_policy}")
107
+
108
+ self.write_queue = PriorityQueue()
109
+ self.load_queue = PriorityQueue()
110
+
111
+ self.ack_write_queue = Queue()
112
+ self.ack_load_queue = Queue()
113
+
114
+ self.write_buffer = TransferBuffer()
115
+ self.load_buffer = TransferBuffer()
116
+
117
+ self.write_stream = torch.cuda.Stream()
118
+ self.load_stream = torch.cuda.Stream()
119
+
120
+ self.write_thread = threading.Thread(
121
+ target=self.write_thread_func_buffer, daemon=True
122
+ )
123
+ self.load_thread = threading.Thread(
124
+ target=self.load_thread_func_buffer, daemon=True
125
+ )
126
+ self.write_thread.start()
127
+ self.load_thread.start()
128
+
129
+ def write(
130
+ self,
131
+ device_indices: torch.Tensor,
132
+ priority: Optional[int] = None,
133
+ node_id: int = 0,
134
+ ) -> Optional[torch.Tensor]:
135
+ """
136
+ Back up KV caches from device memory to host memory.
137
+ """
138
+ host_indices = self.mem_pool_host.alloc(len(device_indices))
139
+ if host_indices is None:
140
+ return None
141
+ self.write_queue.put(
142
+ CacheOperation(host_indices, device_indices, node_id, priority)
143
+ )
144
+ self.mem_pool_host.protect_write(host_indices)
145
+ return host_indices
146
+
147
+ def load(
148
+ self,
149
+ host_indices: torch.Tensor,
150
+ priority: Optional[int] = None,
151
+ node_id: int = 0,
152
+ ) -> Optional[torch.Tensor]:
153
+ """
154
+ Load KV caches from host memory to device memory.
155
+ """
156
+ device_indices = self.mem_pool_device.alloc(len(host_indices))
157
+ if device_indices is None:
158
+ return None
159
+ self.load_queue.put(
160
+ CacheOperation(host_indices, device_indices, node_id, priority)
161
+ )
162
+ self.mem_pool_host.protect_load(host_indices)
163
+ return device_indices
164
+
165
+ def write_thread_func_direct(self):
166
+ """
167
+ Directly write through KV caches to host memory without buffering.
168
+ """
169
+ with torch.cuda.stream(self.write_stream):
170
+ while True:
171
+ try:
172
+ operation = self.write_queue.get(block=True)
173
+ operation.data = self.mem_pool_device.get_flat_data(
174
+ operation.device_indices
175
+ )
176
+ self.mem_pool_host.transfer(operation.host_indices, operation.data)
177
+ self.mem_pool_host.complete_io(operation.host_indices)
178
+ for node_id in operation.node_ids:
179
+ self.ack_write_queue.put(node_id)
180
+ except Exception as e:
181
+ logger.error(e)
182
+
183
+ def load_thread_func_direct(self):
184
+ """
185
+ Directly load KV caches from host memory to device memory without buffering.
186
+ """
187
+ with torch.cuda.stream(self.load_stream):
188
+ while True:
189
+ try:
190
+ operation = self.load_queue.get(block=True)
191
+ operation.data = self.mem_pool_host.get_flat_data(
192
+ operation.host_indices
193
+ )
194
+ self.mem_pool_device.transfer(
195
+ operation.device_indices, operation.data
196
+ )
197
+ self.mem_pool_host.complete_io(operation.host_indices)
198
+ for node_id in operation.node_ids:
199
+ self.ack_load_queue.put(node_id)
200
+ except Exception as e:
201
+ logger.error(e)
202
+
203
+ def write_aux_func(self, no_wait=False):
204
+ """
205
+ Auxiliary function to prepare the buffer for write operations.
206
+ """
207
+ buffer = None
208
+ while True:
209
+ try:
210
+ operation = self.write_queue.get(block=True)
211
+ if buffer is None:
212
+ buffer = operation
213
+ else:
214
+ buffer.merge(operation)
215
+ if (
216
+ no_wait
217
+ or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
218
+ or self.write_queue.empty()
219
+ or self.write_buffer.empty()
220
+ ):
221
+ assert (
222
+ buffer.device_indices.is_cuda
223
+ ), "Device indices should be on GPU"
224
+ buffer.data = self.mem_pool_device.get_flat_data(
225
+ buffer.device_indices
226
+ ).contiguous()
227
+ self.write_buffer.put(buffer, block=True)
228
+ buffer = None
229
+ except Exception as e:
230
+ logger.error(e)
231
+
232
+ def load_aux_func(self):
233
+ """
234
+ Auxiliary function to prepare the buffer for load operations.
235
+ """
236
+ buffer = None
237
+ while True:
238
+ try:
239
+ operation = self.load_queue.get(block=True)
240
+ if buffer is None:
241
+ buffer = operation
242
+ else:
243
+ buffer.merge(operation)
244
+ if (
245
+ len(buffer.host_indices) >= self.load_buffer.max_buffer_size
246
+ or self.load_queue.empty()
247
+ or self.load_buffer.empty()
248
+ ):
249
+ buffer.data = (
250
+ self.mem_pool_host.get_flat_data(buffer.host_indices)
251
+ .contiguous()
252
+ .pin_memory()
253
+ )
254
+ self.load_buffer.put(buffer, block=True)
255
+ buffer = None
256
+ except Exception as e:
257
+ logger.error(e)
258
+
259
+ def write_thread_func_buffer(self):
260
+ aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
261
+ aux_thread.start()
262
+ with torch.cuda.stream(self.write_stream):
263
+ while True:
264
+ operation = self.write_buffer.get()
265
+ if operation is None:
266
+ continue
267
+ self.mem_pool_host.transfer(operation.host_indices, operation.data)
268
+ self.mem_pool_host.complete_io(operation.host_indices)
269
+ for node_id in operation.node_ids:
270
+ self.ack_write_queue.put(node_id)
271
+
272
+ def load_thread_func_buffer(self):
273
+ aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
274
+ aux_thread.start()
275
+ with torch.cuda.stream(self.load_stream):
276
+ while True:
277
+ operation = self.load_buffer.get()
278
+ if operation is None:
279
+ continue
280
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
281
+ self.mem_pool_host.complete_io(operation.host_indices)
282
+ for node_id in operation.node_ids:
283
+ self.ack_load_queue.put(node_id)
284
+
285
+ def evict_device(
286
+ self, device_indices: torch.Tensor, host_indices: torch.Tensor
287
+ ) -> int:
288
+ if self.mem_pool_host.is_synced(host_indices):
289
+ self.mem_pool_device.free(device_indices)
290
+ self.mem_pool_host.update_backup(host_indices)
291
+ return len(device_indices)
292
+ else:
293
+ raise ValueError(
294
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
295
+ )
296
+
297
+ def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
298
+ if not backup_only:
299
+ raise ValueError("Other eviction policies are not supported yet.")
300
+
301
+ if self.mem_pool_host.is_backup(host_indices):
302
+ self.mem_pool_host.free(host_indices)
303
+ return len(host_indices)
304
+ else:
305
+ raise ValueError(
306
+ f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
307
+ )
@@ -20,6 +20,7 @@ import threading
20
20
  from enum import Enum, auto
21
21
 
22
22
  import psutil
23
+ import setproctitle
23
24
  import zmq
24
25
 
25
26
  from sglang.srt.managers.io_struct import (
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
230
231
  port_args: PortArgs,
231
232
  pipe_writer,
232
233
  ):
234
+ setproctitle.setproctitle("sglang::data_parallel_controller")
233
235
  configure_logger(server_args)
234
236
  parent_process = psutil.Process().parent()
235
237
 
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
44
44
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
45
45
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
46
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
47
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
49
  from sglang.srt.sampling.sampling_params import SamplingParams
50
50
  from sglang.srt.server_args import ServerArgs
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
1163
1163
  input_embeds=self.input_embeds,
1164
1164
  spec_algorithm=self.spec_algorithm,
1165
1165
  spec_info=self.spec_info,
1166
+ capture_hidden_mode=(
1167
+ getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
1168
+ if self.spec_info
1169
+ else CaptureHiddenMode.NULL
1170
+ ),
1166
1171
  )
1167
1172
 
1168
1173
  def copy(self):
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
1237
1242
  # Speculative decoding
1238
1243
  spec_algorithm: SpeculativeAlgorithm = None
1239
1244
  spec_info: Optional[SpecInfo] = None
1245
+ capture_hidden_mode: CaptureHiddenMode = None
1240
1246
 
1241
1247
 
1242
1248
  @triton.jit
@@ -962,10 +962,13 @@ class Scheduler:
962
962
  self.tp_worker.forward_batch_generation(model_worker_batch)
963
963
  )
964
964
  else:
965
- logits_output, next_token_ids, model_worker_batch, spec_info = (
966
- self.draft_worker.forward_batch_speculative_generation(batch)
967
- )
968
- batch.spec_info = spec_info
965
+ (
966
+ logits_output,
967
+ next_token_ids,
968
+ model_worker_batch,
969
+ num_accepted_tokens,
970
+ ) = self.draft_worker.forward_batch_speculative_generation(batch)
971
+ self.num_generated_tokens += num_accepted_tokens
969
972
  elif batch.forward_mode.is_idle():
970
973
  model_worker_batch = batch.get_model_worker_batch()
971
974
  self.tp_worker.forward_batch_idle(model_worker_batch)
@@ -1513,8 +1516,9 @@ class Scheduler:
1513
1516
  return success, message
1514
1517
 
1515
1518
  def update_weights_from_distributed(
1516
- self, recv_req: UpdateWeightsFromDistributedReqInput
1517
- ):
1519
+ self,
1520
+ recv_req: UpdateWeightsFromDistributedReqInput,
1521
+ ) -> Tuple[bool, str]:
1518
1522
  """Update the online model parameter."""
1519
1523
  success, message = self.tp_worker.update_weights_from_distributed(recv_req)
1520
1524
  if success:
@@ -99,7 +99,7 @@ class Session:
99
99
 
100
100
  if last_req is not None:
101
101
  # trim bos token if it is an append
102
- if req.input_ids[0] == tokenizer.bos_token_id:
102
+ if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
103
103
  req.input_ids = req.input_ids[1:]
104
104
 
105
105
  input_ids = (
@@ -688,7 +688,7 @@ class TokenizerManager:
688
688
  if self.enable_metrics:
689
689
  completion_tokens = (
690
690
  recv_obj.completion_tokens[i]
691
- if recv_obj.completion_tokens
691
+ if getattr(recv_obj, "completion_tokens", None)
692
692
  else 0
693
693
  )
694
694
 
@@ -716,7 +716,11 @@ class TokenizerManager:
716
716
  time.time() - state.created_time
717
717
  )
718
718
  # Compute time_per_output_token for the non-streaming case
719
- if not state.obj.stream and completion_tokens >= 1:
719
+ if (
720
+ hasattr(state.obj, "stream")
721
+ and not state.obj.stream
722
+ and completion_tokens >= 1
723
+ ):
720
724
  self.metrics_collector.observe_time_per_output_token(
721
725
  (time.time() - state.created_time)
722
726
  / completion_tokens
@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
22
22
  """
23
23
 
24
24
  import logging
25
+ import threading
26
+ from enum import IntEnum
27
+ from functools import wraps
25
28
  from typing import List, Tuple, Union
26
29
 
30
+ import psutil
27
31
  import torch
28
32
 
29
33
  from sglang.srt.layers.radix_attention import RadixAttention
30
- from sglang.srt.utils import get_compiler_backend
34
+ from sglang.srt.utils import debug_timing, get_compiler_backend
31
35
 
32
36
  logger = logging.getLogger(__name__)
33
37
 
@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
213
217
  del self.k_buffer
214
218
  del self.v_buffer
215
219
 
220
+ # Todo: different memory layout
221
+ def get_flat_data(self, indices):
222
+ # prepare a large chunk of contiguous data for efficient transfer
223
+ flatten = torch.stack(
224
+ [
225
+ torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
226
+ torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
227
+ ]
228
+ )
229
+ return flatten
230
+
231
+ @debug_timing
232
+ def transfer(self, indices, flat_data):
233
+ # transfer prepared data from host to device
234
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
235
+ k_data, v_data = flat_data[0], flat_data[1]
236
+ for i in range(self.layer_num):
237
+ self.k_buffer[i][indices] = k_data[i]
238
+ self.v_buffer[i][indices] = v_data[i]
239
+
216
240
  def get_key_buffer(self, layer_id: int):
217
241
  if self.store_dtype != self.dtype:
218
242
  return self.k_buffer[layer_id].view(self.dtype)
@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
361
385
  self.k_buffer[layer_id][loc] = cache_k
362
386
  self.v_buffer[layer_id][loc] = cache_v
363
387
  self.label_buffer[layer_id][loc] = cache_label
388
+
389
+
390
+ class MemoryStateInt(IntEnum):
391
+ IDLE = 0
392
+ RESERVED = 1
393
+ PROTECTED = 2
394
+ SYNCED = 3
395
+ BACKUP = 4
396
+
397
+
398
+ def synchronized(func):
399
+ @wraps(func)
400
+ def wrapper(self, *args, **kwargs):
401
+ with self.lock:
402
+ return func(self, *args, **kwargs)
403
+
404
+ return wrapper
405
+
406
+
407
+ class MLATokenToKVPoolHost:
408
+
409
+ def __init__(
410
+ self,
411
+ device_pool: MHATokenToKVPool,
412
+ host_to_device_ratio: float = 2.0,
413
+ pin_memory: bool = False, # no need to use pin memory with the double buffering
414
+ device: str = "cpu",
415
+ ):
416
+ assert (
417
+ host_to_device_ratio >= 1
418
+ ), "The host memory should be larger than the device memory with the current protocol"
419
+ # todo, other ways of configuring the size
420
+
421
+ self.device_pool = device_pool
422
+ self.host_to_device_ratio = host_to_device_ratio
423
+ self.pin_memory = pin_memory
424
+ self.device = device
425
+
426
+ self.size = int(device_pool.size * host_to_device_ratio)
427
+ self.dtype = device_pool.store_dtype
428
+ self.head_num = device_pool.head_num
429
+ self.head_dim = device_pool.head_dim
430
+ self.layer_num = device_pool.layer_num
431
+ self.size_per_token = (
432
+ self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
433
+ )
434
+
435
+ # Verify there is enough available host memory.
436
+ host_mem = psutil.virtual_memory()
437
+ requested_bytes = self.size * self.size_per_token
438
+ # preserve at least 10GB for other usage
439
+ ten_gb = 10 * (1024**3)
440
+ if requested_bytes > host_mem.available - ten_gb:
441
+ raise ValueError(
442
+ f"Not enough host memory available. Requesting "
443
+ f"{requested_bytes / 1e9:.2f} GB but only have "
444
+ f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
445
+ f"size of the hierarchical cache."
446
+ )
447
+ else:
448
+ logger.info(
449
+ f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
450
+ )
451
+
452
+ self.kv_buffer = torch.empty(
453
+ (2, self.layer_num, self.size, self.head_num, self.head_dim),
454
+ dtype=self.dtype,
455
+ device=self.device,
456
+ pin_memory=self.pin_memory,
457
+ )
458
+
459
+ # Initialize memory states and tracking structures.
460
+ self.mem_state = torch.zeros(
461
+ (self.size,), dtype=torch.uint8, device=self.device
462
+ )
463
+ self.free_slots = torch.arange(self.size, dtype=torch.int32)
464
+ self.can_use_mem_size = self.size
465
+
466
+ # A lock for synchronized operations on memory allocation and state transitions.
467
+ self.lock = threading.RLock()
468
+
469
+ def get_flat_data(self, indices):
470
+ return self.kv_buffer[:, :, indices]
471
+
472
+ @debug_timing
473
+ def transfer(self, indices, flat_data):
474
+ # backup prepared data from device to host
475
+ self.kv_buffer[:, :, indices] = flat_data.to(
476
+ device=self.device, non_blocking=False
477
+ )
478
+
479
+ @synchronized
480
+ def clear(self):
481
+ self.mem_state.fill_(0)
482
+ self.can_use_mem_size = self.size
483
+ self.free_slots = torch.arange(self.size, dtype=torch.int32)
484
+
485
+ @synchronized
486
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
487
+ assert len(indices) > 0, "The indices should not be empty"
488
+ states = self.mem_state[indices]
489
+ assert (
490
+ states == states[0]
491
+ ).all(), "The memory slots should have the same state {}".format(states)
492
+ return MemoryStateInt(states[0].item())
493
+
494
+ @synchronized
495
+ def alloc(self, need_size: int) -> torch.Tensor:
496
+ if need_size > self.can_use_mem_size:
497
+ return None
498
+
499
+ # todo: de-fragementation
500
+ select_index = self.free_slots[:need_size]
501
+ self.free_slots = self.free_slots[need_size:]
502
+
503
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
504
+ self.can_use_mem_size -= need_size
505
+
506
+ return select_index
507
+
508
+ @synchronized
509
+ def is_reserved(self, indices: torch.Tensor) -> bool:
510
+ return self.get_state(indices) == MemoryStateInt.RESERVED
511
+
512
+ @synchronized
513
+ def is_protected(self, indices: torch.Tensor) -> bool:
514
+ return self.get_state(indices) == MemoryStateInt.PROTECTED
515
+
516
+ @synchronized
517
+ def is_synced(self, indices: torch.Tensor) -> bool:
518
+ return self.get_state(indices) == MemoryStateInt.SYNCED
519
+
520
+ @synchronized
521
+ def is_backup(self, indices: torch.Tensor) -> bool:
522
+ return self.get_state(indices) == MemoryStateInt.BACKUP
523
+
524
+ @synchronized
525
+ def update_backup(self, indices: torch.Tensor):
526
+ assert self.is_synced(indices), (
527
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
528
+ f"Current state: {self.get_state(indices)}"
529
+ )
530
+ self.mem_state[indices] = MemoryStateInt.BACKUP
531
+
532
+ @synchronized
533
+ def update_synced(self, indices: torch.Tensor):
534
+ self.mem_state[indices] = MemoryStateInt.SYNCED
535
+
536
+ @synchronized
537
+ def protect_write(self, indices: torch.Tensor):
538
+ assert self.is_reserved(indices), (
539
+ f"The host memory slots should be RESERVED before write operations. "
540
+ f"Current state: {self.get_state(indices)}"
541
+ )
542
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
543
+
544
+ @synchronized
545
+ def protect_load(self, indices: torch.Tensor):
546
+ assert self.is_backup(indices), (
547
+ f"The host memory slots should be in BACKUP state before load operations. "
548
+ f"Current state: {self.get_state(indices)}"
549
+ )
550
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
551
+
552
+ @synchronized
553
+ def complete_io(self, indices: torch.Tensor):
554
+ assert self.is_protected(indices), (
555
+ f"The host memory slots should be PROTECTED during I/O operations. "
556
+ f"Current state: {self.get_state(indices)}"
557
+ )
558
+ self.mem_state[indices] = MemoryStateInt.SYNCED
559
+
560
+ def available_size(self):
561
+ return len(self.free_slots)
562
+
563
+ @synchronized
564
+ def free(self, indices: torch.Tensor) -> int:
565
+ self.mem_state[indices] = MemoryStateInt.IDLE
566
+ self.free_slots = torch.concat([self.free_slots, indices])
567
+ self.can_use_mem_size += len(indices)
568
+ return len(indices)