vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/lora/punica.py
ADDED
@@ -0,0 +1,213 @@
|
|
1
|
+
# Based on code from https://github.com/punica-ai/punica
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
|
8
|
+
def _raise_import_error(e):
|
9
|
+
if torch.cuda.get_device_capability() < (8, 0):
|
10
|
+
raise ImportError(
|
11
|
+
"punica LoRA kernels require compute capability >= 8.0") from e
|
12
|
+
else:
|
13
|
+
raise ImportError(
|
14
|
+
"punica LoRA kernels could not be imported. If you built vLLM "
|
15
|
+
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
16
|
+
"was set.") from e
|
17
|
+
|
18
|
+
|
19
|
+
def bgmv(
|
20
|
+
y: torch.Tensor,
|
21
|
+
x: torch.Tensor,
|
22
|
+
w_t_all: torch.Tensor,
|
23
|
+
indicies: torch.LongTensor,
|
24
|
+
layer_idx: int,
|
25
|
+
scale: float,
|
26
|
+
):
|
27
|
+
"""
|
28
|
+
Semantics:
|
29
|
+
y[i] += (
|
30
|
+
x[i].unsqueeze(0)
|
31
|
+
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
32
|
+
* scale
|
33
|
+
).squeeze(0)
|
34
|
+
|
35
|
+
Args:
|
36
|
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
37
|
+
x: Shape: `[B, H1]`. Input vectors.
|
38
|
+
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
|
39
|
+
matrices.
|
40
|
+
indicies: Shape: `[B]`. Indices of the weight matrices.
|
41
|
+
layer_idx: Layer index of the weight matrices.
|
42
|
+
scale: Scaling factor.
|
43
|
+
"""
|
44
|
+
try:
|
45
|
+
import vllm._punica_C as punica_kernels
|
46
|
+
except ImportError as e:
|
47
|
+
_raise_import_error(e)
|
48
|
+
|
49
|
+
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
50
|
+
|
51
|
+
|
52
|
+
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
|
53
|
+
w_t_all: torch.Tensor, indicies: torch.LongTensor,
|
54
|
+
layer_idx: int, scale: float, y_offset: int,
|
55
|
+
y_slice_size: int):
|
56
|
+
"""
|
57
|
+
Same as `bgmv` but you can operate on slices of y.
|
58
|
+
Pass whole y, define y_offset and y_slice_size.
|
59
|
+
|
60
|
+
Semantics:
|
61
|
+
y[i] += (
|
62
|
+
x[i].unsqueeze(0)
|
63
|
+
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
64
|
+
* scale
|
65
|
+
).squeeze(0)
|
66
|
+
|
67
|
+
Args:
|
68
|
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
69
|
+
x: Shape: `[B, H1]`. Input vectors.
|
70
|
+
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
|
71
|
+
all of the transposed LoRA matrices.
|
72
|
+
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
73
|
+
layer_idx: Layer index of LoRA weights.
|
74
|
+
scale: Scaling factor.
|
75
|
+
y_offset: Offset to apply to the starting column of y.
|
76
|
+
y_slice_size: Size of the y column slice.
|
77
|
+
"""
|
78
|
+
try:
|
79
|
+
import vllm._punica_C as punica_kernels
|
80
|
+
except ImportError as e:
|
81
|
+
_raise_import_error(e)
|
82
|
+
punica_kernels.dispatch_bgmv_low_level(
|
83
|
+
y,
|
84
|
+
x,
|
85
|
+
w_t_all,
|
86
|
+
indicies,
|
87
|
+
layer_idx,
|
88
|
+
scale,
|
89
|
+
x.size(1),
|
90
|
+
y_slice_size,
|
91
|
+
y_offset,
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
def add_lora(y: torch.Tensor,
|
96
|
+
x: torch.Tensor,
|
97
|
+
wa_t_all: torch.Tensor,
|
98
|
+
wb_t_all: torch.Tensor,
|
99
|
+
indicies: torch.LongTensor,
|
100
|
+
layer_idx: int,
|
101
|
+
scale: float,
|
102
|
+
*,
|
103
|
+
buffer: Optional[torch.Tensor] = None):
|
104
|
+
"""
|
105
|
+
Semantics:
|
106
|
+
y[i] += (
|
107
|
+
x[i].unsqueeze(0)
|
108
|
+
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
109
|
+
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
110
|
+
* scale
|
111
|
+
).squeeze(0)
|
112
|
+
|
113
|
+
Args:
|
114
|
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
115
|
+
x: Shape: `[B, H1]`. Input vectors.
|
116
|
+
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
117
|
+
LoRA A matrices.
|
118
|
+
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
119
|
+
LoRA B matrices.
|
120
|
+
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
121
|
+
layer_idx: Layer index of LoRA weights.
|
122
|
+
scale: Scaling factor.
|
123
|
+
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
124
|
+
"""
|
125
|
+
try:
|
126
|
+
import vllm._punica_C as punica_kernels
|
127
|
+
except ImportError as e:
|
128
|
+
_raise_import_error(e)
|
129
|
+
|
130
|
+
r = wb_t_all.size(-1)
|
131
|
+
if buffer is None:
|
132
|
+
# We set the buffer to be float32 by default to avoid
|
133
|
+
# numerical inaccuracies that would otherwise happen
|
134
|
+
# due to downcasting.
|
135
|
+
buffer = torch.zeros((x.size(0), r),
|
136
|
+
dtype=torch.float32,
|
137
|
+
device=x.device)
|
138
|
+
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
|
139
|
+
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
140
|
+
scale)
|
141
|
+
|
142
|
+
|
143
|
+
def add_lora_slice(y: torch.Tensor,
|
144
|
+
x: torch.Tensor,
|
145
|
+
wa_t_all: torch.Tensor,
|
146
|
+
wb_t_all: torch.Tensor,
|
147
|
+
indicies: torch.LongTensor,
|
148
|
+
layer_idx: int,
|
149
|
+
scale: float,
|
150
|
+
y_offset: int,
|
151
|
+
y_slice_size: int,
|
152
|
+
*,
|
153
|
+
buffer: Optional[torch.Tensor] = None):
|
154
|
+
"""
|
155
|
+
Same as `add_lora` but you can operate on slices of y.
|
156
|
+
Pass whole y, define y_offset and y_slice_size.
|
157
|
+
|
158
|
+
Semantics:
|
159
|
+
y[i] += (
|
160
|
+
x[i].unsqueeze(0)
|
161
|
+
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
162
|
+
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
163
|
+
* scale
|
164
|
+
).squeeze(0)
|
165
|
+
|
166
|
+
Args:
|
167
|
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
168
|
+
x: Shape: `[B, H1]`. Input vectors.
|
169
|
+
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
170
|
+
LoRA A matrices.
|
171
|
+
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
172
|
+
LoRA B matrices.
|
173
|
+
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
174
|
+
layer_idx: Layer index of LoRA weights.
|
175
|
+
scale: Scaling factor.
|
176
|
+
y_offset: Offset to apply to the starting column of y.
|
177
|
+
y_slice_size: Size of the y column slice.
|
178
|
+
"""
|
179
|
+
try:
|
180
|
+
import vllm._punica_C as punica_kernels
|
181
|
+
except ImportError as e:
|
182
|
+
_raise_import_error(e)
|
183
|
+
|
184
|
+
r = wb_t_all.size(-1)
|
185
|
+
if buffer is None:
|
186
|
+
# We set the buffer to be float32 by default to avoid
|
187
|
+
# numerical inaccuracies that would otherwise happen
|
188
|
+
# due to downcasting.
|
189
|
+
buffer = torch.zeros((x.size(0), r),
|
190
|
+
dtype=torch.float32,
|
191
|
+
device=x.device)
|
192
|
+
punica_kernels.dispatch_bgmv_low_level(
|
193
|
+
buffer,
|
194
|
+
x,
|
195
|
+
wa_t_all,
|
196
|
+
indicies,
|
197
|
+
layer_idx,
|
198
|
+
1.0,
|
199
|
+
x.size(1),
|
200
|
+
buffer.size(1),
|
201
|
+
0,
|
202
|
+
)
|
203
|
+
punica_kernels.dispatch_bgmv_low_level(
|
204
|
+
y,
|
205
|
+
buffer,
|
206
|
+
wb_t_all,
|
207
|
+
indicies,
|
208
|
+
layer_idx,
|
209
|
+
scale,
|
210
|
+
buffer.size(1),
|
211
|
+
y_slice_size,
|
212
|
+
y_offset,
|
213
|
+
)
|
vllm/lora/request.py
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
|
4
|
+
@dataclass
|
5
|
+
class LoRARequest:
|
6
|
+
"""
|
7
|
+
Request for a LoRA adapter.
|
8
|
+
|
9
|
+
Note that this class should be be used internally. For online
|
10
|
+
serving, it is recommended to not allow users to use this class but
|
11
|
+
instead provide another layer of abstraction to prevent users from
|
12
|
+
accessing unauthorized LoRA adapters.
|
13
|
+
|
14
|
+
lora_int_id must be globally unique for a given adapter.
|
15
|
+
This is currently not enforced in vLLM.
|
16
|
+
"""
|
17
|
+
|
18
|
+
lora_name: str
|
19
|
+
lora_int_id: int
|
20
|
+
lora_local_path: str
|
21
|
+
|
22
|
+
def __post_init__(self):
|
23
|
+
if self.lora_int_id < 1:
|
24
|
+
raise ValueError(
|
25
|
+
f"lora_int_id must be > 0, got {self.lora_int_id}")
|
26
|
+
|
27
|
+
def __eq__(self, value: object) -> bool:
|
28
|
+
return isinstance(
|
29
|
+
value, LoRARequest) and self.lora_int_id == value.lora_int_id
|
30
|
+
|
31
|
+
def __hash__(self) -> int:
|
32
|
+
return self.lora_int_id
|
vllm/lora/utils.py
ADDED
@@ -0,0 +1,98 @@
|
|
1
|
+
from typing import List, Optional, Set, Tuple, Type
|
2
|
+
|
3
|
+
from torch import nn
|
4
|
+
from transformers import PretrainedConfig
|
5
|
+
|
6
|
+
from vllm.config import LoRAConfig
|
7
|
+
from vllm.logger import init_logger
|
8
|
+
from vllm.lora.fully_sharded_layers import (
|
9
|
+
ColumnParallelLinearWithShardedLoRA,
|
10
|
+
MergedColumnParallelLinearWithShardedLoRA,
|
11
|
+
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
|
12
|
+
# being imported for _all_lora_classes below
|
13
|
+
# yapf conflicts with isort for this block
|
14
|
+
# yapf: disable
|
15
|
+
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
16
|
+
LogitsProcessorWithLoRA,
|
17
|
+
MergedColumnParallelLinearWithLoRA,
|
18
|
+
MergedQKVParallelLinearWithLora,
|
19
|
+
QKVParallelLinearWithLora,
|
20
|
+
RowParallelLinearWithLoRA,
|
21
|
+
VocabParallelEmbeddingWithLoRA)
|
22
|
+
# yapf: enable
|
23
|
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
24
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
25
|
+
|
26
|
+
logger = init_logger(__name__)
|
27
|
+
|
28
|
+
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
29
|
+
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
|
30
|
+
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
|
31
|
+
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
|
32
|
+
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
|
33
|
+
MergedColumnParallelLinearWithShardedLoRA,
|
34
|
+
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
def from_layer(layer: nn.Module,
|
39
|
+
max_loras: int,
|
40
|
+
lora_config: LoRAConfig,
|
41
|
+
packed_modules_list: List,
|
42
|
+
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
|
43
|
+
for lora_cls in _all_lora_classes:
|
44
|
+
# specifying kwargs so they can be easily accessed in decorator
|
45
|
+
if lora_cls.can_replace_layer(source_layer=layer,
|
46
|
+
lora_config=lora_config,
|
47
|
+
packed_modules_list=packed_modules_list,
|
48
|
+
model_config=model_config):
|
49
|
+
ret = lora_cls(layer)
|
50
|
+
ret.create_lora_weights(max_loras, lora_config, model_config)
|
51
|
+
return ret
|
52
|
+
return layer
|
53
|
+
|
54
|
+
|
55
|
+
def from_layer_logits_processor(
|
56
|
+
layer: LogitsProcessor,
|
57
|
+
lm_head: ParallelLMHead,
|
58
|
+
max_loras: int,
|
59
|
+
lora_config: LoRAConfig,
|
60
|
+
model_config: Optional[PretrainedConfig] = None,
|
61
|
+
) -> LogitsProcessorWithLoRA:
|
62
|
+
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
63
|
+
lm_head.weight.dtype, lm_head.weight.device)
|
64
|
+
ret.create_lora_weights(max_loras, lora_config, model_config)
|
65
|
+
return ret
|
66
|
+
|
67
|
+
|
68
|
+
def replace_submodule(model: nn.Module, module_name: str,
|
69
|
+
new_module: nn.Module) -> nn.Module:
|
70
|
+
"""Replace a submodule in a model with a new module."""
|
71
|
+
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
72
|
+
target_name = module_name.split(".")[-1]
|
73
|
+
setattr(parent, target_name, new_module)
|
74
|
+
return new_module
|
75
|
+
|
76
|
+
|
77
|
+
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
78
|
+
"""Parse the name of lora weights.
|
79
|
+
|
80
|
+
args:
|
81
|
+
name: the name of the fine-tuned LoRA, e.g.
|
82
|
+
base_model.model.dense1.weight
|
83
|
+
return:
|
84
|
+
Tuple(module_name, is_lora_a):
|
85
|
+
module_name: the name of the module, e.g. model.dense1,
|
86
|
+
is_lora_a whether the tensor is lora_a or lora_b.
|
87
|
+
"""
|
88
|
+
parts = name.split(".")
|
89
|
+
assert parts[0] == "base_model"
|
90
|
+
assert parts[1] == "model"
|
91
|
+
if parts[-1] == "weight":
|
92
|
+
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
|
93
|
+
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
94
|
+
|
95
|
+
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
96
|
+
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
97
|
+
|
98
|
+
raise ValueError(f"{name} is unsupported format")
|
@@ -0,0 +1,251 @@
|
|
1
|
+
from abc import ABC, abstractmethod, abstractproperty
|
2
|
+
from typing import Any, Dict, List, Set, Type
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from vllm.config import LoRAConfig
|
7
|
+
from vllm.logger import init_logger
|
8
|
+
from vllm.lora.layers import LoRAMapping
|
9
|
+
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
10
|
+
LRUCacheLoRAModelManager, create_lora_manager)
|
11
|
+
from vllm.lora.request import LoRARequest
|
12
|
+
|
13
|
+
logger = init_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class AbstractWorkerLoRAManager(ABC):
|
17
|
+
"""Abstract class for managing LoRA models on the worker side."""
|
18
|
+
|
19
|
+
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
20
|
+
vocab_size: int, lora_config: LoRAConfig,
|
21
|
+
device: torch.device):
|
22
|
+
self.max_num_seqs = max_num_seqs
|
23
|
+
self.max_num_batched_tokens = max_num_batched_tokens
|
24
|
+
self.vocab_size = vocab_size
|
25
|
+
self.device = device
|
26
|
+
self.lora_config = lora_config
|
27
|
+
|
28
|
+
@abstractproperty
|
29
|
+
def is_enabled(self) -> bool:
|
30
|
+
...
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def create_lora_manager(
|
34
|
+
self,
|
35
|
+
model: torch.nn.Module,
|
36
|
+
) -> Any:
|
37
|
+
...
|
38
|
+
|
39
|
+
@abstractmethod
|
40
|
+
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
41
|
+
lora_mapping: LoRAMapping) -> None:
|
42
|
+
...
|
43
|
+
|
44
|
+
@abstractmethod
|
45
|
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
46
|
+
...
|
47
|
+
|
48
|
+
@abstractmethod
|
49
|
+
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
50
|
+
...
|
51
|
+
|
52
|
+
@abstractmethod
|
53
|
+
def remove_lora(self, lora_id: int) -> bool:
|
54
|
+
...
|
55
|
+
|
56
|
+
@abstractmethod
|
57
|
+
def remove_all_loras(self):
|
58
|
+
...
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def list_loras(self) -> Set[int]:
|
62
|
+
...
|
63
|
+
|
64
|
+
|
65
|
+
class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
66
|
+
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
67
|
+
|
68
|
+
Every request, the requested LoRAs will be loaded (unless they are already
|
69
|
+
loaded), and every other LoRA will be unloaded."""
|
70
|
+
|
71
|
+
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
max_num_seqs: int,
|
76
|
+
max_num_batched_tokens: int,
|
77
|
+
vocab_size: int,
|
78
|
+
lora_config: LoRAConfig,
|
79
|
+
device: torch.device,
|
80
|
+
embedding_modules: Dict[str, str],
|
81
|
+
embedding_padding_modules: List[str],
|
82
|
+
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
83
|
+
):
|
84
|
+
self._lora_model_cls = lora_model_cls
|
85
|
+
self.embedding_modules = embedding_modules
|
86
|
+
self.embedding_padding_modules = embedding_padding_modules
|
87
|
+
# Lazily initialized by create_lora_manager.
|
88
|
+
self._lora_manager: LoRAModelManager
|
89
|
+
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
90
|
+
lora_config, device)
|
91
|
+
|
92
|
+
@property
|
93
|
+
def is_enabled(self) -> bool:
|
94
|
+
return True
|
95
|
+
|
96
|
+
def create_lora_manager(
|
97
|
+
self,
|
98
|
+
model: torch.nn.Module,
|
99
|
+
) -> Any:
|
100
|
+
lora_manager = create_lora_manager(
|
101
|
+
model,
|
102
|
+
max_num_seqs=self.max_num_seqs,
|
103
|
+
max_num_batched_tokens=self.max_num_batched_tokens,
|
104
|
+
vocab_size=self.vocab_size,
|
105
|
+
lora_config=self.lora_config,
|
106
|
+
lora_manager_cls=self._lora_manager_cls,
|
107
|
+
)
|
108
|
+
self._lora_manager = lora_manager
|
109
|
+
return lora_manager.model
|
110
|
+
|
111
|
+
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
112
|
+
lora_mapping: LoRAMapping) -> None:
|
113
|
+
self._apply_loras(lora_requests)
|
114
|
+
self._lora_manager.set_lora_mapping(lora_mapping)
|
115
|
+
|
116
|
+
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
117
|
+
loras_that_exist = self.list_loras()
|
118
|
+
loras_map = {
|
119
|
+
lora_request.lora_int_id: lora_request
|
120
|
+
for lora_request in lora_requests if lora_request
|
121
|
+
}
|
122
|
+
if len(loras_map) > self._lora_manager.lora_slots:
|
123
|
+
raise RuntimeError(
|
124
|
+
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
125
|
+
"than the number of GPU LoRA slots "
|
126
|
+
f"({self._lora_manager.lora_slots}).")
|
127
|
+
|
128
|
+
new_loras = set(loras_map)
|
129
|
+
loras_to_add = new_loras - loras_that_exist
|
130
|
+
loras_to_remove = loras_that_exist - new_loras
|
131
|
+
|
132
|
+
for lora_id in loras_to_remove:
|
133
|
+
self.remove_lora(lora_id)
|
134
|
+
|
135
|
+
for lora_id in loras_to_add:
|
136
|
+
self.add_lora(loras_map[lora_id])
|
137
|
+
|
138
|
+
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
|
139
|
+
try:
|
140
|
+
model = self._lora_manager.model
|
141
|
+
supported_lora_modules = model.supported_lora_modules
|
142
|
+
packed_modules_mapping = model.packed_modules_mapping
|
143
|
+
expected_lora_modules = []
|
144
|
+
for module in supported_lora_modules:
|
145
|
+
if module in packed_modules_mapping:
|
146
|
+
expected_lora_modules.extend(
|
147
|
+
packed_modules_mapping[module])
|
148
|
+
else:
|
149
|
+
expected_lora_modules.append(module)
|
150
|
+
lora = self._lora_model_cls.from_local_checkpoint(
|
151
|
+
lora_request.lora_local_path,
|
152
|
+
expected_lora_modules,
|
153
|
+
lora_model_id=lora_request.lora_int_id,
|
154
|
+
device="cpu",
|
155
|
+
dtype=self.lora_config.lora_dtype,
|
156
|
+
target_embedding_padding=self.vocab_size +
|
157
|
+
self.lora_config.lora_extra_vocab_size,
|
158
|
+
embedding_modules=self.embedding_modules,
|
159
|
+
embedding_padding_modules=self.embedding_padding_modules,
|
160
|
+
)
|
161
|
+
except Exception as e:
|
162
|
+
raise RuntimeError(
|
163
|
+
f"Loading lora {lora_request.lora_local_path} failed") from e
|
164
|
+
if lora.rank > self.lora_config.max_lora_rank:
|
165
|
+
raise ValueError(
|
166
|
+
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
167
|
+
f"{self.lora_config.max_lora_rank}.")
|
168
|
+
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
|
169
|
+
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
|
170
|
+
f"is greater than lora_extra_vocab_size "
|
171
|
+
f"{self.lora_config.lora_extra_vocab_size}.")
|
172
|
+
return lora
|
173
|
+
|
174
|
+
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
175
|
+
if lora_request.lora_int_id in self.list_loras():
|
176
|
+
return False
|
177
|
+
return self._lora_manager.add_lora(
|
178
|
+
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
179
|
+
rank, self.embedding_modules))
|
180
|
+
|
181
|
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
182
|
+
if lora_request.lora_int_id in self.list_loras():
|
183
|
+
return False
|
184
|
+
lora = self._load_lora(lora_request)
|
185
|
+
loaded = self._lora_manager.add_lora(lora)
|
186
|
+
self._lora_manager.activate_lora(lora.id)
|
187
|
+
return loaded
|
188
|
+
|
189
|
+
def remove_lora(self, lora_id: int) -> bool:
|
190
|
+
return self._lora_manager.remove_lora(lora_id)
|
191
|
+
|
192
|
+
def remove_all_loras(self):
|
193
|
+
self._lora_manager.remove_all_loras()
|
194
|
+
|
195
|
+
def list_loras(self) -> Set[int]:
|
196
|
+
return set(self._lora_manager.list_loras())
|
197
|
+
|
198
|
+
|
199
|
+
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
200
|
+
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
201
|
+
|
202
|
+
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
|
203
|
+
(unless they are already loaded) and least recently used LoRAs will
|
204
|
+
be unloaded if the cache is above capacity."""
|
205
|
+
|
206
|
+
_lora_manager_cls: Type[
|
207
|
+
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
|
208
|
+
|
209
|
+
def create_lora_manager(
|
210
|
+
self,
|
211
|
+
model: torch.nn.Module,
|
212
|
+
) -> Any:
|
213
|
+
lora_manager = create_lora_manager(
|
214
|
+
model,
|
215
|
+
lora_manager_cls=self._lora_manager_cls,
|
216
|
+
max_num_seqs=self.max_num_seqs,
|
217
|
+
vocab_size=self.vocab_size,
|
218
|
+
lora_config=self.lora_config,
|
219
|
+
max_num_batched_tokens=self.max_num_batched_tokens,
|
220
|
+
)
|
221
|
+
self._lora_manager = lora_manager
|
222
|
+
return lora_manager.model
|
223
|
+
|
224
|
+
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
225
|
+
loras_map = {
|
226
|
+
lora_request.lora_int_id: lora_request
|
227
|
+
for lora_request in lora_requests if lora_request
|
228
|
+
}
|
229
|
+
if len(loras_map) > self._lora_manager.lora_slots:
|
230
|
+
raise RuntimeError(
|
231
|
+
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
232
|
+
"than the number of GPU LoRA slots "
|
233
|
+
f"({self._lora_manager.lora_slots}).")
|
234
|
+
for lora in loras_map.values():
|
235
|
+
self.add_lora(lora)
|
236
|
+
|
237
|
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
238
|
+
if lora_request.lora_int_id not in self.list_loras():
|
239
|
+
# Remove before we load the new lora to save memory
|
240
|
+
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
241
|
+
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
|
242
|
+
self._lora_manager.remove_oldest_lora()
|
243
|
+
lora = self._load_lora(lora_request)
|
244
|
+
loaded = self._lora_manager.add_lora(lora)
|
245
|
+
else:
|
246
|
+
# If the lora is already loaded, just touch it to
|
247
|
+
# update its position in the caches
|
248
|
+
loaded = self._lora_manager.get_lora(
|
249
|
+
lora_request.lora_int_id) is not None
|
250
|
+
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
251
|
+
return loaded
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
4
|
+
CompletionRequest)
|
5
|
+
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
6
|
+
get_lm_format_enforcer_guided_decoding_logits_processor)
|
7
|
+
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
8
|
+
get_outlines_guided_decoding_logits_processor)
|
9
|
+
from vllm.sampling_params import LogitsProcessor
|
10
|
+
|
11
|
+
|
12
|
+
async def get_guided_decoding_logits_processor(
|
13
|
+
guided_decoding_backend: str, request: Union[CompletionRequest,
|
14
|
+
ChatCompletionRequest],
|
15
|
+
tokenizer) -> Optional[LogitsProcessor]:
|
16
|
+
if guided_decoding_backend == 'outlines':
|
17
|
+
return await get_outlines_guided_decoding_logits_processor(
|
18
|
+
request, tokenizer)
|
19
|
+
if guided_decoding_backend == 'lm-format-enforcer':
|
20
|
+
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
21
|
+
request, tokenizer)
|
22
|
+
|
23
|
+
raise ValueError(
|
24
|
+
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
25
|
+
"Must be one of 'outlines, 'lm-format-enforcer'")
|
@@ -0,0 +1,70 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
from json import loads as json_loads
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
6
|
+
RegexParser, StringParser,
|
7
|
+
TokenEnforcerTokenizerData, UnionParser)
|
8
|
+
from lmformatenforcer.integrations.vllm import (
|
9
|
+
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
10
|
+
from pydantic import BaseModel
|
11
|
+
from transformers import PreTrainedTokenizerBase
|
12
|
+
|
13
|
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
14
|
+
CompletionRequest)
|
15
|
+
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
16
|
+
get_outlines_guided_decoding_logits_processor)
|
17
|
+
from vllm.sampling_params import LogitsProcessor
|
18
|
+
|
19
|
+
|
20
|
+
async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
21
|
+
request: Union[CompletionRequest, ChatCompletionRequest],
|
22
|
+
tokenizer) -> Optional[LogitsProcessor]:
|
23
|
+
"""
|
24
|
+
Given an OpenAI-compatible request, check for guided decoding parameters
|
25
|
+
and get the necessary logits processor for the given guide.
|
26
|
+
We cache logit processors by (guide, tokenizer), and on cache hit
|
27
|
+
we make a shallow copy to reuse the same underlying FSM.
|
28
|
+
"""
|
29
|
+
|
30
|
+
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
31
|
+
tokenizer)
|
32
|
+
character_level_parser: CharacterLevelParser
|
33
|
+
if request.guided_json:
|
34
|
+
schema = _normalize_json_schema_object(request.guided_json)
|
35
|
+
character_level_parser = JsonSchemaParser(schema)
|
36
|
+
elif request.guided_choice:
|
37
|
+
character_level_parser = UnionParser(
|
38
|
+
[StringParser(choice) for choice in request.guided_choice])
|
39
|
+
elif request.guided_regex:
|
40
|
+
character_level_parser = RegexParser(request.guided_regex)
|
41
|
+
elif request.guided_grammar:
|
42
|
+
# CFG grammar not supported by LMFE, revert to outlines
|
43
|
+
return await get_outlines_guided_decoding_logits_processor(
|
44
|
+
request, tokenizer)
|
45
|
+
elif (request.response_format is not None
|
46
|
+
and request.response_format.type == "json_object"):
|
47
|
+
character_level_parser = JsonSchemaParser(
|
48
|
+
None) # None means any json object
|
49
|
+
else:
|
50
|
+
return None
|
51
|
+
|
52
|
+
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
53
|
+
character_level_parser)
|
54
|
+
return logits_processor
|
55
|
+
|
56
|
+
|
57
|
+
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
58
|
+
if isinstance(schema, str):
|
59
|
+
return json_loads(schema)
|
60
|
+
if isinstance(schema, dict):
|
61
|
+
return schema
|
62
|
+
if isinstance(schema, BaseModel):
|
63
|
+
return schema.model_json_schema()
|
64
|
+
raise AssertionError(f"Unsupported schema type {schema}")
|
65
|
+
|
66
|
+
|
67
|
+
@lru_cache
|
68
|
+
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
69
|
+
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
|
70
|
+
return build_vllm_token_enforcer_tokenizer_data(tokenizer)
|