vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/lora/models.py ADDED
@@ -0,0 +1,645 @@
1
+ import copy
2
+ import json
3
+ import math
4
+ import os
5
+ import re
6
+ from typing import Callable, Dict, List, Optional, Tuple, Type
7
+
8
+ import safetensors.torch
9
+ import torch
10
+ from torch import nn
11
+
12
+ from vllm.config import LoRAConfig
13
+ from vllm.logger import init_logger
14
+ from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
15
+ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
16
+ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
17
+ parse_fine_tuned_lora_name, replace_submodule)
18
+ from vllm.utils import LRUCache, is_pin_memory_available
19
+
20
+ logger = init_logger(__name__)
21
+
22
+ _GLOBAL_LORA_ID = 0
23
+
24
+
25
+ def convert_mapping(
26
+ mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
27
+ max_loras: int, vocab_size: int, extra_vocab_size: int
28
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
29
+ """Converts LoRAMapping to index tensors.
30
+
31
+ Args:
32
+ mapping: LoRAMapping mapping rows in a batch to LoRA ids.
33
+ lora_index_to_id: List mapping LoRA ids to LoRA indices.
34
+ max_loras: Maximum number of LoRAs.
35
+ vocab_size: Model vocab size.
36
+ extra_vocab_size: Extra vocab size each LoRA can have.
37
+
38
+ Returns:
39
+ A tuple of tensors:
40
+ base_indices: Tensor of shape [batch_size] mapping batch rows to
41
+ LoRA indices.
42
+ sampler_indices: Tensor of shape [batch_size] mapping requests to
43
+ LoRA indices for sampler. For generation, this will be the
44
+ same as base_indicies. For prefill, this will map requests
45
+ to LoRA indices.
46
+ sampler_indices_padded: Tensor of shape [batch_size] mapping
47
+ requests to LoRA indices for sampler with padding.
48
+ Same as sampler_indicies, but -1 is replaced with
49
+ max_loras.
50
+ embeddings_indices: Tensor of shape [2, batch_size] mapping
51
+ requests to embedding indices. First row is for embeddings
52
+ added by the LoRAs, second row is for the LoRA.lora_a
53
+ embeddings.
54
+ indices_len: List of lengths of the above tensors.
55
+ """
56
+ index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
57
+ embedding_indices = index_mapping_indices.copy()
58
+ lora_indices = index_mapping_indices.copy()
59
+ prompt_mapping: List[int] = [
60
+ lora_index_to_id.index(x) if x > 0 else -1
61
+ for x in mapping.prompt_mapping
62
+ ]
63
+ lora_idx = None
64
+ for i in range(len(index_mapping_indices)):
65
+ # TODO index can be slow. optimize
66
+ lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
67
+ if index_mapping_indices[i] > 0 else -1)
68
+ embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
69
+ index_mapping_indices[i] = i
70
+ lora_indices[i] = lora_idx
71
+
72
+ indices = torch.tensor(
73
+ [index_mapping_indices, lora_indices, embedding_indices],
74
+ dtype=torch.long,
75
+ device="cuda")
76
+ prompt_mapping_tensor = torch.tensor(prompt_mapping,
77
+ device="cuda",
78
+ dtype=torch.long)
79
+ embeddings_indices = torch.stack([
80
+ indices[2] * extra_vocab_size,
81
+ indices[2] * (vocab_size + extra_vocab_size)
82
+ ])
83
+ embeddings_indices[embeddings_indices == -1] = max_loras - 1
84
+ base_indices = indices[1]
85
+ sampler_indices = prompt_mapping_tensor
86
+ sampler_indices_padded = sampler_indices.clone()
87
+ sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
88
+ sampler_indices_padded = (
89
+ torch.arange(
90
+ 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
91
+ (sampler_indices_padded * len(sampler_indices_padded)))
92
+ indices_len = [
93
+ base_indices.shape[-1], sampler_indices.shape[-1],
94
+ sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
95
+ ]
96
+
97
+ return (base_indices, sampler_indices, sampler_indices_padded,
98
+ embeddings_indices, indices_len)
99
+
100
+
101
+ def get_lora_id():
102
+ global _GLOBAL_LORA_ID
103
+ _GLOBAL_LORA_ID += 1
104
+ return _GLOBAL_LORA_ID
105
+
106
+
107
+ class LoRAModel:
108
+ """A LoRA fine-tuned model."""
109
+
110
+ def __init__(
111
+ self,
112
+ lora_model_id: int,
113
+ rank: int,
114
+ loras: Dict[str, LoRALayerWeights],
115
+ ) -> None:
116
+ self.id = lora_model_id
117
+ assert (lora_model_id >
118
+ 0), f"a valid lora id should be greater than 0, got {self.id}"
119
+ self.rank = rank
120
+ self.loras: Dict[str, LoRALayerWeights] = loras
121
+
122
+ @property
123
+ def extra_vocab_size(self) -> int:
124
+ return max(lora.extra_vocab_size
125
+ for lora in self.loras.values()) if self.loras else 0
126
+
127
+ def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
128
+ """Get LoRA for a given module by name"""
129
+ return self.loras.get(module_name, None)
130
+
131
+ # (yard1): TODO see if we can derive target_embedding_padding automatically
132
+ @classmethod
133
+ def from_lora_tensors(
134
+ cls,
135
+ lora_model_id: int,
136
+ rank: int,
137
+ lora_alpha: int,
138
+ tensors: Dict[str, torch.Tensor],
139
+ device: str = "cuda",
140
+ dtype: Optional[torch.dtype] = None,
141
+ embeddings: Optional[Dict[str, torch.Tensor]] = None,
142
+ target_embedding_padding: Optional[int] = None,
143
+ embedding_modules: Optional[Dict[str, str]] = None,
144
+ embedding_padding_modules: Optional[List[str]] = None,
145
+ ) -> "LoRAModel":
146
+ """Create a LoRAModel from a dictionary of tensors."""
147
+ pin_memory = str(device) == "cpu" and is_pin_memory_available()
148
+ loras: Dict[str, LoRALayerWeights] = {}
149
+ for tensor_name, tensor in tensors.items():
150
+ module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
151
+ if module_name not in loras:
152
+ lora_embeddings_tensor = None
153
+ if embeddings:
154
+ assert embedding_modules is not None
155
+ embeddings_module = next(
156
+ (k for k in embedding_modules if k in module_name),
157
+ None)
158
+ if embeddings_module:
159
+ lora_embeddings_tensor = embeddings[
160
+ embedding_modules[embeddings_module]].to(
161
+ device=device, dtype=dtype)
162
+ if pin_memory:
163
+ lora_embeddings_tensor = (
164
+ lora_embeddings_tensor.pin_memory())
165
+ loras[module_name] = LoRALayerWeights(module_name, rank,
166
+ lora_alpha, None, None,
167
+ lora_embeddings_tensor)
168
+ if is_lora_a:
169
+ loras[module_name].lora_a = tensor.to(device=device,
170
+ dtype=dtype).t()
171
+ if pin_memory:
172
+ loras[module_name].lora_a = loras[
173
+ module_name].lora_a.pin_memory()
174
+ else:
175
+ loras[module_name].lora_b = tensor.to(device=device,
176
+ dtype=dtype).t()
177
+ assert embedding_padding_modules is not None
178
+ if any(name in module_name
179
+ for name in embedding_padding_modules
180
+ ) and target_embedding_padding is not None:
181
+ lora_b = loras[module_name].lora_b
182
+ assert target_embedding_padding >= lora_b.shape[1]
183
+ addition = target_embedding_padding - lora_b.shape[1]
184
+ loras[module_name].lora_b = torch.nn.functional.pad(
185
+ lora_b, (0, addition))
186
+ if pin_memory:
187
+ loras[module_name].lora_b = loras[
188
+ module_name].lora_b.pin_memory()
189
+
190
+ for lora in loras.values():
191
+ lora.optimize()
192
+ return cls(lora_model_id, rank, loras)
193
+
194
+ @classmethod
195
+ def from_local_checkpoint(
196
+ cls,
197
+ lora_dir: str,
198
+ expected_lora_modules: List[str],
199
+ lora_model_id: Optional[int] = None,
200
+ device: str = "cuda",
201
+ dtype: Optional[torch.dtype] = None,
202
+ target_embedding_padding: Optional[int] = None,
203
+ embedding_modules: Optional[Dict[str, str]] = None,
204
+ embedding_padding_modules: Optional[List[str]] = None,
205
+ ) -> "LoRAModel":
206
+ """Create a LoRAModel from a local checkpoint."""
207
+ lora_config_path = os.path.join(lora_dir, "adapter_config.json")
208
+ lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
209
+ lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
210
+ new_embeddings_tensor_path = os.path.join(
211
+ lora_dir, "new_embeddings.safetensors")
212
+ new_embeddings_bin_file_path = os.path.join(lora_dir,
213
+ "new_embeddings.bin")
214
+ with open(lora_config_path) as f:
215
+ config = json.load(f)
216
+ target_modules = config["target_modules"]
217
+ unexpected_modules = []
218
+ for module in target_modules:
219
+ # Compatible with more modules, such as:layers.11.self_attn.k_proj
220
+ part_name = module.split(".")[-1]
221
+ if part_name not in expected_lora_modules:
222
+ unexpected_modules.append(module)
223
+ # loaded lora's target modules must be a subset of expected_lora_modules
224
+ if unexpected_modules:
225
+ raise ValueError(
226
+ f"While loading {lora_dir}, expected"
227
+ f" target modules in {expected_lora_modules}"
228
+ f" but received {unexpected_modules}."
229
+ f" Please verify that the loaded LoRA module is correct")
230
+ if os.path.isfile(lora_tensor_path):
231
+ tensors = safetensors.torch.load_file(lora_tensor_path)
232
+ elif os.path.isfile(lora_bin_file_path):
233
+ tensors = torch.load(lora_bin_file_path)
234
+ else:
235
+ raise ValueError(f"{lora_dir} doesn't contain tensors")
236
+
237
+ embeddings = None
238
+ if os.path.isfile(new_embeddings_tensor_path):
239
+ embeddings = safetensors.torch.load_file(
240
+ new_embeddings_tensor_path)
241
+ elif os.path.isfile(new_embeddings_bin_file_path):
242
+ embeddings = torch.load(new_embeddings_bin_file_path)
243
+
244
+ rank = config["r"]
245
+ lora_alpha = config["lora_alpha"]
246
+ return cls.from_lora_tensors(
247
+ lora_model_id=get_lora_id()
248
+ if lora_model_id is None else lora_model_id,
249
+ rank=rank,
250
+ lora_alpha=lora_alpha,
251
+ tensors=tensors,
252
+ device=device,
253
+ dtype=dtype,
254
+ embeddings=embeddings,
255
+ target_embedding_padding=target_embedding_padding,
256
+ embedding_modules=embedding_modules,
257
+ embedding_padding_modules=embedding_padding_modules,
258
+ )
259
+
260
+
261
+ class LoRAModelManager:
262
+ """A manager that manages multiple LoRA-fine-tuned models."""
263
+
264
+ def __init__(
265
+ self,
266
+ model: nn.Module,
267
+ max_num_seqs: int,
268
+ max_num_batched_tokens: int,
269
+ vocab_size: int,
270
+ lora_config: LoRAConfig,
271
+ ):
272
+ """Create a LoRAModelManager and adapter for a given model.
273
+
274
+ Args:
275
+ model: the model to be adapted.
276
+ max_num_seqs: the maximum number of sequences model can run in a
277
+ single batch.
278
+ max_num_batched_tokens: the maximum number of tokens model can run
279
+ in a single batch.
280
+ vocab_size: the vocab size of the model.
281
+ lora_config: the LoRA configuration.
282
+ """
283
+ self.lora_config = lora_config
284
+ self.max_num_seqs = max_num_seqs
285
+ assert self.capacity >= self.lora_slots
286
+ self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
287
+ self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
288
+ self.vocab_size = vocab_size
289
+ self.base_indices = torch.empty(self.max_num_batched_tokens,
290
+ dtype=torch.long,
291
+ device="cuda")
292
+ self.sampler_indices = torch.empty(self.max_num_batched_tokens,
293
+ dtype=torch.long,
294
+ device="cuda")
295
+ self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
296
+ dtype=torch.long,
297
+ device="cuda")
298
+ self.embeddings_indices = torch.empty(2,
299
+ self.max_num_batched_tokens,
300
+ dtype=torch.long,
301
+ device="cuda")
302
+ # 4 is the number of indicies tensors defined above
303
+ # base_indices, sampler_indices, sampler_indices_padded,
304
+ # embeddings_indices
305
+ self.indices_len: List[Optional[int]] = [None] * 4
306
+
307
+ self.model: nn.Module = model
308
+ if hasattr(self.model, "supported_lora_modules"):
309
+ self.supported_lora_modules = copy.deepcopy(
310
+ self.model.supported_lora_modules)
311
+ self.packed_modules_mapping = copy.deepcopy(
312
+ self.model.packed_modules_mapping)
313
+ self.packed_modules: Dict[str, List[str]] = {}
314
+ self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
315
+ self._registered_loras: Dict[int, LoRAModel] = {}
316
+ # Dict instead of a Set for compatibility with LRUCache.
317
+ self._active_loras: Dict[int, None] = {}
318
+ self._last_mapping: Optional[LoRAMapping] = None
319
+ self._create_lora_modules()
320
+ self.model.lora_manager = self
321
+
322
+ @property
323
+ def capacity(self) -> int:
324
+ return self.lora_config.max_cpu_loras
325
+
326
+ @property
327
+ def lora_slots(self) -> int:
328
+ return self.lora_config.max_loras
329
+
330
+ def __len__(self) -> int:
331
+ return len(self._registered_loras)
332
+
333
+ def activate_lora(
334
+ self,
335
+ lora_id: int,
336
+ ) -> bool:
337
+ """Move LoRA into a GPU buffer to be used in the forward pass."""
338
+ if lora_id in self._active_loras:
339
+ return False
340
+ first_free_slot = next(
341
+ ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
342
+ if lora_id is None), None)
343
+ if first_free_slot is None:
344
+ raise ValueError("No free lora slots")
345
+ index, _ = first_free_slot
346
+ self._active_loras[lora_id] = None
347
+ lora_model = self._registered_loras[lora_id]
348
+ logger.debug("Activating LoRA. int id: %d, slot index: %d",
349
+ lora_model.id, index)
350
+ self.lora_index_to_id[index] = lora_model.id
351
+ for module_name, module in self.modules.items():
352
+ module_lora = lora_model.get_lora(module_name)
353
+ if module_lora:
354
+ module_lora.optimize()
355
+ module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
356
+ module_lora.embeddings_tensor)
357
+ else:
358
+ module.reset_lora(index)
359
+ return True
360
+
361
+ def _deactivate_lora(self, lora_id: int):
362
+ try:
363
+ index = self.lora_index_to_id.index(lora_id)
364
+ self.lora_index_to_id[index] = None
365
+ except ValueError:
366
+ pass
367
+
368
+ def deactivate_lora(self, lora_id: int) -> bool:
369
+ """Remove a LoRA from a GPU buffer."""
370
+ if lora_id in self._active_loras:
371
+ self._deactivate_lora(lora_id)
372
+ self._active_loras.pop(lora_id)
373
+ return True
374
+ return False
375
+
376
+ def _add_lora(self, lora: LoRAModel):
377
+ self._create_merged_loras_inplace(lora)
378
+ self._registered_loras[lora.id] = lora
379
+
380
+ def add_lora(self, lora: LoRAModel) -> bool:
381
+ """Add a LoRAModel to the manager CPU cache."""
382
+ if lora.id not in self._registered_loras:
383
+ if len(self._registered_loras) >= self.capacity:
384
+ raise RuntimeError("No free LoRA slots.")
385
+ self._add_lora(lora)
386
+ return True
387
+ return False
388
+
389
+ def remove_lora(self, lora_id: int) -> bool:
390
+ """Remove a LoRAModel from the manager CPU cache."""
391
+ # TODO: should we check active lora?
392
+ self.deactivate_lora(lora_id)
393
+ return bool(self._registered_loras.pop(lora_id, None))
394
+
395
+ # TODO see if this can be vectorized
396
+ def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
397
+ (base_indices, sampler_indices, sampler_indices_padded,
398
+ embeddings_indices,
399
+ indices_len) = convert_mapping(mapping, self.lora_index_to_id,
400
+ self.lora_slots + 1, self.vocab_size,
401
+ self.lora_config.lora_extra_vocab_size)
402
+ self.base_indices[:base_indices.shape[0]].copy_(base_indices)
403
+ self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
404
+ self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
405
+ sampler_indices_padded)
406
+ self.embeddings_indices[:embeddings_indices.
407
+ shape[0], :embeddings_indices.shape[1]].copy_(
408
+ embeddings_indices)
409
+ # Maintain the reference
410
+ self.indices_len[:] = indices_len
411
+
412
+ def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
413
+ if self._last_mapping != lora_mapping:
414
+ self._set_lora_mapping(lora_mapping)
415
+ self._last_mapping = lora_mapping
416
+
417
+ def list_loras(self) -> Dict[int, LoRAModel]:
418
+ """List all registered LoRAModels."""
419
+ return dict(self._registered_loras)
420
+
421
+ def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
422
+ return self._registered_loras.get(lora_id, None)
423
+
424
+ def remove_all_loras(self):
425
+ """Remove all LoRAModels from the manager."""
426
+ self._registered_loras.clear()
427
+ self.lora_index_to_id = [None] * self.lora_slots
428
+ self._active_loras.clear()
429
+
430
+ def _create_lora_modules(self):
431
+ for module_name, module in self.model.named_modules():
432
+ if not self._match_target_modules(module_name):
433
+ continue
434
+ parts = module_name.split(".")[-1]
435
+ packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
436
+ new_module = replace_submodule(
437
+ self.model, module_name,
438
+ from_layer(module, self.lora_slots, self.lora_config,
439
+ packed_moduled_lst, self.model.config))
440
+ # (yard1): TODO make this more robust
441
+ if "lm_head" in module_name:
442
+ logits_processor_module = self.model.get_submodule(
443
+ "logits_processor")
444
+ new_module = replace_submodule(
445
+ self.model, "logits_processor",
446
+ from_layer_logits_processor(logits_processor_module,
447
+ module, self.lora_slots,
448
+ self.lora_config,
449
+ self.model.config))
450
+ self.register_module(module_name, new_module)
451
+ self._register_packed_modules(module_name)
452
+ new_module.set_mapping(self.base_indices, self.sampler_indices,
453
+ self.sampler_indices_padded,
454
+ self.embeddings_indices, self.indices_len)
455
+
456
+ def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
457
+ assert isinstance(module, BaseLayerWithLoRA)
458
+ self.modules[module_name] = module
459
+
460
+ def create_dummy_lora(
461
+ self,
462
+ lora_id: int,
463
+ rank: int,
464
+ embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
465
+ """Create zero-initialized LoRAModel for warmup."""
466
+ model = LoRAModel(lora_id, rank, {})
467
+ for module_name, module in self.model.named_modules():
468
+ if not self._match_target_modules(module_name) or not isinstance(
469
+ module, BaseLayerWithLoRA):
470
+ continue
471
+ parts = module_name.split(".")
472
+ if module_name not in self.packed_modules:
473
+ assert embedding_modules is not None
474
+ if parts[-1] in embedding_modules:
475
+ input_dim = (module.base_layer.org_vocab_size +
476
+ self.lora_config.lora_extra_vocab_size if
477
+ hasattr(module.base_layer, "org_vocab_size")
478
+ else module.base_layer.weight.shape[1])
479
+ output_dim = module.base_layer.embedding_dim if hasattr(
480
+ module.base_layer,
481
+ "embedding_dim") else module.base_layer.weight.shape[0]
482
+ embeddings_tensor_dim = (module.base_layer.embedding_dim if
483
+ hasattr(module.base_layer,
484
+ "embedding_dim") else
485
+ module.base_layer.weight.shape[1])
486
+ lora = LoRALayerWeights.create_dummy_lora_weights(
487
+ module_name,
488
+ input_dim,
489
+ output_dim,
490
+ rank,
491
+ module.lora_a_stacked.dtype,
492
+ "cpu",
493
+ embeddings_tensor_dim=embeddings_tensor_dim)
494
+ else:
495
+ lora = LoRALayerWeights.create_dummy_lora_weights(
496
+ module_name,
497
+ module.lora_a_stacked.shape[-1],
498
+ module.lora_b_stacked.shape[-2],
499
+ rank,
500
+ module.lora_a_stacked.dtype,
501
+ "cpu",
502
+ )
503
+ lora.optimize()
504
+ else:
505
+ parts = module_name.split(".")
506
+ replacements = self.packed_modules_mapping[parts[-1]]
507
+ subloras: List[Optional["LoRALayerWeights"]] = []
508
+ for i, r in enumerate(replacements):
509
+ lora = LoRALayerWeights.create_dummy_lora_weights(
510
+ module_name + "." + r,
511
+ module.lora_a_stacked[i].shape[-1],
512
+ module.lora_b_stacked[i].shape[-2],
513
+ rank,
514
+ module.lora_a_stacked[i].dtype,
515
+ "cpu",
516
+ )
517
+ lora.optimize()
518
+ subloras.append(lora)
519
+ lora = PackedLoRALayerWeights.pack(subloras)
520
+ model.loras[module_name] = lora
521
+ return model
522
+
523
+ def _match_target_modules(self, module_name: str):
524
+ return any(
525
+ re.match(
526
+ r".*\.{target_module}$".format(target_module=target_module),
527
+ module_name) or target_module == module_name
528
+ for target_module in self.supported_lora_modules)
529
+
530
+ def _register_packed_modules(self, module_full_name: str) -> None:
531
+ parts = module_full_name.split(".")
532
+ module_name = parts[-1]
533
+ replacements = self.packed_modules_mapping.get(module_name, [])
534
+ # When replacements is less than or equal to 1, it indicates that this
535
+ # module is not a packed module.
536
+ if len(replacements) <= 1:
537
+ return
538
+ prefix = ".".join(parts[:-1])
539
+ self.packed_modules[module_full_name] = [
540
+ prefix + "." + r if prefix else r for r in replacements
541
+ ]
542
+
543
+ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
544
+ for module_name, new_module_names in self.packed_modules.items():
545
+ replacement_loras: List[Optional[LoRALayerWeights]] = []
546
+ has_replacement = False
547
+ for r in new_module_names:
548
+ lora = lora_model.get_lora(r)
549
+ replacement_loras.append(lora)
550
+ if lora:
551
+ has_replacement = True
552
+ if not has_replacement:
553
+ continue
554
+ for i in range(len(replacement_loras)):
555
+ if replacement_loras[i]:
556
+ continue
557
+ replacement_loras[i] = None
558
+ lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
559
+ replacement_loras)
560
+
561
+
562
+ class LoRALRUCache(LRUCache[LoRAModel]):
563
+
564
+ def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
565
+ bool]):
566
+ super().__init__(capacity)
567
+ self.deactivate_lora_fn = deactivate_lora_fn
568
+
569
+ def _on_remove(self, key: int, value: LoRAModel):
570
+ logger.debug("Removing LoRA. int id: %d", key)
571
+ self.deactivate_lora_fn(key)
572
+ return super()._on_remove(key, value)
573
+
574
+
575
+ class LRUCacheLoRAModelManager(LoRAModelManager):
576
+ """A model manager that manages multiple LoRAs with LRU cache."""
577
+
578
+ def __init__(
579
+ self,
580
+ model: nn.Module,
581
+ max_num_seqs: int,
582
+ max_num_batched_tokens: int,
583
+ vocab_size: int,
584
+ lora_config: LoRAConfig,
585
+ ):
586
+ super().__init__(model, max_num_seqs, max_num_batched_tokens,
587
+ vocab_size, lora_config)
588
+ self._registered_loras: LoRALRUCache = LoRALRUCache(
589
+ self.capacity, self.deactivate_lora)
590
+ self._active_loras: LoRALRUCache = LoRALRUCache(
591
+ self.lora_slots, self._deactivate_lora)
592
+
593
+ def list_loras(self) -> Dict[int, LoRAModel]:
594
+ """List all registered LoRAModels."""
595
+ return dict(self._registered_loras.cache)
596
+
597
+ def add_lora(self, lora: LoRAModel) -> bool:
598
+ """Add a LoRAModel to the manager."""
599
+ if lora.id not in self._registered_loras:
600
+ self._add_lora(lora)
601
+ was_added = True
602
+ else:
603
+ # We always touch to update the LRU cache order
604
+ self._registered_loras.touch(lora.id)
605
+ was_added = False
606
+ return was_added
607
+
608
+ def activate_lora(
609
+ self,
610
+ lora_id: int,
611
+ ) -> bool:
612
+ if lora_id not in self._active_loras and len(
613
+ self._active_loras) >= self.lora_slots:
614
+ self._active_loras.remove_oldest()
615
+ result = super().activate_lora(lora_id)
616
+ # We always touch to update the LRU cache order
617
+ self._active_loras.touch(lora_id)
618
+ return result
619
+
620
+ def remove_oldest_lora(self) -> bool:
621
+ if len(self._registered_loras) > 0:
622
+ self._registered_loras.remove_oldest()
623
+ return True
624
+ return False
625
+
626
+
627
+ def create_lora_manager(
628
+ model: nn.Module,
629
+ max_num_seqs: int,
630
+ max_num_batched_tokens: int,
631
+ vocab_size: int,
632
+ lora_config: LoRAConfig,
633
+ lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
634
+ **kwargs) -> LoRAModelManager:
635
+ """Create a LoRA adapter for a given model."""
636
+ if not hasattr(model, "supported_lora_modules"):
637
+ raise ValueError(f"Model {type(model)} is not supported for LoRA.")
638
+ lora_manager = lora_manager_cls(
639
+ model=model,
640
+ max_num_seqs=max_num_seqs,
641
+ max_num_batched_tokens=max_num_batched_tokens,
642
+ vocab_size=vocab_size,
643
+ lora_config=lora_config,
644
+ **kwargs)
645
+ return lora_manager