vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,115 @@
|
|
1
|
+
"""A layer that compute logits from hidden_stats."""
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
|
7
|
+
from vllm.distributed import tensor_model_parallel_gather
|
8
|
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
9
|
+
|
10
|
+
|
11
|
+
class LogitsProcessor(nn.Module):
|
12
|
+
"""Process logits and apply logits processors from sampling metadata.
|
13
|
+
|
14
|
+
This layer does the following:
|
15
|
+
1. Gather logits from model hidden_states.
|
16
|
+
2. Scale logits if needed.
|
17
|
+
3. Apply logits processors (if any).
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self,
|
21
|
+
vocab_size: int,
|
22
|
+
org_vocab_size: Optional[int] = None,
|
23
|
+
scale: Optional[float] = 1.0,
|
24
|
+
logits_as_input: bool = False) -> None:
|
25
|
+
"""
|
26
|
+
Args:
|
27
|
+
scale: A scaling factor to apply to the logits.
|
28
|
+
"""
|
29
|
+
super().__init__()
|
30
|
+
self.scale = scale
|
31
|
+
self.vocab_size = vocab_size
|
32
|
+
# Whether the input is logits (default is hidden states).
|
33
|
+
self.logits_as_input = logits_as_input
|
34
|
+
# original vocabulary size (without LoRA).
|
35
|
+
self.org_vocab_size = org_vocab_size or vocab_size
|
36
|
+
|
37
|
+
def forward(
|
38
|
+
self,
|
39
|
+
embedding: torch.Tensor,
|
40
|
+
hidden_states: torch.Tensor,
|
41
|
+
sampling_metadata: SamplingMetadata,
|
42
|
+
embedding_bias: Optional[torch.Tensor] = None,
|
43
|
+
) -> torch.Tensor:
|
44
|
+
if self.logits_as_input:
|
45
|
+
logits = hidden_states
|
46
|
+
else:
|
47
|
+
hidden_states = _prune_hidden_states(hidden_states,
|
48
|
+
sampling_metadata)
|
49
|
+
|
50
|
+
# Get the logits for the next tokens.
|
51
|
+
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
52
|
+
|
53
|
+
if logits is not None:
|
54
|
+
logits *= self.scale
|
55
|
+
|
56
|
+
# Apply logits processors (if any).
|
57
|
+
logits = _apply_logits_processors(logits, sampling_metadata)
|
58
|
+
|
59
|
+
return logits
|
60
|
+
|
61
|
+
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
62
|
+
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
63
|
+
# Get the logits for the next tokens.
|
64
|
+
logits = torch.matmul(hidden_states, embedding.t())
|
65
|
+
if embedding_bias is not None:
|
66
|
+
logits += embedding_bias
|
67
|
+
logits = tensor_model_parallel_gather(logits)
|
68
|
+
# Remove paddings in vocab (if any).
|
69
|
+
if logits is not None:
|
70
|
+
logits = logits[:, :self.org_vocab_size]
|
71
|
+
return logits
|
72
|
+
|
73
|
+
def extra_repr(self) -> str:
|
74
|
+
s = f"vocab_size={self.vocab_size}"
|
75
|
+
s += f", forg_vocab_size={self.org_vocab_size}"
|
76
|
+
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
|
77
|
+
return s
|
78
|
+
|
79
|
+
|
80
|
+
def _prune_hidden_states(
|
81
|
+
hidden_states: torch.Tensor,
|
82
|
+
sampling_metadata: SamplingMetadata,
|
83
|
+
) -> torch.Tensor:
|
84
|
+
return hidden_states.index_select(0,
|
85
|
+
sampling_metadata.selected_token_indices)
|
86
|
+
|
87
|
+
|
88
|
+
def _apply_logits_processors(
|
89
|
+
logits: torch.Tensor,
|
90
|
+
sampling_metadata: SamplingMetadata,
|
91
|
+
) -> torch.Tensor:
|
92
|
+
found_logits_processors = False
|
93
|
+
logits_processed = 0
|
94
|
+
for seq_group in sampling_metadata.seq_groups:
|
95
|
+
seq_ids = seq_group.seq_ids
|
96
|
+
sampling_params = seq_group.sampling_params
|
97
|
+
logits_processors = sampling_params.logits_processors
|
98
|
+
|
99
|
+
if logits_processors:
|
100
|
+
found_logits_processors = True
|
101
|
+
for seq_id, logits_row_idx in zip(seq_ids,
|
102
|
+
seq_group.sample_indices):
|
103
|
+
logits_row = logits[logits_row_idx]
|
104
|
+
token_ids = seq_group.seq_data[seq_id].output_token_ids
|
105
|
+
for logits_processor in logits_processors:
|
106
|
+
logits_row = logits_processor(token_ids, logits_row)
|
107
|
+
logits[logits_row_idx] = logits_row
|
108
|
+
|
109
|
+
logits_processed += len(seq_group.sample_indices) + len(
|
110
|
+
seq_group.prompt_logprob_indices)
|
111
|
+
|
112
|
+
if found_logits_processors:
|
113
|
+
# verifies that no rows in logits were missed unexpectedly
|
114
|
+
assert logits_processed == logits.shape[0]
|
115
|
+
return logits
|
File without changes
|
@@ -0,0 +1,157 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
def seeded_uniform(
|
9
|
+
*size,
|
10
|
+
seeds: torch.Tensor,
|
11
|
+
out: Optional[torch.Tensor] = None,
|
12
|
+
dtype: Optional[torch.dtype] = None,
|
13
|
+
device: Optional[Union[torch.device, str]] = None,
|
14
|
+
pin_memory: Optional[bool] = False,
|
15
|
+
) -> torch.Tensor:
|
16
|
+
"""Similar to torch.rand, but allows for seeds to be set per row.
|
17
|
+
|
18
|
+
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
|
19
|
+
If it is 3d, the additional seeds needed will be derived automatically
|
20
|
+
in a deterministic fashion:
|
21
|
+
[
|
22
|
+
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
|
23
|
+
]
|
24
|
+
"""
|
25
|
+
n_dims = len(size)
|
26
|
+
|
27
|
+
if n_dims > 3:
|
28
|
+
raise ValueError("seeded_uniform only supports up to 3D tensors")
|
29
|
+
|
30
|
+
if out is None:
|
31
|
+
out = torch.empty(*size,
|
32
|
+
dtype=dtype,
|
33
|
+
device=device,
|
34
|
+
pin_memory=pin_memory)
|
35
|
+
elif out.shape != size:
|
36
|
+
raise ValueError("shape of out and size must be the same")
|
37
|
+
|
38
|
+
if n_dims == 3:
|
39
|
+
n_rows, n_3d, n_cols = out.shape
|
40
|
+
stride_row = out.stride(0)
|
41
|
+
stride_3d = out.stride(1)
|
42
|
+
elif n_dims == 2:
|
43
|
+
n_rows, n_cols = out.shape
|
44
|
+
n_3d = 1
|
45
|
+
stride_row = out.stride(0)
|
46
|
+
stride_3d = 1
|
47
|
+
else:
|
48
|
+
n_cols = out.shape[0]
|
49
|
+
n_rows = 1
|
50
|
+
n_3d = 1
|
51
|
+
stride_row = 1
|
52
|
+
stride_3d = 1
|
53
|
+
|
54
|
+
if seeds.ndim != 1:
|
55
|
+
raise ValueError("seeds must be a 1D tensor")
|
56
|
+
|
57
|
+
if seeds.numel() != n_rows:
|
58
|
+
raise ValueError(
|
59
|
+
"seeds must have the same number of elements as out has rows")
|
60
|
+
|
61
|
+
# The philox PRNG Triton uses generates 4 random numbers at once.
|
62
|
+
# Therefore, the most efficient use of it is to divide the
|
63
|
+
# block size by 4, and then save the generated random numbers to
|
64
|
+
# each of the 4 slices of the tensor.
|
65
|
+
full_block_size = triton.next_power_of_2(n_cols)
|
66
|
+
philox_block_size = max(full_block_size // 4, 1)
|
67
|
+
n_slices = full_block_size // philox_block_size
|
68
|
+
num_warps = 4
|
69
|
+
# Manual tuning. This seems to give best performance on A100 for
|
70
|
+
# simple kernels like this.
|
71
|
+
if philox_block_size >= 8192:
|
72
|
+
num_warps = 32
|
73
|
+
elif philox_block_size >= 4096:
|
74
|
+
num_warps = 16
|
75
|
+
elif philox_block_size >= 2048:
|
76
|
+
num_warps = 8
|
77
|
+
|
78
|
+
_seeded_uniform_triton[(n_rows, n_3d)](
|
79
|
+
out,
|
80
|
+
seeds,
|
81
|
+
stride_row,
|
82
|
+
stride_3d,
|
83
|
+
seeds.stride(0),
|
84
|
+
n_rows,
|
85
|
+
n_3d,
|
86
|
+
n_cols,
|
87
|
+
n_slices=n_slices,
|
88
|
+
num_warps=num_warps,
|
89
|
+
block_size=philox_block_size,
|
90
|
+
)
|
91
|
+
return out
|
92
|
+
|
93
|
+
|
94
|
+
@triton.jit
|
95
|
+
def _seeded_uniform_triton(
|
96
|
+
out_ptr: torch.Tensor,
|
97
|
+
seed_ptr: torch.Tensor,
|
98
|
+
out_row_stride: int,
|
99
|
+
out_3d_stride: int,
|
100
|
+
seed_row_stride: int,
|
101
|
+
n_rows: int,
|
102
|
+
n_3d: int,
|
103
|
+
n_cols: int,
|
104
|
+
n_slices: tl.constexpr,
|
105
|
+
block_size: tl.constexpr,
|
106
|
+
):
|
107
|
+
"""
|
108
|
+
Generate a random float32 number in [0, 1) for each element in the output
|
109
|
+
tensor. The random numbers in a row generated using the seed for that row.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
out_ptr: The output tensor.
|
113
|
+
seed_ptr: The per-row seeds to use for random number generation.
|
114
|
+
out_row_stride: The stride between rows of the output tensor.
|
115
|
+
out_3d_stride: The stride between 3D slices of the output tensor.
|
116
|
+
seed_row_stride: The stride between rows of the seed tensor.
|
117
|
+
n_rows: The number of rows in the output tensor.
|
118
|
+
n_3d: The size of second dimension of the output tensor,
|
119
|
+
if output tensor is 3D.
|
120
|
+
n_cols: The number of columns in the output tensor.
|
121
|
+
n_slices: The number of philox outputs to use.
|
122
|
+
"""
|
123
|
+
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
|
124
|
+
|
125
|
+
# Get the row index.
|
126
|
+
row_idx = tl.program_id(axis=0)
|
127
|
+
three_d_idx = tl.program_id(axis=1)
|
128
|
+
|
129
|
+
philox_offsets = tl.arange(0, block_size)
|
130
|
+
# Get the seed for the current element.
|
131
|
+
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
|
132
|
+
if three_d_idx > 0:
|
133
|
+
seed ^= three_d_idx
|
134
|
+
# Generate random numbers in [0, 1).
|
135
|
+
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
|
136
|
+
|
137
|
+
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
|
138
|
+
three_d_idx * out_3d_stride)
|
139
|
+
out1_offsets = philox_offsets
|
140
|
+
tl.store(output_row_start_ptr + out1_offsets,
|
141
|
+
out1,
|
142
|
+
mask=out1_offsets < n_cols)
|
143
|
+
if n_slices > 1:
|
144
|
+
out2_offsets = tl.arange(block_size, block_size * 2)
|
145
|
+
tl.store(output_row_start_ptr + out2_offsets,
|
146
|
+
out2,
|
147
|
+
mask=out2_offsets < n_cols)
|
148
|
+
if n_slices > 2:
|
149
|
+
out3_offsets = tl.arange(block_size * 2, block_size * 3)
|
150
|
+
tl.store(output_row_start_ptr + out3_offsets,
|
151
|
+
out3,
|
152
|
+
mask=out3_offsets < n_cols)
|
153
|
+
if n_slices > 3:
|
154
|
+
out4_offsets = tl.arange(block_size * 3, block_size * 4)
|
155
|
+
tl.store(output_row_start_ptr + out4_offsets,
|
156
|
+
out4,
|
157
|
+
mask=out4_offsets < n_cols)
|
@@ -0,0 +1,406 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Optional, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import triton
|
6
|
+
import triton.language as tl
|
7
|
+
|
8
|
+
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
9
|
+
|
10
|
+
_EPS = 1e-6
|
11
|
+
|
12
|
+
# This is a hardcoded limit in Triton (max block size).
|
13
|
+
MAX_TRITON_N_COLS = 131072
|
14
|
+
|
15
|
+
|
16
|
+
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
17
|
+
"""Get the number of splits to use for Triton sampling.
|
18
|
+
|
19
|
+
Triton has a limit on the number of columns it can handle, so we need to
|
20
|
+
split the tensor and call the kernel multiple times if it's too large.
|
21
|
+
"""
|
22
|
+
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
23
|
+
|
24
|
+
|
25
|
+
def _multi_split_sample(
|
26
|
+
probs: torch.Tensor,
|
27
|
+
seeds: torch.Tensor,
|
28
|
+
n_splits: int,
|
29
|
+
sampled_tokens_size: Tuple[int, int],
|
30
|
+
sampled_logprobs_size: Tuple[int, int],
|
31
|
+
sample_indices: torch.Tensor,
|
32
|
+
logprobs: torch.Tensor,
|
33
|
+
*,
|
34
|
+
modify_greedy_probs: bool = False,
|
35
|
+
save_logprobs: bool = False,
|
36
|
+
):
|
37
|
+
"""Sample tokens where vocab size is split into multiple parts
|
38
|
+
(too large for Triton otherwise)."""
|
39
|
+
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
|
40
|
+
split_probs = probs.tensor_split(n_splits, 1)
|
41
|
+
split_logprobs = logprobs.tensor_split(n_splits, 1)
|
42
|
+
sampled_tokens_tmp = [
|
43
|
+
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
|
44
|
+
for _ in range(n_splits)
|
45
|
+
]
|
46
|
+
sampled_logprobs_tmp = [
|
47
|
+
torch.empty(sampled_logprobs_size,
|
48
|
+
dtype=probs.dtype,
|
49
|
+
device=probs.device) for _ in range(n_splits)
|
50
|
+
]
|
51
|
+
# We are purposefuly using sampled_tokens_size as we need to always
|
52
|
+
# save modified probs in this case.
|
53
|
+
sampled_modified_probs_tmp = [
|
54
|
+
torch.empty(sampled_tokens_size,
|
55
|
+
dtype=probs.dtype,
|
56
|
+
device=probs.device) for _ in range(n_splits)
|
57
|
+
]
|
58
|
+
for i in range(n_splits):
|
59
|
+
n_samples = sample_indices.shape[0]
|
60
|
+
n_cols = split_probs[i].shape[1]
|
61
|
+
n_best = sampled_tokens_tmp[i].shape[1]
|
62
|
+
uniform_noise = seeded_uniform(n_samples,
|
63
|
+
n_best,
|
64
|
+
n_cols,
|
65
|
+
seeds=seeds[i].flatten(),
|
66
|
+
device=split_probs[i].device,
|
67
|
+
dtype=split_probs[i].dtype)
|
68
|
+
# TODO(yard1): See if we can remove the contiguous() calls.
|
69
|
+
# Will need kernel support.
|
70
|
+
_sample(
|
71
|
+
split_probs[i].contiguous(),
|
72
|
+
split_logprobs[i].contiguous(),
|
73
|
+
sample_indices,
|
74
|
+
sampled_tokens_tmp[i],
|
75
|
+
sampled_logprobs_tmp[i],
|
76
|
+
sampled_modified_probs_tmp[i],
|
77
|
+
seeds[i],
|
78
|
+
uniform_noise,
|
79
|
+
modify_greedy_probs=False,
|
80
|
+
save_logprobs=save_logprobs,
|
81
|
+
save_modified_probs=True,
|
82
|
+
)
|
83
|
+
if i > 0:
|
84
|
+
# Add offset to sampled tokens
|
85
|
+
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
|
86
|
+
sampled_tokens = torch.stack(sampled_tokens_tmp)
|
87
|
+
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
|
88
|
+
# Reduce the results from the splits.
|
89
|
+
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
|
90
|
+
dim=0,
|
91
|
+
keepdim=True)
|
92
|
+
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
|
93
|
+
if save_logprobs:
|
94
|
+
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
|
95
|
+
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
|
96
|
+
else:
|
97
|
+
sampled_logprobs = None
|
98
|
+
sampled_modified_probs = sampled_modified_probs.squeeze(0)
|
99
|
+
|
100
|
+
if modify_greedy_probs:
|
101
|
+
# We need to modify the greedy probs for the sampled tokens.
|
102
|
+
# We can't do this in the kernel as we need to know the
|
103
|
+
# sampled tokens.
|
104
|
+
probs.fill_(0.0)
|
105
|
+
probs.scatter_(1, sampled_tokens, 1.0)
|
106
|
+
|
107
|
+
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
|
108
|
+
|
109
|
+
|
110
|
+
def sample(
|
111
|
+
probs: torch.Tensor,
|
112
|
+
seeds: torch.Tensor,
|
113
|
+
*,
|
114
|
+
max_best_of: int = 1,
|
115
|
+
sample_indices: Optional[torch.Tensor] = None,
|
116
|
+
logprobs: Optional[torch.Tensor] = None,
|
117
|
+
modify_greedy_probs: bool = False,
|
118
|
+
save_logprobs: bool = False,
|
119
|
+
_save_modified_probs: bool = False, # pylint: disable=invalid-name
|
120
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
121
|
+
"""Sample tokens from probs. with per-sequence seeds.
|
122
|
+
|
123
|
+
Can sample from a subset of sequences through sample_indices.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
probs: Probabilities to sample from.
|
127
|
+
shape = [batch_size, vocab_size]
|
128
|
+
seeds: Per-sequence seed values.
|
129
|
+
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
|
130
|
+
max_best_of: Number of samples to generate per sequence.
|
131
|
+
Sequence seed will be incremented by 1 each time.
|
132
|
+
sample_indices: Indices of sequences to sample from.
|
133
|
+
If not provided, will sample from all sequences.
|
134
|
+
shape = [n]
|
135
|
+
logprobs: Log-probabilities of the sampled tokens.
|
136
|
+
Only used for saving the logprobs if save_logprobs is True.
|
137
|
+
shape = [batch_size, vocab_size]
|
138
|
+
modify_greedy_probs: Whether to modify the greedy probabilities
|
139
|
+
for speculative sampling (sampled token = 1.0,
|
140
|
+
everything else = 0.0).
|
141
|
+
save_logprobs: Whether to save the log-probabilities of the
|
142
|
+
sampled tokens to a tensor.
|
143
|
+
_save_modified_probs: Whether to save the modified probabilities
|
144
|
+
(including gumbel noise) of the sampled tokens to a tensor.
|
145
|
+
DOES NOT include the modification done by modify_greedy_probs
|
146
|
+
(because we want to use the unmodified probs to pick the best
|
147
|
+
split in case of multi-split sampling).
|
148
|
+
This is exposed only for testing.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
sampled_tokens: shape = [n, max_best_of]
|
152
|
+
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
|
153
|
+
sampled_modified_probs: shape = [n, max_best_of]
|
154
|
+
if save_modified_probs else None
|
155
|
+
"""
|
156
|
+
if sample_indices is None:
|
157
|
+
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
|
158
|
+
|
159
|
+
sampled_tokens_size = (sample_indices.size(0), max_best_of)
|
160
|
+
if save_logprobs:
|
161
|
+
if logprobs is None:
|
162
|
+
raise ValueError(
|
163
|
+
"logprobs tensor must be provided if save_logprobs is True")
|
164
|
+
sampled_logprobs_size = sampled_tokens_size
|
165
|
+
else:
|
166
|
+
# Empty tensors to invoke the kernel
|
167
|
+
sampled_logprobs_size = (0, 0)
|
168
|
+
logprobs = probs
|
169
|
+
|
170
|
+
assert logprobs is not None
|
171
|
+
if _save_modified_probs:
|
172
|
+
sampled_modified_probs_size = sampled_tokens_size
|
173
|
+
else:
|
174
|
+
# Empty tensors to invoke the kernel
|
175
|
+
sampled_modified_probs_size = (0, 0)
|
176
|
+
|
177
|
+
# If the number of columns in probs is too large for Triton to handle,
|
178
|
+
# we split the tensor and sample from each split separately, and then
|
179
|
+
# do an argmax+gather to combine the results.
|
180
|
+
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
181
|
+
if n_splits > 1:
|
182
|
+
(sampled_tokens, sampled_logprobs,
|
183
|
+
sampled_modified_probs) = _multi_split_sample(
|
184
|
+
probs,
|
185
|
+
seeds,
|
186
|
+
n_splits,
|
187
|
+
sampled_tokens_size,
|
188
|
+
sampled_logprobs_size,
|
189
|
+
sample_indices,
|
190
|
+
logprobs=logprobs,
|
191
|
+
modify_greedy_probs=modify_greedy_probs,
|
192
|
+
save_logprobs=save_logprobs)
|
193
|
+
else:
|
194
|
+
sampled_tokens = torch.empty(sampled_tokens_size,
|
195
|
+
dtype=torch.long,
|
196
|
+
device=probs.device)
|
197
|
+
sampled_logprobs = torch.empty(sampled_logprobs_size,
|
198
|
+
dtype=probs.dtype,
|
199
|
+
device=probs.device)
|
200
|
+
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
|
201
|
+
dtype=probs.dtype,
|
202
|
+
device=probs.device)
|
203
|
+
n_samples = sample_indices.shape[0]
|
204
|
+
n_cols = probs.shape[1]
|
205
|
+
uniform_noise = seeded_uniform(n_samples,
|
206
|
+
max_best_of,
|
207
|
+
n_cols,
|
208
|
+
seeds=seeds.flatten(),
|
209
|
+
device=probs.device,
|
210
|
+
dtype=probs.dtype)
|
211
|
+
|
212
|
+
_sample(
|
213
|
+
probs,
|
214
|
+
logprobs,
|
215
|
+
sample_indices,
|
216
|
+
sampled_tokens,
|
217
|
+
sampled_logprobs,
|
218
|
+
sampled_modified_probs,
|
219
|
+
seeds,
|
220
|
+
uniform_noise,
|
221
|
+
modify_greedy_probs=modify_greedy_probs,
|
222
|
+
save_logprobs=save_logprobs,
|
223
|
+
save_modified_probs=_save_modified_probs,
|
224
|
+
)
|
225
|
+
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
|
226
|
+
sampled_modified_probs if _save_modified_probs else None)
|
227
|
+
|
228
|
+
|
229
|
+
def _sample(probs: torch.Tensor,
|
230
|
+
logprobs: torch.Tensor,
|
231
|
+
sample_indices: torch.Tensor,
|
232
|
+
output_samples: torch.Tensor,
|
233
|
+
output_logprobs: torch.Tensor,
|
234
|
+
output_modified_probs: torch.Tensor,
|
235
|
+
seeds: torch.Tensor,
|
236
|
+
uniform_noise: torch.Tensor,
|
237
|
+
*,
|
238
|
+
modify_greedy_probs: bool = False,
|
239
|
+
save_logprobs: bool = True,
|
240
|
+
save_modified_probs: bool = False) -> torch.Tensor:
|
241
|
+
"""Sample tokens from probs.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
probs [batch_size, vocab_size]: probs to sample from.
|
245
|
+
logprobs [batch_size, vocab_size]: logprobs (used when
|
246
|
+
save_logprobsis True).
|
247
|
+
sample_indices [n]: Indices of the samples to use for each row of probs.
|
248
|
+
output_samples [n, n_best]: Output tensor to store samples in.
|
249
|
+
output_logprobs [n, n_best]: Output tensor to store logprobs in.
|
250
|
+
output_modified_probs [n, n_best]: Output tensor to store
|
251
|
+
probs of chosen tokens in (modified with noise).
|
252
|
+
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
|
253
|
+
greedy sampling. Note this is ONLY used for determining
|
254
|
+
whether to use random sampling or not. The actual random
|
255
|
+
noise should be passed as uniform_noise.
|
256
|
+
uniform_noise [batch_size, n_best, vocab_size]: Uniform
|
257
|
+
noise to use for random sampling (will be converted
|
258
|
+
to exponential gumbel noise by the kernel).
|
259
|
+
modify_greedy_probs: If True, we modify the probs tensor in-place
|
260
|
+
to encode the sampling method used for each row. This is used
|
261
|
+
in speculative decoding. Only applies in greedy decoding.
|
262
|
+
save_logprobs: If True, we save the logprobs of the sampled tokens
|
263
|
+
in the output_logprobs tensor.
|
264
|
+
save_modified_probs: If True, we save the modified probs (with noise)
|
265
|
+
of the sampled tokens in the output_modified_probs tensor.
|
266
|
+
DOES NOT include the modification done by modify_greedy_probs
|
267
|
+
(because we want to use the unmodified probs to pick the best
|
268
|
+
split in case of multi-split sampling).
|
269
|
+
"""
|
270
|
+
n_samples = sample_indices.shape[0]
|
271
|
+
n_cols = probs.shape[1]
|
272
|
+
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
|
273
|
+
|
274
|
+
# The block size is the smallest power of two greater than the number of
|
275
|
+
# columns in probs
|
276
|
+
block_size = triton.next_power_of_2(n_cols)
|
277
|
+
num_warps = 4
|
278
|
+
# Manual tuning. This seems to give best performance on A100 for
|
279
|
+
# simple kernels like this.
|
280
|
+
if block_size >= 8192:
|
281
|
+
num_warps = 32
|
282
|
+
elif block_size >= 4096:
|
283
|
+
num_warps = 16
|
284
|
+
elif block_size >= 2048:
|
285
|
+
num_warps = 8
|
286
|
+
|
287
|
+
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
|
288
|
+
# instance per row of the probs matrix
|
289
|
+
_sample_triton[(n_samples, n_best)](
|
290
|
+
sample_indices,
|
291
|
+
output_samples,
|
292
|
+
output_logprobs,
|
293
|
+
output_modified_probs,
|
294
|
+
probs,
|
295
|
+
logprobs,
|
296
|
+
seeds,
|
297
|
+
uniform_noise,
|
298
|
+
output_samples.stride(0),
|
299
|
+
probs.stride(0),
|
300
|
+
uniform_noise.stride(0),
|
301
|
+
uniform_noise.stride(1) if n_best > 1 else 1,
|
302
|
+
n_samples,
|
303
|
+
n_cols,
|
304
|
+
n_best,
|
305
|
+
num_warps=num_warps,
|
306
|
+
block_size=block_size,
|
307
|
+
modify_greedy_probs=modify_greedy_probs,
|
308
|
+
save_logprobs=save_logprobs,
|
309
|
+
save_modified_probs=save_modified_probs,
|
310
|
+
)
|
311
|
+
return output_samples, output_logprobs, output_modified_probs
|
312
|
+
|
313
|
+
|
314
|
+
@triton.jit
|
315
|
+
def _uniform_to_exponential(uniform_noise):
|
316
|
+
"""Convert uniform samples to exponential samples."""
|
317
|
+
# tl.rand returns values in [0, 1), so we clamp lower bound
|
318
|
+
# to _EPS to avoid log(0) and thus division by 0 later
|
319
|
+
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
|
320
|
+
uniform_noise = tl.maximum(uniform_noise, lb)
|
321
|
+
# Use the inversion method to turn uniform samples
|
322
|
+
# into exponential samples
|
323
|
+
exponential_noise = -tl.log(uniform_noise)
|
324
|
+
return exponential_noise
|
325
|
+
|
326
|
+
|
327
|
+
@triton.jit
|
328
|
+
def _sample_triton(
|
329
|
+
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
|
330
|
+
output_logprobs_ptr: torch.Tensor,
|
331
|
+
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
|
332
|
+
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
|
333
|
+
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
|
334
|
+
probs_row_stride: int, uniform_noise_row_stride: int,
|
335
|
+
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
|
336
|
+
n_best: int, block_size: tl.constexpr,
|
337
|
+
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
|
338
|
+
save_modified_probs: tl.constexpr):
|
339
|
+
# The rows are independent, so we parallelize across those
|
340
|
+
sample_idx = tl.program_id(0)
|
341
|
+
best_idx = tl.program_id(1)
|
342
|
+
|
343
|
+
# Load the row index from DRAM
|
344
|
+
row_idx = tl.load(sample_indices_ptr + sample_idx)
|
345
|
+
seed = tl.load(seeds_ptr + sample_idx)
|
346
|
+
uses_random_sampling = seed != 0
|
347
|
+
|
348
|
+
# The stride represents how much we need to increase the
|
349
|
+
# pointer to advance 1 row
|
350
|
+
row_start_ptr = probs_ptr + row_idx * probs_row_stride
|
351
|
+
|
352
|
+
# The block size is the next power of two greater than n_cols,
|
353
|
+
# so we can fit each row in a single block
|
354
|
+
col_offsets = tl.arange(0, block_size)
|
355
|
+
|
356
|
+
# Load the row into SRAM, using a mask since block_size may be > than n_cols
|
357
|
+
row = tl.load(row_start_ptr + col_offsets,
|
358
|
+
mask=col_offsets < n_cols,
|
359
|
+
other=float("-inf"))
|
360
|
+
|
361
|
+
if uses_random_sampling:
|
362
|
+
uniform_noise_start_ptr = (uniform_noise_ptr +
|
363
|
+
sample_idx * uniform_noise_row_stride +
|
364
|
+
best_idx * uniform_noise_best_stride)
|
365
|
+
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
|
366
|
+
mask=col_offsets < n_cols,
|
367
|
+
other=0.5)
|
368
|
+
exponential_noise = _uniform_to_exponential(uniform_noise)
|
369
|
+
row /= exponential_noise
|
370
|
+
|
371
|
+
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
|
372
|
+
# clamp sampled token to n_cols - 1
|
373
|
+
# this should not be necessary, but we do it
|
374
|
+
# just in case
|
375
|
+
if sampled_token >= n_cols:
|
376
|
+
sampled_token = n_cols - 1
|
377
|
+
# Write back output to DRAM
|
378
|
+
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
|
379
|
+
best_idx)
|
380
|
+
tl.store(output_row_start_ptr, sampled_token)
|
381
|
+
|
382
|
+
if modify_greedy_probs: # noqa
|
383
|
+
if not uses_random_sampling:
|
384
|
+
# Set the probability of the sampled token to 1, all other
|
385
|
+
# tokens to zero. This is used in speculative decoding where
|
386
|
+
# the sampling method must be encoded within the sampled
|
387
|
+
# probability distributions.
|
388
|
+
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
|
389
|
+
tl.store(row_start_ptr + col_offsets,
|
390
|
+
row,
|
391
|
+
mask=col_offsets < n_cols)
|
392
|
+
|
393
|
+
if save_modified_probs:
|
394
|
+
output_row_start_ptr = (output_modified_probs_ptr +
|
395
|
+
sample_idx * output_row_stride + best_idx)
|
396
|
+
tl.store(output_row_start_ptr, sampled_value)
|
397
|
+
|
398
|
+
if save_logprobs:
|
399
|
+
# Load the row into SRAM, using a mask since block_size
|
400
|
+
# may be > than n_cols
|
401
|
+
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
|
402
|
+
sampled_token)
|
403
|
+
# Write back output to DRAM
|
404
|
+
output_row_start_ptr = (output_logprobs_ptr +
|
405
|
+
sample_idx * output_row_stride + best_idx)
|
406
|
+
tl.store(output_row_start_ptr, sampled_logprob)
|