tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
@@ -10,7 +24,7 @@ from torchax.interop import jax_view, torch_view
|
|
|
10
24
|
from torchax.ops.mappings import t2j
|
|
11
25
|
from vllm.logger import init_logger
|
|
12
26
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
-
FusedMoEConfig, FusedMoEQuantConfig,
|
|
27
|
+
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
|
|
14
28
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
29
|
FusedMoEMethodBase)
|
|
16
30
|
from vllm.model_executor.layers.linear import LinearBase
|
|
@@ -28,44 +42,22 @@ from tpu_inference import envs
|
|
|
28
42
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
29
43
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
30
44
|
get_tpu_quant_method)
|
|
31
|
-
from tpu_inference.layers.
|
|
45
|
+
from tpu_inference.layers.common.quantization import (
|
|
46
|
+
dequantize_tensor_from_mxfp4_packed, quantize_tensor)
|
|
47
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
48
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
32
49
|
from tpu_inference.layers.vllm.linear_common import \
|
|
33
50
|
reorder_concatenated_tensor_for_sharding
|
|
34
51
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
35
52
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
36
53
|
VllmUnquantizedLinearMethod
|
|
54
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
37
55
|
|
|
38
|
-
|
|
56
|
+
REQUANTIZED_BLOCK_SIZE = 512
|
|
39
57
|
|
|
40
58
|
P = PartitionSpec
|
|
41
|
-
logger = init_logger(__name__)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
45
|
-
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
46
|
-
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
47
|
-
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
48
|
-
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
49
|
-
# we flatten them with the last dim.
|
|
50
|
-
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
54
|
-
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
55
|
-
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
56
|
-
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
57
|
-
return jnp.ldexp(ones, exponents)
|
|
58
59
|
|
|
59
|
-
|
|
60
|
-
def dequantize_block_weight(weight: jax.Array,
|
|
61
|
-
scale: jax.Array,
|
|
62
|
-
block_size: int,
|
|
63
|
-
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
64
|
-
orig_shape = weight.shape
|
|
65
|
-
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
66
|
-
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
67
|
-
scale, -1)
|
|
68
|
-
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
60
|
+
logger = init_logger(__name__)
|
|
69
61
|
|
|
70
62
|
|
|
71
63
|
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
@@ -87,17 +79,14 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
87
79
|
fused_mapping=self.packed_modules_mapping,
|
|
88
80
|
):
|
|
89
81
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
90
|
-
# TODO: Add support for MXFP4 Linear Method.
|
|
91
|
-
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
92
|
-
# implementation if you are interested in enabling MXFP4 here.
|
|
93
82
|
logger.warning_once(
|
|
94
83
|
"MXFP4 linear layer is not implemented - falling back to "
|
|
95
84
|
"UnquantizedLinearMethod.")
|
|
96
85
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
97
86
|
elif isinstance(layer, FusedMoE):
|
|
98
|
-
|
|
87
|
+
moe_config = self.get_moe_config(layer)
|
|
88
|
+
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
99
89
|
elif isinstance(layer, Attention):
|
|
100
|
-
# TODO: Add support for MXFP4 Attention.
|
|
101
90
|
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
102
91
|
"Skipping quantization for this layer.")
|
|
103
92
|
return None
|
|
@@ -116,225 +105,306 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
116
105
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
117
106
|
|
|
118
107
|
self.mesh = mesh
|
|
119
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
108
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
120
109
|
self.ep_axis_name = ep_axis_name
|
|
121
110
|
# TODO: Use autotune table once we have it.
|
|
122
111
|
self.block_size = {
|
|
123
|
-
"bt":
|
|
112
|
+
"bt": 256,
|
|
124
113
|
"bf": 1024,
|
|
125
|
-
"bd1":
|
|
126
|
-
"bd2":
|
|
127
|
-
"btc":
|
|
114
|
+
"bd1": 1024,
|
|
115
|
+
"bd2": 1024,
|
|
116
|
+
"btc": 256,
|
|
128
117
|
"bfc": 1024,
|
|
129
|
-
"bd1c":
|
|
130
|
-
"bd2c":
|
|
118
|
+
"bd1c": 1024,
|
|
119
|
+
"bd2c": 1024,
|
|
131
120
|
}
|
|
132
121
|
|
|
133
122
|
def get_fused_moe_quant_config(
|
|
134
123
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
layer.w13_bias,
|
|
139
|
-
layer.w2_bias,
|
|
124
|
+
return mxfp4_w4a16_moe_quant_config(
|
|
125
|
+
w1_scale=layer.w13_weight_scale,
|
|
126
|
+
w2_scale=layer.w2_weight_scale,
|
|
127
|
+
w1_bias=layer.w13_bias,
|
|
128
|
+
w2_bias=layer.w2_bias,
|
|
140
129
|
)
|
|
141
130
|
|
|
142
131
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
143
132
|
assert isinstance(layer, FusedMoE)
|
|
144
133
|
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
145
134
|
|
|
146
|
-
w13_weight =
|
|
147
|
-
w13_weight_scale =
|
|
148
|
-
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
135
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
136
|
+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
|
|
149
137
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
150
138
|
|
|
151
|
-
w2_weight =
|
|
152
|
-
w2_weight_scale =
|
|
153
|
-
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
139
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
140
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
154
141
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
155
142
|
|
|
156
|
-
#
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
#
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
143
|
+
# Wrap functions in jit to speedup requantization.
|
|
144
|
+
@jax.jit
|
|
145
|
+
def wrapper(w13_weight, w13_weight_scale, w13_bias, w2_weight,
|
|
146
|
+
w2_weight_scale, w2_bias):
|
|
147
|
+
# Dequantize fp4 weights into fp32.
|
|
148
|
+
w13_weight = dequantize_tensor_from_mxfp4_packed(
|
|
149
|
+
w13_weight, w13_weight_scale, 2)
|
|
150
|
+
w2_weight = dequantize_tensor_from_mxfp4_packed(
|
|
151
|
+
w2_weight, w2_weight_scale, 2)
|
|
152
|
+
|
|
153
|
+
num_experts, orig_hidden_size, orig_intermediate_size = w2_weight.shape
|
|
154
|
+
|
|
155
|
+
# Requantize the weights into TPU friendly block size.
|
|
156
|
+
w13_weight, w13_weight_scale = quantize_tensor(
|
|
157
|
+
jnp.float4_e2m1fn, w13_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
|
|
158
|
+
w2_weight, w2_weight_scale = quantize_tensor(
|
|
159
|
+
jnp.float4_e2m1fn, w2_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
|
|
160
|
+
|
|
161
|
+
intermediate_size = w2_weight.shape[-1]
|
|
162
|
+
hidden_size = w13_weight.shape[-1]
|
|
163
|
+
|
|
164
|
+
# Dims may have been padded to align with subchannel size during
|
|
165
|
+
# quantization. We pad the corresponding dim on other weight.
|
|
166
|
+
# NOTE: We perform padding after quantization as padding value can
|
|
167
|
+
# affect quantization numerics.
|
|
168
|
+
intermediate_padding_size = 2 * (intermediate_size -
|
|
169
|
+
orig_intermediate_size)
|
|
170
|
+
w13_weight = jnp.pad(w13_weight,
|
|
171
|
+
((0, 0), (0, intermediate_padding_size),
|
|
172
|
+
(0, 0)))
|
|
173
|
+
w13_weight_scale = jnp.pad(w13_weight_scale,
|
|
174
|
+
((0, 0), (0, intermediate_padding_size),
|
|
175
|
+
(0, 0)))
|
|
176
|
+
w13_bias = jnp.pad(w13_bias,
|
|
177
|
+
((0, 0), (0, intermediate_padding_size)))
|
|
178
|
+
|
|
179
|
+
hidden_padding_size = hidden_size - orig_hidden_size
|
|
180
|
+
w2_weight = jnp.pad(w2_weight,
|
|
181
|
+
((0, 0), (0, hidden_padding_size), (0, 0)))
|
|
182
|
+
w2_weight_scale = jnp.pad(w2_weight_scale,
|
|
183
|
+
((0, 0), (0, hidden_padding_size),
|
|
184
|
+
(0, 0)))
|
|
185
|
+
w2_bias = jnp.pad(w2_bias, ((0, 0), (0, hidden_padding_size)))
|
|
186
|
+
|
|
187
|
+
if layer.activation == "swigluoai":
|
|
188
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
189
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
190
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
191
|
+
# can still get the same result.
|
|
192
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
193
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
194
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
195
|
+
|
|
196
|
+
w1_weight_scale = w13_weight_scale[:, ::2, :]
|
|
197
|
+
w3_weight_scale = w13_weight_scale[:, 1::2, :]
|
|
198
|
+
w13_weight_scale = jnp.concat(
|
|
199
|
+
[w1_weight_scale, w3_weight_scale], axis=1)
|
|
200
|
+
|
|
201
|
+
w1_bias = w13_bias[:, ::2]
|
|
202
|
+
w3_bias = w13_bias[:, 1::2]
|
|
203
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
204
|
+
|
|
205
|
+
if self.use_kernel:
|
|
206
|
+
# Kernel expects:
|
|
207
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
208
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
209
|
+
# Current format:
|
|
210
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
211
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
212
|
+
|
|
213
|
+
w13_weight = w13_weight.reshape(num_experts, 2,
|
|
214
|
+
intermediate_size, hidden_size)
|
|
215
|
+
|
|
216
|
+
w13_weight_scale = w13_weight_scale.reshape(
|
|
217
|
+
num_experts, 2, intermediate_size, 1, -1)
|
|
218
|
+
w2_weight_scale = w2_weight_scale.reshape(
|
|
219
|
+
num_experts, hidden_size, 1, -1)
|
|
220
|
+
|
|
221
|
+
w13_bias = w13_bias.astype(jnp.float32).reshape(
|
|
222
|
+
num_experts, 2, 1, intermediate_size)
|
|
223
|
+
w2_bias = w2_bias.astype(jnp.float32).reshape(
|
|
224
|
+
num_experts, 1, hidden_size)
|
|
225
|
+
|
|
226
|
+
# Transpose non-constracting dim to right most dim
|
|
227
|
+
w13_weight = jnp.swapaxes(w13_weight, 2, 3)
|
|
228
|
+
w2_weight = jnp.swapaxes(w2_weight, 1, 2)
|
|
229
|
+
|
|
230
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
|
|
231
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
|
|
211
232
|
|
|
212
233
|
# Apply EP sharding
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
Format(Layout((0, 1, 2)),
|
|
231
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
232
|
-
|
|
233
|
-
w13_bias = jax.device_put(
|
|
234
|
-
w13_bias,
|
|
235
|
-
Format(Layout((0, 1)),
|
|
236
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
237
|
-
w2_bias = jax.device_put(
|
|
238
|
-
w2_bias,
|
|
239
|
-
Format(Layout((0, 1)),
|
|
240
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
241
|
-
|
|
234
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
235
|
+
|
|
236
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
237
|
+
w13_weight, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
238
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
239
|
+
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
240
|
+
|
|
241
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
242
|
+
w13_weight_scale,
|
|
243
|
+
Format(Layout((0, 1, 2, 3, 4)), ep_sharding))
|
|
244
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
245
|
+
w2_weight_scale, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
246
|
+
|
|
247
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
248
|
+
w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
249
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
250
|
+
w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
242
251
|
else:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
252
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
|
|
253
|
+
w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
|
|
254
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
|
|
255
|
+
w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
|
|
256
|
+
|
|
257
|
+
w13_bias = jnp.expand_dims(w13_bias, 1)
|
|
258
|
+
w2_bias = jnp.expand_dims(w2_bias, 1)
|
|
259
|
+
|
|
260
|
+
if layer.use_ep:
|
|
261
|
+
ep_sharding = NamedSharding(self.mesh,
|
|
262
|
+
P(ShardingAxisName.EXPERT))
|
|
263
|
+
|
|
264
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
265
|
+
w13_weight, ep_sharding)
|
|
266
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
267
|
+
w2_weight, ep_sharding)
|
|
268
|
+
|
|
269
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
270
|
+
w13_weight_scale, ep_sharding)
|
|
271
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
272
|
+
w2_weight_scale, ep_sharding)
|
|
273
|
+
|
|
274
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
275
|
+
w13_bias, ep_sharding)
|
|
276
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
277
|
+
w2_bias, ep_sharding)
|
|
278
|
+
|
|
279
|
+
else:
|
|
280
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
281
|
+
n_shards = get_mesh_shape_product(
|
|
282
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
283
|
+
assert intermediate_size % n_shards == 0
|
|
284
|
+
|
|
285
|
+
# Reorder w13 weights so that splitting w1 and w3 output
|
|
286
|
+
# can happen locally without any collective operations.
|
|
287
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
288
|
+
w13_weight,
|
|
289
|
+
output_sizes,
|
|
290
|
+
n_shards,
|
|
291
|
+
dim=1,
|
|
292
|
+
)
|
|
293
|
+
w13_weight_scale = reorder_concatenated_tensor_for_sharding(
|
|
294
|
+
w13_weight_scale,
|
|
295
|
+
output_sizes,
|
|
296
|
+
n_shards,
|
|
297
|
+
dim=3,
|
|
298
|
+
)
|
|
299
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
300
|
+
w13_bias,
|
|
301
|
+
output_sizes,
|
|
302
|
+
n_shards,
|
|
303
|
+
dim=2,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
307
|
+
w13_weight,
|
|
308
|
+
NamedSharding(
|
|
309
|
+
self.mesh,
|
|
310
|
+
P(None, ShardingAxisName.MLP_TENSOR, None)))
|
|
311
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
312
|
+
w2_weight,
|
|
313
|
+
NamedSharding(
|
|
314
|
+
self.mesh,
|
|
315
|
+
P(None, None, ShardingAxisName.MLP_TENSOR)))
|
|
316
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
317
|
+
w13_weight_scale,
|
|
318
|
+
NamedSharding(
|
|
319
|
+
self.mesh,
|
|
320
|
+
P(None, None, None, ShardingAxisName.MLP_TENSOR)))
|
|
321
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
322
|
+
w2_weight_scale,
|
|
323
|
+
NamedSharding(
|
|
324
|
+
self.mesh,
|
|
325
|
+
P(None, ShardingAxisName.MLP_TENSOR, None, None)))
|
|
326
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
327
|
+
w13_bias,
|
|
328
|
+
NamedSharding(
|
|
329
|
+
self.mesh,
|
|
330
|
+
P(None, None, ShardingAxisName.MLP_TENSOR)))
|
|
331
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
332
|
+
w2_bias, NamedSharding(self.mesh, P(None, None, None)))
|
|
333
|
+
|
|
334
|
+
return w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias
|
|
335
|
+
|
|
336
|
+
w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias = wrapper(
|
|
337
|
+
w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale,
|
|
338
|
+
w2_bias)
|
|
269
339
|
|
|
270
340
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
271
341
|
requires_grad=False)
|
|
272
|
-
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
273
|
-
|
|
274
342
|
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
275
|
-
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
276
343
|
|
|
277
|
-
|
|
344
|
+
layer.w13_weight_scale = Parameter(torch_view(w13_weight_scale),
|
|
345
|
+
requires_grad=False)
|
|
346
|
+
layer.w2_weight_scale = Parameter(torch_view(w2_weight_scale),
|
|
347
|
+
requires_grad=False)
|
|
348
|
+
|
|
349
|
+
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
350
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
278
351
|
|
|
279
352
|
def apply(
|
|
280
353
|
self,
|
|
281
354
|
layer: torch.nn.Module,
|
|
282
355
|
x: torch.Tensor,
|
|
283
356
|
router_logits: torch.Tensor,
|
|
284
|
-
top_k: int,
|
|
285
|
-
renormalize: bool,
|
|
286
|
-
use_grouped_topk: bool = False,
|
|
287
|
-
topk_group: Optional[int] = None,
|
|
288
|
-
num_expert_group: Optional[int] = None,
|
|
289
|
-
global_num_experts: int = -1,
|
|
290
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
291
|
-
custom_routing_function: Optional[Callable] = None,
|
|
292
|
-
scoring_func: str = "softmax",
|
|
293
|
-
routed_scaling_factor: float = 1.0,
|
|
294
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
295
|
-
apply_router_weight_on_input: bool = False,
|
|
296
|
-
activation: str = "silu",
|
|
297
|
-
enable_eplb: bool = False,
|
|
298
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
299
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
300
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
301
357
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
302
358
|
assert isinstance(layer, FusedMoE)
|
|
303
|
-
if scoring_func != "softmax":
|
|
359
|
+
if layer.scoring_func != "softmax":
|
|
304
360
|
raise NotImplementedError(
|
|
305
361
|
"Only softmax is supported for scoring_func")
|
|
306
362
|
|
|
307
|
-
|
|
363
|
+
x = jax_view(x)
|
|
364
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
365
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
366
|
+
w13_weight_scale = jax_view(layer.w13_weight_scale)
|
|
367
|
+
w2_weight_scale = jax_view(layer.w2_weight_scale)
|
|
368
|
+
w13_bias = jax_view(layer.w13_bias)
|
|
369
|
+
w2_bias = jax_view(layer.w2_bias)
|
|
370
|
+
gating_output = jax_view(router_logits)
|
|
371
|
+
|
|
372
|
+
if self.use_kernel:
|
|
373
|
+
actual_hidden_size = x.shape[-1]
|
|
374
|
+
padding_size = w13_weight.shape[-2] - actual_hidden_size
|
|
375
|
+
x = jnp.pad(x, ((0, 0), (0, padding_size)))
|
|
308
376
|
output = fused_ep_moe(
|
|
309
377
|
mesh=self.mesh,
|
|
310
|
-
tokens=
|
|
311
|
-
w1=
|
|
312
|
-
w2=
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
378
|
+
tokens=x,
|
|
379
|
+
w1=w13_weight,
|
|
380
|
+
w2=w2_weight,
|
|
381
|
+
w1_scale=w13_weight_scale,
|
|
382
|
+
w2_scale=w2_weight_scale,
|
|
383
|
+
b1=w13_bias,
|
|
384
|
+
b2=w2_bias,
|
|
385
|
+
gating_output=gating_output,
|
|
386
|
+
subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
|
|
387
|
+
top_k=layer.top_k,
|
|
317
388
|
ep_axis_name=self.ep_axis_name,
|
|
318
|
-
renormalize_topk_logits=renormalize,
|
|
319
|
-
act_fn=activation,
|
|
389
|
+
renormalize_topk_logits=layer.renormalize,
|
|
390
|
+
act_fn=layer.activation,
|
|
320
391
|
**self.block_size,
|
|
321
|
-
)
|
|
392
|
+
)[:, :actual_hidden_size]
|
|
322
393
|
else:
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
renormalize=renormalize,
|
|
334
|
-
reduce_results=layer.reduce_results,
|
|
394
|
+
output = fused_moe_func(
|
|
395
|
+
hidden_states=x,
|
|
396
|
+
w1=w13_weight,
|
|
397
|
+
w2=w2_weight,
|
|
398
|
+
w1_scale=w13_weight_scale,
|
|
399
|
+
w2_scale=w2_weight_scale,
|
|
400
|
+
w1_bias=w13_bias,
|
|
401
|
+
w2_bias=w2_bias,
|
|
402
|
+
gating_output=gating_output,
|
|
403
|
+
topk=layer.top_k,
|
|
404
|
+
renormalize=layer.renormalize,
|
|
335
405
|
mesh=self.mesh,
|
|
336
406
|
use_ep=layer.use_ep,
|
|
337
|
-
activation=activation,
|
|
407
|
+
activation=layer.activation,
|
|
338
408
|
)
|
|
339
409
|
|
|
340
410
|
return torch_view(output)
|