tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_rope(
|
|
9
|
+
# (seq_len, num_heads, head_dim)
|
|
10
|
+
inputs: jax.Array,
|
|
11
|
+
# (3, seq_len) for M-RoPE, otherwise (seq_len,)
|
|
12
|
+
positions: jax.Array,
|
|
13
|
+
head_dim: int,
|
|
14
|
+
rope_theta: float = 10000,
|
|
15
|
+
rope_scaling: Dict[str, Any] = None,
|
|
16
|
+
rope_input_ordering: str = "split",
|
|
17
|
+
) -> jax.Array:
|
|
18
|
+
"""
|
|
19
|
+
Applies Rotary Positional Embedding using the sine and cosine strategy.
|
|
20
|
+
|
|
21
|
+
This implementation assumes the input tensor has a shape that might include
|
|
22
|
+
padding on the last dimension (head_dim).
|
|
23
|
+
RoPE is applied only to the first `head_dim` features, and the result is
|
|
24
|
+
padded back to the original dimension if necessary.
|
|
25
|
+
If rope_input_ordering is "split", then the input pairs for rotation are taken one from the
|
|
26
|
+
first and one from the second half of the head_dim. If it is "interleaved" then
|
|
27
|
+
adjacent values are used as inputs for rotation.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# M-RoPE support for Qwen2.5-VL
|
|
31
|
+
if positions.ndim == 2 and positions.shape[0] == 3:
|
|
32
|
+
mrope_section = rope_scaling.get("mrope_section",
|
|
33
|
+
None) if rope_scaling else None
|
|
34
|
+
# NOTE: We assume mrope_section is always available
|
|
35
|
+
# as Qwen2.5-VL is the only model using mrope
|
|
36
|
+
assert mrope_section is not None
|
|
37
|
+
|
|
38
|
+
split_indices = [mrope_section[0], mrope_section[0] + mrope_section[1]]
|
|
39
|
+
|
|
40
|
+
# Indices for the features to be rotated (first half of head_dim)
|
|
41
|
+
all_freq_indices = jnp.arange(head_dim // 2)
|
|
42
|
+
|
|
43
|
+
# Split the indices according to mrope_section. This is valid because split_indices are static.
|
|
44
|
+
freq_indices_split = jnp.split(all_freq_indices, split_indices)
|
|
45
|
+
# freq_indices_split is a list of 3 JAX arrays.
|
|
46
|
+
|
|
47
|
+
cos_list = []
|
|
48
|
+
sin_list = []
|
|
49
|
+
|
|
50
|
+
for i in range(3): # For each of the 3 position dimensions
|
|
51
|
+
current_indices = freq_indices_split[i]
|
|
52
|
+
|
|
53
|
+
if current_indices.size == 0:
|
|
54
|
+
# This section is empty, skip.
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
# inv_freq shape: (mrope_section[i],)
|
|
58
|
+
inv_freq = 1.0 / (rope_theta**(current_indices * 2.0 / head_dim))
|
|
59
|
+
|
|
60
|
+
# positions[i]: (seq_len,)
|
|
61
|
+
# freqs shape: (seq_len, mrope_section[i])
|
|
62
|
+
freqs = jnp.outer(positions[i], inv_freq)
|
|
63
|
+
|
|
64
|
+
cos_list.append(jnp.cos(freqs))
|
|
65
|
+
sin_list.append(jnp.sin(freqs))
|
|
66
|
+
|
|
67
|
+
# Concatenate along the feature dimension
|
|
68
|
+
# cos, sin shape: (seq_len, head_dim//2)
|
|
69
|
+
cos = jnp.concatenate(cos_list, axis=1)
|
|
70
|
+
sin = jnp.concatenate(sin_list, axis=1)
|
|
71
|
+
|
|
72
|
+
# Add num_heads dimension for broadcasting
|
|
73
|
+
cos = cos[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
|
|
74
|
+
sin = sin[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
|
|
75
|
+
|
|
76
|
+
# Apply rotation
|
|
77
|
+
inputs_real = inputs[..., :head_dim // 2]
|
|
78
|
+
inputs_imag = inputs[..., head_dim // 2:head_dim]
|
|
79
|
+
|
|
80
|
+
outputs_real = inputs_real * cos - inputs_imag * sin
|
|
81
|
+
outputs_imag = inputs_real * sin + inputs_imag * cos
|
|
82
|
+
|
|
83
|
+
out = jnp.concatenate([outputs_real, outputs_imag], axis=-1)
|
|
84
|
+
|
|
85
|
+
# Standard RoPE
|
|
86
|
+
else:
|
|
87
|
+
# Calculate inverse frequencies (timescale)
|
|
88
|
+
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
|
|
89
|
+
timescale = 1.0 / (rope_theta**fraction)
|
|
90
|
+
|
|
91
|
+
# Apply scaling if provided
|
|
92
|
+
if rope_scaling:
|
|
93
|
+
timescale = apply_rope_scaling(timescale, rope_scaling)
|
|
94
|
+
|
|
95
|
+
# Prepare for rotation by calculating sin and cos values
|
|
96
|
+
# `sinusoid_inp` gets shape (batch * seq_len, head_dim/2)
|
|
97
|
+
sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
|
|
98
|
+
|
|
99
|
+
# Broadcast over the 'heads' dimension, assuming shape (batch*seq, heads, head_dim)
|
|
100
|
+
sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
|
|
101
|
+
sin = jnp.sin(sinusoid_inp)
|
|
102
|
+
cos = jnp.cos(sinusoid_inp)
|
|
103
|
+
|
|
104
|
+
if rope_input_ordering == "interleaved":
|
|
105
|
+
# Reshape to group adjacent features for rotation, matching new_apply_rope
|
|
106
|
+
rotary_inputs = inputs[
|
|
107
|
+
..., :head_dim] # Take just the non-padded amount.
|
|
108
|
+
reshaped_inputs = rotary_inputs.reshape(*rotary_inputs.shape[:-1],
|
|
109
|
+
-1, 2)
|
|
110
|
+
|
|
111
|
+
# Apply the rotation
|
|
112
|
+
first_half = reshaped_inputs[..., 0]
|
|
113
|
+
second_half = reshaped_inputs[..., 1]
|
|
114
|
+
else:
|
|
115
|
+
first_half = inputs[..., :head_dim // 2]
|
|
116
|
+
second_half = inputs[..., head_dim // 2:head_dim]
|
|
117
|
+
|
|
118
|
+
first_part = first_half * cos - second_half * sin
|
|
119
|
+
second_part = second_half * cos + first_half * sin
|
|
120
|
+
|
|
121
|
+
# Combine the rotated parts and reshape back
|
|
122
|
+
if rope_input_ordering == "interleaved":
|
|
123
|
+
out_stacked = jnp.stack([first_part, second_part], axis=-1)
|
|
124
|
+
out = out_stacked.reshape(rotary_inputs.shape)
|
|
125
|
+
else:
|
|
126
|
+
out = jnp.concatenate([first_part, second_part], axis=-1)
|
|
127
|
+
|
|
128
|
+
# If the original input was padded, pad the output with zeros to match.
|
|
129
|
+
padded_head_dim = inputs.shape[-1]
|
|
130
|
+
if padded_head_dim > head_dim:
|
|
131
|
+
pad_width = padded_head_dim - head_dim
|
|
132
|
+
pad_config = [(0, 0)] * (out.ndim - 1) + [(0, pad_width)]
|
|
133
|
+
out = jnp.pad(out, pad_config)
|
|
134
|
+
|
|
135
|
+
return out.astype(inputs.dtype)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def apply_longrope(
|
|
139
|
+
inputs: jax.Array,
|
|
140
|
+
positions: jax.Array,
|
|
141
|
+
head_dim: int,
|
|
142
|
+
rope_scaling: Dict[str, Any],
|
|
143
|
+
original_max_position_embeddings: int,
|
|
144
|
+
max_position_embeddings: int,
|
|
145
|
+
rope_theta: float = 10000,
|
|
146
|
+
) -> jax.Array:
|
|
147
|
+
# LongRoPE implementation specific to Phi-3
|
|
148
|
+
# Implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L197-L235
|
|
149
|
+
|
|
150
|
+
scale = max_position_embeddings / original_max_position_embeddings
|
|
151
|
+
if scale <= 1.0:
|
|
152
|
+
mscale = 1.0
|
|
153
|
+
else:
|
|
154
|
+
mscale = jnp.sqrt(1 + (jnp.log(scale) /
|
|
155
|
+
jnp.log(original_max_position_embeddings)))
|
|
156
|
+
|
|
157
|
+
seq_len = inputs.shape[0]
|
|
158
|
+
if seq_len > original_max_position_embeddings:
|
|
159
|
+
long_factor = jnp.array(rope_scaling.get("long_factor"))
|
|
160
|
+
timescale = 1.0 / (long_factor * (rope_theta**(
|
|
161
|
+
(2 * jnp.arange(0, head_dim // 2)) / head_dim)))
|
|
162
|
+
else:
|
|
163
|
+
short_factor = jnp.array(rope_scaling.get("short_factor"))
|
|
164
|
+
timescale = 1.0 / (short_factor * (rope_theta**(
|
|
165
|
+
(2 * jnp.arange(0, head_dim // 2)) / head_dim)))
|
|
166
|
+
|
|
167
|
+
# Calculate RoPE positions
|
|
168
|
+
sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
|
|
169
|
+
sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
|
|
170
|
+
sin = jnp.sin(sinusoid_inp) * mscale
|
|
171
|
+
cos = jnp.cos(sinusoid_inp) * mscale
|
|
172
|
+
|
|
173
|
+
# Padding logic
|
|
174
|
+
padded_head_dim = inputs.shape[-1]
|
|
175
|
+
|
|
176
|
+
# Apply RoPE mechanism
|
|
177
|
+
first_half = inputs[..., :head_dim // 2]
|
|
178
|
+
second_half = inputs[..., head_dim // 2:head_dim]
|
|
179
|
+
|
|
180
|
+
first_part = first_half * cos - second_half * sin
|
|
181
|
+
second_part = second_half * cos + first_half * sin
|
|
182
|
+
out = jnp.concatenate([first_part, second_part], axis=-1)
|
|
183
|
+
|
|
184
|
+
if padded_head_dim > head_dim:
|
|
185
|
+
out = jnp.pad(out, ((0, 0), (0, 0), (0, padded_head_dim - head_dim)))
|
|
186
|
+
|
|
187
|
+
return out.astype(inputs.dtype)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def apply_rope_scaling(freqs: jax.Array, rope_scaling: Dict[str,
|
|
191
|
+
Any]) -> jax.Array:
|
|
192
|
+
# Values obtained from grid search
|
|
193
|
+
scale_factor = rope_scaling.get("scale_factor", 8.0)
|
|
194
|
+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
|
|
195
|
+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
|
|
196
|
+
old_context_len = rope_scaling.get("original_max_position_embeddings",
|
|
197
|
+
8192)
|
|
198
|
+
|
|
199
|
+
low_freq_wavelen = old_context_len / low_freq_factor
|
|
200
|
+
high_freq_wavelen = old_context_len / high_freq_factor
|
|
201
|
+
|
|
202
|
+
wavelen = 2 * math.pi / freqs
|
|
203
|
+
smooth = (old_context_len / wavelen -
|
|
204
|
+
low_freq_factor) / (high_freq_factor - low_freq_factor)
|
|
205
|
+
|
|
206
|
+
high_freqs = jnp.where(wavelen < high_freq_wavelen, freqs, 0)
|
|
207
|
+
low_freqs = jnp.where(wavelen > low_freq_wavelen, freqs / scale_factor, 0)
|
|
208
|
+
mid_freqs = jnp.where(
|
|
209
|
+
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
|
210
|
+
(1 - smooth) * freqs / scale_factor + smooth * freqs,
|
|
211
|
+
0,
|
|
212
|
+
)
|
|
213
|
+
new_freqs = high_freqs + low_freqs + mid_freqs
|
|
214
|
+
return new_freqs
|
|
File without changes
|