tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import Optional, Tuple
|
|
18
|
+
|
|
19
|
+
import jax
|
|
20
|
+
from flax import nnx
|
|
21
|
+
from jax import numpy as jnp
|
|
22
|
+
from jax.experimental.layout import Layout, with_layout_constraint
|
|
23
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(kw_only=True)
|
|
27
|
+
class RotaryEmbedding(nnx.Module):
|
|
28
|
+
"""
|
|
29
|
+
An implementation of the original rotary positional embedding.
|
|
30
|
+
"""
|
|
31
|
+
rotary_dim: int
|
|
32
|
+
rope_theta: float
|
|
33
|
+
original_max_position_embeddings: int
|
|
34
|
+
dtype: jnp.dtype
|
|
35
|
+
sin_cos_cache: Optional[jax.Array] = field(init=False, default=None)
|
|
36
|
+
|
|
37
|
+
def initialize_cache(self):
|
|
38
|
+
"""Computes and caches the sin/cos embeddings."""
|
|
39
|
+
if self.sin_cos_cache is None:
|
|
40
|
+
self.sin_cos_cache = self._compute_sin_cos()
|
|
41
|
+
|
|
42
|
+
def _compute_inv_freq(self):
|
|
43
|
+
fractions_H = jnp.arange(0, self.rotary_dim, 2,
|
|
44
|
+
dtype=jnp.float32) / self.rotary_dim
|
|
45
|
+
inv_freq_H = 1.0 / (self.rope_theta**fractions_H)
|
|
46
|
+
return inv_freq_H
|
|
47
|
+
|
|
48
|
+
def _compute_sin_cos(self):
|
|
49
|
+
inv_freq_H = self._compute_inv_freq()
|
|
50
|
+
t = jnp.arange(self.original_max_position_embeddings,
|
|
51
|
+
dtype=jnp.float32)
|
|
52
|
+
|
|
53
|
+
freqs = jnp.einsum("...T,k->...Tk",
|
|
54
|
+
t,
|
|
55
|
+
inv_freq_H,
|
|
56
|
+
precision=jax.lax.Precision.HIGHEST)
|
|
57
|
+
sin, cos = jnp.sin(freqs), jnp.cos(freqs)
|
|
58
|
+
cache = jnp.concatenate((cos, sin), axis=-1)
|
|
59
|
+
return cache
|
|
60
|
+
|
|
61
|
+
def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
|
|
62
|
+
assert x_TNH.ndim == 3
|
|
63
|
+
assert self.sin_cos_cache is not None, "RoPE cache not initialized."
|
|
64
|
+
cos_sin_TH = self.sin_cos_cache[positions]
|
|
65
|
+
# cos, sin: (T, H/2)
|
|
66
|
+
cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
|
|
67
|
+
assert sin_TH.ndim == 2 and cos_TH.ndim == 2
|
|
68
|
+
# cos, sin: (T, 1, H/2)
|
|
69
|
+
cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
|
|
70
|
+
# first_half, second_half: (T, N, H/2)
|
|
71
|
+
first_half_TNH, second_half_TNH = jnp.split(x_TNH, 2, axis=-1)
|
|
72
|
+
combined = jnp.concatenate([
|
|
73
|
+
first_half_TNH * cos_T1H - second_half_TNH * sin_T1H,
|
|
74
|
+
second_half_TNH * cos_T1H + first_half_TNH * sin_T1H
|
|
75
|
+
],
|
|
76
|
+
axis=-1)
|
|
77
|
+
return combined.astype(self.dtype)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass(kw_only=True)
|
|
81
|
+
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
82
|
+
"""
|
|
83
|
+
Rotary Embedding for deepseek, with scaling and YaRN method.
|
|
84
|
+
"""
|
|
85
|
+
scaling_factor: float
|
|
86
|
+
beta_fast: int = 32
|
|
87
|
+
beta_slow: int = 1
|
|
88
|
+
mscale_value: float = 1
|
|
89
|
+
mscale_all_dim: float = 0
|
|
90
|
+
|
|
91
|
+
def initialize_cache(self, mesh: jax.sharding.Mesh):
|
|
92
|
+
"""Computes and caches the sin/cos embeddings."""
|
|
93
|
+
# The second condition is for the Qwix case, where we need to call `initialize_cache` on
|
|
94
|
+
# the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
|
|
95
|
+
# this method will have been called already, but we need to recompute the cache so that
|
|
96
|
+
# it's concrete (otherwise, it'll still be a jax.ShapeDtypeStruct).
|
|
97
|
+
if self.sin_cos_cache is not None and not isinstance(
|
|
98
|
+
self.sin_cos_cache, jax.ShapeDtypeStruct):
|
|
99
|
+
return
|
|
100
|
+
mscale_val = _yarn_get_mscale(
|
|
101
|
+
self.scaling_factor, self.mscale_value) / _yarn_get_mscale(
|
|
102
|
+
self.scaling_factor, self.mscale_all_dim)
|
|
103
|
+
replicated_sharding = NamedSharding(mesh, PartitionSpec())
|
|
104
|
+
self.mscale = jax.device_put(mscale_val, replicated_sharding)
|
|
105
|
+
self.sin_cos_cache = self._compute_sin_cos()
|
|
106
|
+
|
|
107
|
+
def _compute_inv_freq(self):
|
|
108
|
+
fractions = jnp.arange(0, self.rotary_dim, 2,
|
|
109
|
+
dtype=jnp.float32) / self.rotary_dim
|
|
110
|
+
inv_freq_extrapolation = 1.0 / (self.rope_theta**fractions)
|
|
111
|
+
inv_freq_interpolation = 1.0 / (self.scaling_factor *
|
|
112
|
+
self.rope_theta**fractions)
|
|
113
|
+
low, high = _yarn_find_correction_range(
|
|
114
|
+
self.beta_fast, self.beta_slow, self.rotary_dim, self.rope_theta,
|
|
115
|
+
self.original_max_position_embeddings)
|
|
116
|
+
|
|
117
|
+
# Get n-d rotational scaling corrected for extrapolation
|
|
118
|
+
inv_freq_mask = 1 - _yarn_linear_ramp_mask(
|
|
119
|
+
low, high, self.rotary_dim // 2).astype(jnp.float32)
|
|
120
|
+
inv_freq = inv_freq_interpolation * (
|
|
121
|
+
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
|
122
|
+
return inv_freq
|
|
123
|
+
|
|
124
|
+
@jax.jit
|
|
125
|
+
def _compute_sin_cos(self):
|
|
126
|
+
inv_freq_H = self._compute_inv_freq()
|
|
127
|
+
t = jnp.arange(self.original_max_position_embeddings *
|
|
128
|
+
self.scaling_factor,
|
|
129
|
+
dtype=jnp.float32)
|
|
130
|
+
freqs = jnp.einsum("...T,k->...Tk", t, inv_freq_H)
|
|
131
|
+
sin, cos = jnp.sin(freqs) * self.mscale, jnp.cos(freqs) * self.mscale
|
|
132
|
+
cache = jnp.concatenate((cos, sin), axis=-1)
|
|
133
|
+
H = cache.shape[1]
|
|
134
|
+
target_dim = ((H - 1) // 128 + 1) * 128
|
|
135
|
+
padding_amount = target_dim - self.rotary_dim
|
|
136
|
+
pad_width = ((0, 0), (0, padding_amount))
|
|
137
|
+
cache_padded = jnp.pad(cache, pad_width, mode='constant')
|
|
138
|
+
desired_layout = Layout(major_to_minor=(1, 0))
|
|
139
|
+
cache_padded = with_layout_constraint(cache_padded, desired_layout)
|
|
140
|
+
return cache_padded
|
|
141
|
+
|
|
142
|
+
def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
|
|
143
|
+
assert x_TNH.ndim == 3
|
|
144
|
+
assert self.sin_cos_cache is not None, "RoPE cache not initialized."
|
|
145
|
+
cos_sin_padded = self.sin_cos_cache[positions]
|
|
146
|
+
cos_sin_TH = cos_sin_padded[:, :self.rotary_dim]
|
|
147
|
+
# cos, sin: (T, H/2)
|
|
148
|
+
cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
|
|
149
|
+
assert sin_TH.ndim == 2 and cos_TH.ndim == 2
|
|
150
|
+
# cos, sin: (T, 1, H/2)
|
|
151
|
+
cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
|
|
152
|
+
# even, odd: (T, N, H/2)
|
|
153
|
+
even_TNH, odd_TNH = x_TNH[..., ::2], x_TNH[..., 1::2]
|
|
154
|
+
combined_TNH = jnp.stack([
|
|
155
|
+
even_TNH * cos_T1H - odd_TNH * sin_T1H,
|
|
156
|
+
odd_TNH * cos_T1H + even_TNH * sin_T1H
|
|
157
|
+
],
|
|
158
|
+
axis=-1).reshape(x_TNH.shape)
|
|
159
|
+
return combined_TNH.astype(self.dtype)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# Calculates the temperature scaling factor for YaRN to adjust
|
|
163
|
+
# RoPE embedding magnitudes.
|
|
164
|
+
def _yarn_get_mscale(scale, mscale):
|
|
165
|
+
return jnp.where(scale <= 1, 1.0, 0.1 * mscale * jnp.log(scale) + 1.0)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# Inverses dim formula to find dim based on number of rotations.
|
|
169
|
+
def _yarn_find_correction_dim(num_rotations,
|
|
170
|
+
dim,
|
|
171
|
+
base=10000,
|
|
172
|
+
max_position_embeddings=2048):
|
|
173
|
+
return (dim * math.log(max_position_embeddings /
|
|
174
|
+
(num_rotations * 2 * math.pi))) / (2 *
|
|
175
|
+
math.log(base))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# Finds dim range bounds based on rotations.
|
|
179
|
+
def _yarn_find_correction_range(low_rot,
|
|
180
|
+
high_rot,
|
|
181
|
+
dim,
|
|
182
|
+
base=10000,
|
|
183
|
+
max_position_embeddings=2048):
|
|
184
|
+
low = math.floor(
|
|
185
|
+
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
|
186
|
+
high = math.ceil(
|
|
187
|
+
_yarn_find_correction_dim(high_rot, dim, base,
|
|
188
|
+
max_position_embeddings))
|
|
189
|
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# Creates a 1D mask that ramps linearly from 0 to 1 between min and max indices.
|
|
193
|
+
def _yarn_linear_ramp_mask(min, max, dim):
|
|
194
|
+
if min == max:
|
|
195
|
+
max += 0.001 # Prevent singularity
|
|
196
|
+
|
|
197
|
+
linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min)
|
|
198
|
+
ramp_func = jnp.clip(linear_func, 0, 1)
|
|
199
|
+
return ramp_func
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass(kw_only=True)
|
|
203
|
+
class GptOssRotaryEmbedding(nnx.Module):
|
|
204
|
+
"""
|
|
205
|
+
JAX implementation of the Rotary Positional Embedding with YaRN scaling.
|
|
206
|
+
"""
|
|
207
|
+
head_dim: int
|
|
208
|
+
rope_theta: float
|
|
209
|
+
dtype: jnp.dtype
|
|
210
|
+
initial_context_length: int = 4096
|
|
211
|
+
rope_scaling_factor: float = 1.0
|
|
212
|
+
rope_ntk_alpha: float = 1.0
|
|
213
|
+
rope_ntk_beta: float = 32.0
|
|
214
|
+
|
|
215
|
+
def _compute_concentration_and_inv_freq(self) -> Tuple[float, jax.Array]:
|
|
216
|
+
"""
|
|
217
|
+
Computes the inverse frequencies and concentration factor for YaRN.
|
|
218
|
+
See YaRN paper: https://arxiv.org/abs/2309.00071
|
|
219
|
+
"""
|
|
220
|
+
freq = self.rope_theta**(
|
|
221
|
+
jnp.arange(0, self.head_dim, 2, dtype=jnp.float32) / self.head_dim)
|
|
222
|
+
|
|
223
|
+
if self.rope_scaling_factor > 1.0:
|
|
224
|
+
concentration = 0.1 * jnp.log(self.rope_scaling_factor) + 1.0
|
|
225
|
+
|
|
226
|
+
d_half = self.head_dim / 2
|
|
227
|
+
# NTK by parts
|
|
228
|
+
low = (d_half * jnp.log(self.initial_context_length /
|
|
229
|
+
(self.rope_ntk_beta * 2 * jnp.pi)) /
|
|
230
|
+
jnp.log(self.rope_theta))
|
|
231
|
+
high = (d_half * jnp.log(self.initial_context_length /
|
|
232
|
+
(self.rope_ntk_alpha * 2 * jnp.pi)) /
|
|
233
|
+
jnp.log(self.rope_theta))
|
|
234
|
+
|
|
235
|
+
interpolation = 1.0 / (self.rope_scaling_factor * freq)
|
|
236
|
+
extrapolation = 1.0 / freq
|
|
237
|
+
|
|
238
|
+
ramp = (jnp.arange(d_half, dtype=jnp.float32) - low) / (high - low)
|
|
239
|
+
mask = 1 - jnp.clip(ramp, 0, 1)
|
|
240
|
+
|
|
241
|
+
inv_freq = interpolation * (1 - mask) + extrapolation * mask
|
|
242
|
+
else:
|
|
243
|
+
concentration = 1.0
|
|
244
|
+
inv_freq = 1.0 / freq
|
|
245
|
+
|
|
246
|
+
return concentration, inv_freq
|
|
247
|
+
|
|
248
|
+
def _compute_cos_sin(self,
|
|
249
|
+
positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
|
|
250
|
+
"""Computes cosine and sine embeddings for given positions."""
|
|
251
|
+
concentration, inv_freq_H = self._compute_concentration_and_inv_freq()
|
|
252
|
+
|
|
253
|
+
# freqs: (T, H/2)
|
|
254
|
+
freqs = jnp.einsum("T,H->TH",
|
|
255
|
+
positions.astype(jnp.float32),
|
|
256
|
+
inv_freq_H,
|
|
257
|
+
precision=jax.lax.Precision.HIGHEST)
|
|
258
|
+
|
|
259
|
+
cos = jnp.cos(freqs) * concentration
|
|
260
|
+
sin = jnp.sin(freqs) * concentration
|
|
261
|
+
return cos, sin
|
|
262
|
+
|
|
263
|
+
def __call__(self, query_TNH: jax.Array, key_TNH: jax.Array,
|
|
264
|
+
positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
|
|
265
|
+
"""
|
|
266
|
+
Applies rotary embeddings to query and key tensors.
|
|
267
|
+
Args:
|
|
268
|
+
query_TNH: Query tensor with shape (num_tokens, num_heads, head_dim)
|
|
269
|
+
key_TNH: Key tensor with shape (num_tokens, num_kv_heads, head_dim)
|
|
270
|
+
positions: A 1D array of token positions.
|
|
271
|
+
"""
|
|
272
|
+
# cos, sin: (T, H/2)
|
|
273
|
+
cos_TH, sin_TH = self._compute_cos_sin(positions)
|
|
274
|
+
|
|
275
|
+
# Reshape for broadcasting: (T, 1, H/2)
|
|
276
|
+
cos_T1H = cos_TH[:, None, :]
|
|
277
|
+
sin_T1H = sin_TH[:, None, :]
|
|
278
|
+
|
|
279
|
+
def _apply_rotation(x_TNH: jax.Array) -> jax.Array:
|
|
280
|
+
# Split the last dimension
|
|
281
|
+
first_half, second_half = jnp.split(x_TNH, 2, axis=-1)
|
|
282
|
+
|
|
283
|
+
# Apply rotation
|
|
284
|
+
rotated_x = jnp.concatenate([
|
|
285
|
+
first_half * cos_T1H - second_half * sin_T1H,
|
|
286
|
+
second_half * cos_T1H + first_half * sin_T1H
|
|
287
|
+
],
|
|
288
|
+
axis=-1)
|
|
289
|
+
return rotated_x.astype(self.dtype)
|
|
290
|
+
|
|
291
|
+
rotated_query = _apply_rotation(query_TNH)
|
|
292
|
+
rotated_key = _apply_rotation(key_TNH)
|
|
293
|
+
|
|
294
|
+
return rotated_query, rotated_key
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
from typing import Any, Dict
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def apply_rope(
|
|
23
|
+
# (seq_len, num_heads, head_dim)
|
|
24
|
+
inputs: jax.Array,
|
|
25
|
+
# (3, seq_len) for M-RoPE, otherwise (seq_len,)
|
|
26
|
+
positions: jax.Array,
|
|
27
|
+
head_dim: int,
|
|
28
|
+
rope_theta: float = 10000,
|
|
29
|
+
rope_scaling: Dict[str, Any] = None,
|
|
30
|
+
rope_input_ordering: str = "split",
|
|
31
|
+
) -> jax.Array:
|
|
32
|
+
"""
|
|
33
|
+
Applies Rotary Positional Embedding using the sine and cosine strategy.
|
|
34
|
+
|
|
35
|
+
This implementation assumes the input tensor has a shape that might include
|
|
36
|
+
padding on the last dimension (head_dim).
|
|
37
|
+
RoPE is applied only to the first `head_dim` features, and the result is
|
|
38
|
+
padded back to the original dimension if necessary.
|
|
39
|
+
If rope_input_ordering is "split", then the input pairs for rotation are taken one from the
|
|
40
|
+
first and one from the second half of the head_dim. If it is "interleaved" then
|
|
41
|
+
adjacent values are used as inputs for rotation.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
# M-RoPE support for Qwen2.5-VL
|
|
45
|
+
if positions.ndim == 2 and positions.shape[0] == 3:
|
|
46
|
+
mrope_section = rope_scaling.get("mrope_section",
|
|
47
|
+
None) if rope_scaling else None
|
|
48
|
+
# NOTE: We assume mrope_section is always available
|
|
49
|
+
# as Qwen2.5-VL is the only model using mrope
|
|
50
|
+
assert mrope_section is not None
|
|
51
|
+
|
|
52
|
+
split_indices = [mrope_section[0], mrope_section[0] + mrope_section[1]]
|
|
53
|
+
|
|
54
|
+
# Indices for the features to be rotated (first half of head_dim)
|
|
55
|
+
all_freq_indices = jnp.arange(head_dim // 2)
|
|
56
|
+
|
|
57
|
+
# Split the indices according to mrope_section. This is valid because split_indices are static.
|
|
58
|
+
freq_indices_split = jnp.split(all_freq_indices, split_indices)
|
|
59
|
+
# freq_indices_split is a list of 3 JAX arrays.
|
|
60
|
+
|
|
61
|
+
cos_list = []
|
|
62
|
+
sin_list = []
|
|
63
|
+
|
|
64
|
+
for i in range(3): # For each of the 3 position dimensions
|
|
65
|
+
current_indices = freq_indices_split[i]
|
|
66
|
+
|
|
67
|
+
if current_indices.size == 0:
|
|
68
|
+
# This section is empty, skip.
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
# inv_freq shape: (mrope_section[i],)
|
|
72
|
+
inv_freq = 1.0 / (rope_theta**(current_indices * 2.0 / head_dim))
|
|
73
|
+
|
|
74
|
+
# positions[i]: (seq_len,)
|
|
75
|
+
# freqs shape: (seq_len, mrope_section[i])
|
|
76
|
+
freqs = jnp.outer(positions[i], inv_freq)
|
|
77
|
+
|
|
78
|
+
cos_list.append(jnp.cos(freqs))
|
|
79
|
+
sin_list.append(jnp.sin(freqs))
|
|
80
|
+
|
|
81
|
+
# Concatenate along the feature dimension
|
|
82
|
+
# cos, sin shape: (seq_len, head_dim//2)
|
|
83
|
+
cos = jnp.concatenate(cos_list, axis=1)
|
|
84
|
+
sin = jnp.concatenate(sin_list, axis=1)
|
|
85
|
+
|
|
86
|
+
# Add num_heads dimension for broadcasting
|
|
87
|
+
cos = cos[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
|
|
88
|
+
sin = sin[:, jnp.newaxis, :] # Shape: (seq_len, 1, head_dim//2)
|
|
89
|
+
|
|
90
|
+
# Apply rotation
|
|
91
|
+
inputs_real = inputs[..., :head_dim // 2]
|
|
92
|
+
inputs_imag = inputs[..., head_dim // 2:head_dim]
|
|
93
|
+
|
|
94
|
+
outputs_real = inputs_real * cos - inputs_imag * sin
|
|
95
|
+
outputs_imag = inputs_real * sin + inputs_imag * cos
|
|
96
|
+
|
|
97
|
+
out = jnp.concatenate([outputs_real, outputs_imag], axis=-1)
|
|
98
|
+
|
|
99
|
+
# Standard RoPE
|
|
100
|
+
else:
|
|
101
|
+
# Calculate inverse frequencies (timescale)
|
|
102
|
+
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
|
|
103
|
+
timescale = 1.0 / (rope_theta**fraction)
|
|
104
|
+
|
|
105
|
+
# Apply scaling if provided
|
|
106
|
+
if rope_scaling:
|
|
107
|
+
timescale = apply_rope_scaling(timescale, rope_scaling)
|
|
108
|
+
|
|
109
|
+
# Prepare for rotation by calculating sin and cos values
|
|
110
|
+
# `sinusoid_inp` gets shape (batch * seq_len, head_dim/2)
|
|
111
|
+
sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
|
|
112
|
+
|
|
113
|
+
# Broadcast over the 'heads' dimension, assuming shape (batch*seq, heads, head_dim)
|
|
114
|
+
sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
|
|
115
|
+
sin = jnp.sin(sinusoid_inp)
|
|
116
|
+
cos = jnp.cos(sinusoid_inp)
|
|
117
|
+
|
|
118
|
+
if rope_input_ordering == "interleaved":
|
|
119
|
+
# Reshape to group adjacent features for rotation, matching new_apply_rope
|
|
120
|
+
rotary_inputs = inputs[
|
|
121
|
+
..., :head_dim] # Take just the non-padded amount.
|
|
122
|
+
reshaped_inputs = rotary_inputs.reshape(*rotary_inputs.shape[:-1],
|
|
123
|
+
-1, 2)
|
|
124
|
+
|
|
125
|
+
# Apply the rotation
|
|
126
|
+
first_half = reshaped_inputs[..., 0]
|
|
127
|
+
second_half = reshaped_inputs[..., 1]
|
|
128
|
+
else:
|
|
129
|
+
first_half = inputs[..., :head_dim // 2]
|
|
130
|
+
second_half = inputs[..., head_dim // 2:head_dim]
|
|
131
|
+
|
|
132
|
+
first_part = first_half * cos - second_half * sin
|
|
133
|
+
second_part = second_half * cos + first_half * sin
|
|
134
|
+
|
|
135
|
+
# Combine the rotated parts and reshape back
|
|
136
|
+
if rope_input_ordering == "interleaved":
|
|
137
|
+
out_stacked = jnp.stack([first_part, second_part], axis=-1)
|
|
138
|
+
out = out_stacked.reshape(rotary_inputs.shape)
|
|
139
|
+
else:
|
|
140
|
+
out = jnp.concatenate([first_part, second_part], axis=-1)
|
|
141
|
+
|
|
142
|
+
# If the original input was padded, pad the output with zeros to match.
|
|
143
|
+
padded_head_dim = inputs.shape[-1]
|
|
144
|
+
if padded_head_dim > head_dim:
|
|
145
|
+
pad_width = padded_head_dim - head_dim
|
|
146
|
+
pad_config = [(0, 0)] * (out.ndim - 1) + [(0, pad_width)]
|
|
147
|
+
out = jnp.pad(out, pad_config)
|
|
148
|
+
|
|
149
|
+
return out.astype(inputs.dtype)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def apply_longrope(
|
|
153
|
+
inputs: jax.Array,
|
|
154
|
+
positions: jax.Array,
|
|
155
|
+
head_dim: int,
|
|
156
|
+
rope_scaling: Dict[str, Any],
|
|
157
|
+
original_max_position_embeddings: int,
|
|
158
|
+
max_position_embeddings: int,
|
|
159
|
+
rope_theta: float = 10000,
|
|
160
|
+
) -> jax.Array:
|
|
161
|
+
# LongRoPE implementation specific to Phi-3
|
|
162
|
+
# Implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L197-L235
|
|
163
|
+
|
|
164
|
+
scale = max_position_embeddings / original_max_position_embeddings
|
|
165
|
+
if scale <= 1.0:
|
|
166
|
+
mscale = 1.0
|
|
167
|
+
else:
|
|
168
|
+
mscale = jnp.sqrt(1 + (jnp.log(scale) /
|
|
169
|
+
jnp.log(original_max_position_embeddings)))
|
|
170
|
+
|
|
171
|
+
seq_len = inputs.shape[0]
|
|
172
|
+
if seq_len > original_max_position_embeddings:
|
|
173
|
+
long_factor = jnp.array(rope_scaling.get("long_factor"))
|
|
174
|
+
timescale = 1.0 / (long_factor * (rope_theta**(
|
|
175
|
+
(2 * jnp.arange(0, head_dim // 2)) / head_dim)))
|
|
176
|
+
else:
|
|
177
|
+
short_factor = jnp.array(rope_scaling.get("short_factor"))
|
|
178
|
+
timescale = 1.0 / (short_factor * (rope_theta**(
|
|
179
|
+
(2 * jnp.arange(0, head_dim // 2)) / head_dim)))
|
|
180
|
+
|
|
181
|
+
# Calculate RoPE positions
|
|
182
|
+
sinusoid_inp = positions[..., jnp.newaxis] * timescale[jnp.newaxis, :]
|
|
183
|
+
sinusoid_inp = sinusoid_inp[:, jnp.newaxis, ...]
|
|
184
|
+
sin = jnp.sin(sinusoid_inp) * mscale
|
|
185
|
+
cos = jnp.cos(sinusoid_inp) * mscale
|
|
186
|
+
|
|
187
|
+
# Padding logic
|
|
188
|
+
padded_head_dim = inputs.shape[-1]
|
|
189
|
+
|
|
190
|
+
# Apply RoPE mechanism
|
|
191
|
+
first_half = inputs[..., :head_dim // 2]
|
|
192
|
+
second_half = inputs[..., head_dim // 2:head_dim]
|
|
193
|
+
|
|
194
|
+
first_part = first_half * cos - second_half * sin
|
|
195
|
+
second_part = second_half * cos + first_half * sin
|
|
196
|
+
out = jnp.concatenate([first_part, second_part], axis=-1)
|
|
197
|
+
|
|
198
|
+
if padded_head_dim > head_dim:
|
|
199
|
+
out = jnp.pad(out, ((0, 0), (0, 0), (0, padded_head_dim - head_dim)))
|
|
200
|
+
|
|
201
|
+
return out.astype(inputs.dtype)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def apply_rope_scaling(freqs: jax.Array, rope_scaling: Dict[str,
|
|
205
|
+
Any]) -> jax.Array:
|
|
206
|
+
# Values obtained from grid search
|
|
207
|
+
scale_factor = rope_scaling.get("scale_factor", 8.0)
|
|
208
|
+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
|
|
209
|
+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
|
|
210
|
+
old_context_len = rope_scaling.get("original_max_position_embeddings",
|
|
211
|
+
8192)
|
|
212
|
+
|
|
213
|
+
low_freq_wavelen = old_context_len / low_freq_factor
|
|
214
|
+
high_freq_wavelen = old_context_len / high_freq_factor
|
|
215
|
+
|
|
216
|
+
wavelen = 2 * math.pi / freqs
|
|
217
|
+
smooth = (old_context_len / wavelen -
|
|
218
|
+
low_freq_factor) / (high_freq_factor - low_freq_factor)
|
|
219
|
+
|
|
220
|
+
high_freqs = jnp.where(wavelen < high_freq_wavelen, freqs, 0)
|
|
221
|
+
low_freqs = jnp.where(wavelen > low_freq_wavelen, freqs / scale_factor, 0)
|
|
222
|
+
mid_freqs = jnp.where(
|
|
223
|
+
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
|
224
|
+
(1 - smooth) * freqs / scale_factor + smooth * freqs,
|
|
225
|
+
0,
|
|
226
|
+
)
|
|
227
|
+
new_freqs = high_freqs + low_freqs + mid_freqs
|
|
228
|
+
return new_freqs
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|