telefuser 0.1.0.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- telefuser/__init__.py +6 -0
- telefuser/_logo.py +12 -0
- telefuser/_version.py +24 -0
- telefuser/cache/__init__.py +5 -0
- telefuser/cache/kv_cache.py +438 -0
- telefuser/cache_mem/__init__.py +27 -0
- telefuser/cache_mem/cache_types.py +40 -0
- telefuser/cache_mem/config.py +83 -0
- telefuser/cache_mem/connection.py +197 -0
- telefuser/cache_mem/encoders.py +398 -0
- telefuser/cache_mem/encoding/__init__.py +0 -0
- telefuser/cache_mem/encoding/interfaces.py +27 -0
- telefuser/cache_mem/latent_cache.py +213 -0
- telefuser/cache_mem/log_monitor.py +77 -0
- telefuser/cache_mem/metadata.py +268 -0
- telefuser/cache_mem/src/__init__.py +0 -0
- telefuser/cache_mem/src/models/__init__.py +0 -0
- telefuser/cache_mem/src/models/qwen3_vl_embedding.py +346 -0
- telefuser/cache_mem/src/models/qwen3_vl_reranker.py +437 -0
- telefuser/cache_mem/state/__init__.py +0 -0
- telefuser/cache_mem/state/interfaces.py +67 -0
- telefuser/cache_mem/storage/__init__.py +11 -0
- telefuser/cache_mem/storage/fluxon.py +24 -0
- telefuser/cache_mem/storage/interfaces.py +25 -0
- telefuser/cache_mem/storage/local_file.py +112 -0
- telefuser/cache_mem/storage/memory.py +24 -0
- telefuser/cache_mem/strategies.py +819 -0
- telefuser/cache_mem/vector_store/__init__.py +5 -0
- telefuser/cache_mem/vector_store/faiss.py +298 -0
- telefuser/cache_mem/vector_store/interfaces.py +42 -0
- telefuser/cache_mem/vector_store/qdrant.py +46 -0
- telefuser/client/__init__.py +34 -0
- telefuser/client/openai/__init__.py +34 -0
- telefuser/client/openai/client.py +146 -0
- telefuser/client/openai/images.py +221 -0
- telefuser/client/openai/videos.py +307 -0
- telefuser/client/tf_client.py +1016 -0
- telefuser/core/__init__.py +37 -0
- telefuser/core/base_model.py +262 -0
- telefuser/core/base_pipeline.py +421 -0
- telefuser/core/base_stage.py +169 -0
- telefuser/core/config.py +409 -0
- telefuser/core/config_serializer.py +54 -0
- telefuser/core/model_registry.py +108 -0
- telefuser/core/module_manager.py +412 -0
- telefuser/distributed/__init__.py +95 -0
- telefuser/distributed/device_mesh.py +347 -0
- telefuser/distributed/fsdp.py +143 -0
- telefuser/distributed/parallel_shard.py +250 -0
- telefuser/distributed/pp_comm.py +306 -0
- telefuser/distributed/ring.py +357 -0
- telefuser/distributed/tp_parallelize.py +63 -0
- telefuser/distributed/ulysses_comm.py +250 -0
- telefuser/entrypoints/__init__.py +3 -0
- telefuser/entrypoints/cli/main.py +257 -0
- telefuser/feature_cache/__init__.py +47 -0
- telefuser/feature_cache/ada_taylor_cache/__init__.py +28 -0
- telefuser/feature_cache/ada_taylor_cache/ada_taylor_cache.py +656 -0
- telefuser/feature_cache/ada_taylor_cache/params/HunyuanVideo15-I2V-480P.json +111 -0
- telefuser/feature_cache/ada_taylor_cache/params/HunyuanVideo15-T2V-480P.json +111 -0
- telefuser/feature_cache/ada_taylor_cache/params/Qwen-Image-2512.json +111 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-FL2V-14B-720P.json +89 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-I2V-14B-480P.json +89 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-I2V-14B-720P.json +89 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-T2V-14B.json +109 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-T2V-1_3B.json +109 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-FL2V-A14B.json +89 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-I2V-A14B-Camera.json +109 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-I2V-A14B.json +89 -0
- telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-T2V-A14B.json +89 -0
- telefuser/feature_cache/base.py +150 -0
- telefuser/kernel/__init__.py +55 -0
- telefuser/kernel/triton/__init__.py +43 -0
- telefuser/kernel/triton/merge_attn_states.py +115 -0
- telefuser/kernel/triton/norm.py +816 -0
- telefuser/kernel/triton/quant.py +147 -0
- telefuser/kernel/triton/quant_per_block.py +154 -0
- telefuser/kernel/triton/rotary.py +162 -0
- telefuser/kernel/triton/scale_shift.py +1064 -0
- telefuser/kernel/triton/sparse_int8_attn.py +280 -0
- telefuser/metrics/__init__.py +101 -0
- telefuser/metrics/collector.py +442 -0
- telefuser/metrics/config.py +111 -0
- telefuser/metrics/exporters.py +113 -0
- telefuser/metrics/registry.py +485 -0
- telefuser/metrics/service_metrics.py +436 -0
- telefuser/metrics/stage_metrics.py +207 -0
- telefuser/models/TCDecoder.py +352 -0
- telefuser/models/__init__.py +24 -0
- telefuser/models/flashvsr_dit.py +608 -0
- telefuser/models/flux2_dit.py +1126 -0
- telefuser/models/hunyuan_video_byt5.py +433 -0
- telefuser/models/hunyuan_video_dit.py +2124 -0
- telefuser/models/hunyuan_video_image_encoder.py +222 -0
- telefuser/models/hunyuan_video_text_encoder.py +461 -0
- telefuser/models/hunyuan_video_upsampler.py +320 -0
- telefuser/models/hunyuan_video_vae.py +850 -0
- telefuser/models/lingbot_world_fast_dit.py +573 -0
- telefuser/models/liveact_dit.py +1213 -0
- telefuser/models/longcat_video_dit.py +1214 -0
- telefuser/models/ltx_audio_vae.py +1183 -0
- telefuser/models/ltx_dit.py +2202 -0
- telefuser/models/ltx_gemma_text_encoder.py +1004 -0
- telefuser/models/ltx_upsampler.py +416 -0
- telefuser/models/ltx_video_vae.py +2668 -0
- telefuser/models/qwen_image_dit.py +780 -0
- telefuser/models/qwen_image_text_encoder.py +196 -0
- telefuser/models/qwen_image_vae.py +643 -0
- telefuser/models/realesrgan.py +356 -0
- telefuser/models/rift_hdv3.py +353 -0
- telefuser/models/t5_tokenizer.py +96 -0
- telefuser/models/video_projector.py +457 -0
- telefuser/models/wan22_video_vae.py +1548 -0
- telefuser/models/wan_video_dit.py +1586 -0
- telefuser/models/wan_video_image_encoder.py +534 -0
- telefuser/models/wan_video_text_encoder.py +317 -0
- telefuser/models/wan_video_vae.py +1519 -0
- telefuser/models/wav2vec2.py +154 -0
- telefuser/models/xlm_roberta.py +157 -0
- telefuser/models/z_image_dit.py +695 -0
- telefuser/models/z_image_text_encoder.py +81 -0
- telefuser/offload/__init__.py +26 -0
- telefuser/offload/async_offload.py +417 -0
- telefuser/offload/model_offload.py +35 -0
- telefuser/offload/sequential_offload.py +318 -0
- telefuser/ops/__init__.py +33 -0
- telefuser/ops/activations.py +187 -0
- telefuser/ops/attention/__init__.py +29 -0
- telefuser/ops/attention/attention_impl.py +529 -0
- telefuser/ops/attention/backends.py +209 -0
- telefuser/ops/attention/bsa.py +250 -0
- telefuser/ops/attention/local_sparse_attn.py +547 -0
- telefuser/ops/attention/sparse_patterns.py +622 -0
- telefuser/ops/attention/sparse_sage.py +80 -0
- telefuser/ops/base.py +145 -0
- telefuser/ops/custom_op.py +121 -0
- telefuser/ops/ffn.py +69 -0
- telefuser/ops/fp8_gemm.py +348 -0
- telefuser/ops/normalization.py +274 -0
- telefuser/ops/quantized_linear.py +164 -0
- telefuser/ops/rotary.py +138 -0
- telefuser/orchestrator/__init__.py +22 -0
- telefuser/orchestrator/artifact_save_stage.py +119 -0
- telefuser/orchestrator/pipeline_orchestrator.py +358 -0
- telefuser/orchestrator/stage_wrapper.py +276 -0
- telefuser/pipelines/__init__.py +9 -0
- telefuser/pipelines/common/realesrgan_upscale.py +92 -0
- telefuser/pipelines/common/rift_vfi.py +54 -0
- telefuser/pipelines/flashvsr/__init__.py +4 -0
- telefuser/pipelines/flashvsr/dit_denoising.py +312 -0
- telefuser/pipelines/flashvsr/flashvsr_stream.py +197 -0
- telefuser/pipelines/flashvsr/vae.py +57 -0
- telefuser/pipelines/flux2_klein/__init__.py +5 -0
- telefuser/pipelines/flux2_klein/dit_denoising.py +329 -0
- telefuser/pipelines/flux2_klein/pipeline.py +427 -0
- telefuser/pipelines/flux2_klein/text_encoding.py +201 -0
- telefuser/pipelines/flux2_klein/vae.py +215 -0
- telefuser/pipelines/hunyuan_video_1_5/__init__.py +55 -0
- telefuser/pipelines/hunyuan_video_1_5/dit_denoising.py +270 -0
- telefuser/pipelines/hunyuan_video_1_5/image_encoding.py +87 -0
- telefuser/pipelines/hunyuan_video_1_5/pipeline.py +324 -0
- telefuser/pipelines/hunyuan_video_1_5/sr_dit_denoising.py +363 -0
- telefuser/pipelines/hunyuan_video_1_5/text_encoding.py +291 -0
- telefuser/pipelines/hunyuan_video_1_5/upsampler.py +95 -0
- telefuser/pipelines/hunyuan_video_1_5/vae.py +133 -0
- telefuser/pipelines/lingbot_world_fast/__init__.py +31 -0
- telefuser/pipelines/lingbot_world_fast/control.py +208 -0
- telefuser/pipelines/lingbot_world_fast/denoising.py +85 -0
- telefuser/pipelines/lingbot_world_fast/pipeline.py +592 -0
- telefuser/pipelines/lingbot_world_fast/service.py +483 -0
- telefuser/pipelines/lingbot_world_fast/session.py +76 -0
- telefuser/pipelines/liveact/__init__.py +16 -0
- telefuser/pipelines/liveact/audio_encoding.py +365 -0
- telefuser/pipelines/liveact/denoising.py +306 -0
- telefuser/pipelines/liveact/pipeline.py +337 -0
- telefuser/pipelines/longcat_video/__init__.py +12 -0
- telefuser/pipelines/longcat_video/dit_denoising.py +297 -0
- telefuser/pipelines/longcat_video/longcat_video.py +542 -0
- telefuser/pipelines/longcat_video/refine_denoise.py +235 -0
- telefuser/pipelines/longcat_video/text_encoding.py +118 -0
- telefuser/pipelines/ltx_video/__init__.py +1 -0
- telefuser/pipelines/ltx_video/dit_denoising.py +1010 -0
- telefuser/pipelines/ltx_video/gemma_text_encoding.py +165 -0
- telefuser/pipelines/ltx_video/ltx23_video.py +518 -0
- telefuser/pipelines/ltx_video/upsampler.py +29 -0
- telefuser/pipelines/ltx_video/vae.py +195 -0
- telefuser/pipelines/qwen_image/__init__.py +11 -0
- telefuser/pipelines/qwen_image/dit_denoising.py +228 -0
- telefuser/pipelines/qwen_image/qwen_image.py +301 -0
- telefuser/pipelines/qwen_image/qwen_image_edit.py +209 -0
- telefuser/pipelines/qwen_image/text_encoding.py +223 -0
- telefuser/pipelines/qwen_image/vae.py +91 -0
- telefuser/pipelines/wan_video/__init__.py +6 -0
- telefuser/pipelines/wan_video/async_wan22_video.py +467 -0
- telefuser/pipelines/wan_video/clip_encoding.py +54 -0
- telefuser/pipelines/wan_video/latent_data_utils.py +53 -0
- telefuser/pipelines/wan_video/moe_dit_denoising.py +409 -0
- telefuser/pipelines/wan_video/single_dit_denoising.py +262 -0
- telefuser/pipelines/wan_video/text_encoding.py +58 -0
- telefuser/pipelines/wan_video/ti2v_denoising.py +396 -0
- telefuser/pipelines/wan_video/vae.py +237 -0
- telefuser/pipelines/wan_video/wan21_video.py +353 -0
- telefuser/pipelines/wan_video/wan22_ti2v.py +372 -0
- telefuser/pipelines/wan_video/wan22_video.py +318 -0
- telefuser/pipelines/z_image/__init__.py +8 -0
- telefuser/pipelines/z_image/dit_denoising.py +281 -0
- telefuser/pipelines/z_image/text_encoding.py +117 -0
- telefuser/pipelines/z_image/vae.py +49 -0
- telefuser/pipelines/z_image/z_image.py +139 -0
- telefuser/platforms/__init__.py +80 -0
- telefuser/platforms/cpu.py +30 -0
- telefuser/platforms/cuda.py +86 -0
- telefuser/platforms/interface.py +99 -0
- telefuser/platforms/npu.py +78 -0
- telefuser/platforms/rocm.py +72 -0
- telefuser/schedulers/__init__.py +13 -0
- telefuser/schedulers/flow_match.py +377 -0
- telefuser/schedulers/flow_match_discrete.py +325 -0
- telefuser/schedulers/lcm.py +81 -0
- telefuser/schedulers/unipc.py +697 -0
- telefuser/service/__init__.py +38 -0
- telefuser/service/api/__init__.py +41 -0
- telefuser/service/api/api_server.py +327 -0
- telefuser/service/api/middleware.py +334 -0
- telefuser/service/api/openai/__init__.py +58 -0
- telefuser/service/api/openai/adapter.py +423 -0
- telefuser/service/api/openai/image_routes.py +417 -0
- telefuser/service/api/openai/protocol.py +287 -0
- telefuser/service/api/openai/video_routes.py +416 -0
- telefuser/service/api/routers/__init__.py +20 -0
- telefuser/service/api/routers/files.py +65 -0
- telefuser/service/api/routers/service.py +143 -0
- telefuser/service/api/routers/stream.py +88 -0
- telefuser/service/api/routers/tasks.py +432 -0
- telefuser/service/api/routers/webrtc.py +164 -0
- telefuser/service/api/schema.py +80 -0
- telefuser/service/api/stream_schema.py +76 -0
- telefuser/service/api/task_contract_runtime.py +148 -0
- telefuser/service/api/utils.py +139 -0
- telefuser/service/cache/__init__.py +4 -0
- telefuser/service/cache/cache_factory.py +176 -0
- telefuser/service/cache/cache_service.py +389 -0
- telefuser/service/core/__init__.py +30 -0
- telefuser/service/core/config.py +249 -0
- telefuser/service/core/container.py +264 -0
- telefuser/service/core/contract_templates.py +147 -0
- telefuser/service/core/file_service.py +269 -0
- telefuser/service/core/pipeline_contract.py +339 -0
- telefuser/service/core/pipeline_loader.py +94 -0
- telefuser/service/core/pipeline_pool.py +280 -0
- telefuser/service/core/pipeline_runner.py +205 -0
- telefuser/service/core/pipeline_service.py +311 -0
- telefuser/service/core/replica_worker.py +298 -0
- telefuser/service/core/stream_pipeline_service.py +261 -0
- telefuser/service/core/task_manager.py +416 -0
- telefuser/service/core/task_processor.py +162 -0
- telefuser/service/core/task_service.py +156 -0
- telefuser/service/main.py +99 -0
- telefuser/service/media/__init__.py +17 -0
- telefuser/service/media/media_base.py +298 -0
- telefuser/service/security/__init__.py +32 -0
- telefuser/service/security/security_validator.py +797 -0
- telefuser/service/webrtc/__init__.py +32 -0
- telefuser/service/webrtc/chunk_router.py +111 -0
- telefuser/service/webrtc/session_manager.py +322 -0
- telefuser/service/webrtc/track.py +307 -0
- telefuser/service_types.py +83 -0
- telefuser/utils/__init__.py +25 -0
- telefuser/utils/audio.py +51 -0
- telefuser/utils/func.py +31 -0
- telefuser/utils/hf_model_analyzer.py +382 -0
- telefuser/utils/hf_model_utils.py +209 -0
- telefuser/utils/hf_utils.py +256 -0
- telefuser/utils/logging.py +749 -0
- telefuser/utils/lora_loader.py +295 -0
- telefuser/utils/lora_network.py +212 -0
- telefuser/utils/memory_snapshot.py +423 -0
- telefuser/utils/model_weight.py +163 -0
- telefuser/utils/profiler.py +1079 -0
- telefuser/utils/stage_bench_harness.py +740 -0
- telefuser/utils/system.py +228 -0
- telefuser/utils/torch_compile.py +83 -0
- telefuser/utils/utils.py +49 -0
- telefuser/utils/video.py +464 -0
- telefuser/worker/__init__.py +18 -0
- telefuser/worker/native_worker.py +125 -0
- telefuser/worker/parallel_worker.py +292 -0
- telefuser/worker/ray_worker.py +107 -0
- telefuser-0.1.0.post3.dist-info/METADATA +379 -0
- telefuser-0.1.0.post3.dist-info/RECORD +296 -0
- telefuser-0.1.0.post3.dist-info/WHEEL +5 -0
- telefuser-0.1.0.post3.dist-info/entry_points.txt +2 -0
- telefuser-0.1.0.post3.dist-info/licenses/LICENSE +201 -0
- telefuser-0.1.0.post3.dist-info/scm_file_list.json +763 -0
- telefuser-0.1.0.post3.dist-info/scm_version.json +8 -0
- telefuser-0.1.0.post3.dist-info/top_level.txt +1 -0
telefuser/__init__.py
ADDED
telefuser/_logo.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Lightweight package branding constants."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
TELEFUSER_LOGO = r"""
|
|
6
|
+
████████╗███████╗██╗ ███████╗███████╗██╗ ██╗███████╗███████╗█████████╗
|
|
7
|
+
╚══██╔══╝██╔════╝██║ ██╔════╝██╔════╝██║ ██║██╔════╝██╔════╝██╔════██║
|
|
8
|
+
██║ █████╗ ██║ █████╗ █████╗ ██║ ██║███████╗█████╗ ███████╔═╝
|
|
9
|
+
██║ ██╔══╝ ██║ ██╔══╝ ██╔══╝ ██║ ██║╚════██║██╔══╝ ██╔══██║
|
|
10
|
+
██║ ███████╗███████╗███████╗██║ ╚██████╔╝███████║███████╗██║ ████╗
|
|
11
|
+
╚═╝ ╚══════╝╚══════╝ ╚═════╝╚═╝ ╚═════╝ ╚══════╝╚══════╝╚═╝ ╚═══╝
|
|
12
|
+
"""
|
telefuser/_version.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.0.post3'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 0, 'post3')
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = 'gca0bc08c7'
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
"""KVCache module for LiveAct - List-based structure for torch.compile optimization."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class KVCacheConfig:
|
|
13
|
+
"""Configuration for KV cache management.
|
|
14
|
+
|
|
15
|
+
Attributes:
|
|
16
|
+
fp8_kv_cache: Enable FP8 quantization for memory efficiency
|
|
17
|
+
offload_cache: Offload cache to CPU memory
|
|
18
|
+
cache_frames: Number of frames to cache after compression (default 6)
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
fp8_kv_cache: bool = False
|
|
22
|
+
offload_cache: bool = False
|
|
23
|
+
cache_frames: int = 6
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class KVCache:
|
|
27
|
+
"""Minimal KV cache for a single (timestep, layer) entry.
|
|
28
|
+
|
|
29
|
+
Preserves exact original behavior:
|
|
30
|
+
- Direct dict access (k, v, k_scale, v_scale)
|
|
31
|
+
- FP8 quantization support
|
|
32
|
+
- CPU offload support
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
fp8_kv_cache: bool = False,
|
|
38
|
+
offload_cache: bool = False,
|
|
39
|
+
):
|
|
40
|
+
"""Initialize KV cache.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
fp8_kv_cache: Enable FP8 quantization
|
|
44
|
+
offload_cache: Enable CPU offload
|
|
45
|
+
"""
|
|
46
|
+
self.fp8_kv_cache = fp8_kv_cache
|
|
47
|
+
self.offload_cache = offload_cache
|
|
48
|
+
|
|
49
|
+
# Storage tensors
|
|
50
|
+
self.k: torch.Tensor | None = None
|
|
51
|
+
self.v: torch.Tensor | None = None
|
|
52
|
+
self.k_scale: torch.Tensor | None = None
|
|
53
|
+
self.v_scale: torch.Tensor | None = None
|
|
54
|
+
|
|
55
|
+
def allocate(self, shape: tuple[int, ...], dtype: torch.dtype, device: str | torch.device) -> None:
|
|
56
|
+
"""Allocate cache tensors.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
shape: Shape of K/V tensor [batch, seq, heads, head_dim]
|
|
60
|
+
dtype: Storage dtype (bf16 or fp8)
|
|
61
|
+
device: Device for storage
|
|
62
|
+
"""
|
|
63
|
+
storage_dtype = torch.float8_e4m3fn if self.fp8_kv_cache else dtype
|
|
64
|
+
self.k = torch.zeros(shape, dtype=storage_dtype, device=device)
|
|
65
|
+
self.v = torch.zeros(shape, dtype=storage_dtype, device=device)
|
|
66
|
+
|
|
67
|
+
if self.fp8_kv_cache:
|
|
68
|
+
# Scale shape: [batch, seq, heads, 1]
|
|
69
|
+
scale_shape = (shape[0], shape[1], shape[2], 1)
|
|
70
|
+
self.k_scale = torch.ones(scale_shape, dtype=torch.float32, device=device)
|
|
71
|
+
self.v_scale = torch.ones(scale_shape, dtype=torch.float32, device=device)
|
|
72
|
+
|
|
73
|
+
def clear(self) -> None:
|
|
74
|
+
"""Clear cache."""
|
|
75
|
+
self.k = None
|
|
76
|
+
self.v = None
|
|
77
|
+
self.k_scale = None
|
|
78
|
+
self.v_scale = None
|
|
79
|
+
|
|
80
|
+
def load(self, device: str | torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
|
81
|
+
"""Load K/V tensors to compute device.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
device: Target device
|
|
85
|
+
dtype: Target dtype
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
(k, v) tensors on device
|
|
89
|
+
"""
|
|
90
|
+
# Move to device if offloaded
|
|
91
|
+
if self.offload_cache:
|
|
92
|
+
self._move_to_device(device)
|
|
93
|
+
|
|
94
|
+
# Dequantize if FP8
|
|
95
|
+
if self.fp8_kv_cache:
|
|
96
|
+
k = self._dequantize(self.k, self.k_scale, dtype)
|
|
97
|
+
v = self._dequantize(self.v, self.v_scale, dtype)
|
|
98
|
+
else:
|
|
99
|
+
if self.k.dtype != dtype:
|
|
100
|
+
self.k = self.k.to(dtype=dtype)
|
|
101
|
+
if self.v.dtype != dtype:
|
|
102
|
+
self.v = self.v.to(dtype=dtype)
|
|
103
|
+
k = self.k
|
|
104
|
+
v = self.v
|
|
105
|
+
|
|
106
|
+
return k, v
|
|
107
|
+
|
|
108
|
+
def store(self, k: torch.Tensor, v: torch.Tensor) -> None:
|
|
109
|
+
"""Store K/V tensors.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
k: Key tensor
|
|
113
|
+
v: Value tensor
|
|
114
|
+
"""
|
|
115
|
+
if self.fp8_kv_cache:
|
|
116
|
+
self.k, self.k_scale = self._quantize(k)
|
|
117
|
+
self.v, self.v_scale = self._quantize(v)
|
|
118
|
+
else:
|
|
119
|
+
self.k = k
|
|
120
|
+
self.v = v
|
|
121
|
+
|
|
122
|
+
if self.offload_cache:
|
|
123
|
+
self._move_to_device("cpu")
|
|
124
|
+
|
|
125
|
+
def _move_to_device(self, device: str | torch.device) -> None:
|
|
126
|
+
"""Move cache tensors to device."""
|
|
127
|
+
self.k = self.k.to(device=device, non_blocking=True)
|
|
128
|
+
self.v = self.v.to(device=device, non_blocking=True)
|
|
129
|
+
if self.k_scale is not None:
|
|
130
|
+
self.k_scale = self.k_scale.to(device=device, non_blocking=True)
|
|
131
|
+
if self.v_scale is not None:
|
|
132
|
+
self.v_scale = self.v_scale.to(device=device, non_blocking=True)
|
|
133
|
+
|
|
134
|
+
def _quantize(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
135
|
+
"""Quantize tensor to FP8."""
|
|
136
|
+
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
|
137
|
+
scale = tensor.detach().abs().amax(dim=-1, keepdim=True).to(torch.float32)
|
|
138
|
+
scale = torch.clamp(scale / fp8_max, min=1e-12)
|
|
139
|
+
q_tensor = (tensor / scale.to(dtype=tensor.dtype)).to(torch.float8_e4m3fn)
|
|
140
|
+
return q_tensor.contiguous(), scale.contiguous()
|
|
141
|
+
|
|
142
|
+
def _dequantize(self, q_tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
143
|
+
"""Dequantize FP8 tensor."""
|
|
144
|
+
return q_tensor.to(dtype=dtype) * scale.to(device=q_tensor.device, dtype=dtype)
|
|
145
|
+
|
|
146
|
+
def to_dict(self) -> dict:
|
|
147
|
+
"""Convert to dict for backward compatibility."""
|
|
148
|
+
return {
|
|
149
|
+
"k": self.k,
|
|
150
|
+
"v": self.v,
|
|
151
|
+
"k_scale": self.k_scale,
|
|
152
|
+
"v_scale": self.v_scale,
|
|
153
|
+
"fp8_kv_cache": self.fp8_kv_cache,
|
|
154
|
+
"offload_cache": self.offload_cache,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
@classmethod
|
|
158
|
+
def from_dict(cls, d: dict) -> "KVCache":
|
|
159
|
+
"""Create from dict."""
|
|
160
|
+
cache = cls(
|
|
161
|
+
fp8_kv_cache=d.get("fp8_kv_cache", False),
|
|
162
|
+
offload_cache=d.get("offload_cache", False),
|
|
163
|
+
)
|
|
164
|
+
cache.k = d.get("k")
|
|
165
|
+
cache.v = d.get("v")
|
|
166
|
+
cache.k_scale = d.get("k_scale")
|
|
167
|
+
cache.v_scale = d.get("v_scale")
|
|
168
|
+
return cache
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class KVCacheManager:
|
|
172
|
+
"""Manager for nested KV cache structure (timestep -> layer -> KVCache).
|
|
173
|
+
|
|
174
|
+
Uses list structure for optimal torch.compile performance:
|
|
175
|
+
- List indexing is faster than dict hashing
|
|
176
|
+
- No graph breaks from dynamic dict keys
|
|
177
|
+
- Memory contiguous for compiler optimization
|
|
178
|
+
|
|
179
|
+
Usage:
|
|
180
|
+
config = KVCacheConfig(fp8_kv_cache=False, offload_cache=True, cache_frames=6)
|
|
181
|
+
manager = KVCacheManager.from_dit_model(
|
|
182
|
+
dit_model,
|
|
183
|
+
config=config,
|
|
184
|
+
tokens_per_frame=520,
|
|
185
|
+
num_timesteps=3,
|
|
186
|
+
)
|
|
187
|
+
manager.allocate(device="cuda", dtype=torch.bfloat16)
|
|
188
|
+
|
|
189
|
+
# Access cache for specific (t_idx, layer_idx)
|
|
190
|
+
cache = manager.get_cache(t_idx=0, layer_idx=5)
|
|
191
|
+
k, v = cache.load(device, dtype)
|
|
192
|
+
|
|
193
|
+
# Get all layer caches for a timestep (returns list)
|
|
194
|
+
kv_list = manager.get_timestep_caches(0)
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
def __init__(
|
|
198
|
+
self,
|
|
199
|
+
config: KVCacheConfig,
|
|
200
|
+
num_timesteps: int,
|
|
201
|
+
num_layers: int,
|
|
202
|
+
num_heads: int,
|
|
203
|
+
head_dim: int,
|
|
204
|
+
sp_size: int = 1,
|
|
205
|
+
):
|
|
206
|
+
"""Initialize KV cache manager.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
config: KVCacheConfig instance
|
|
210
|
+
num_timesteps: Number of denoising timesteps
|
|
211
|
+
num_layers: Number of transformer layers
|
|
212
|
+
num_heads: Total number of attention heads
|
|
213
|
+
head_dim: Dimension per head
|
|
214
|
+
sp_size: Sequence parallel world size (heads are sharded)
|
|
215
|
+
"""
|
|
216
|
+
self.config = config
|
|
217
|
+
self.num_timesteps = num_timesteps
|
|
218
|
+
self.num_layers = num_layers
|
|
219
|
+
self.num_heads = num_heads
|
|
220
|
+
self.head_dim = head_dim
|
|
221
|
+
self.sp_size = sp_size
|
|
222
|
+
|
|
223
|
+
# Local heads after SP sharding
|
|
224
|
+
self.local_heads = num_heads // sp_size
|
|
225
|
+
|
|
226
|
+
# Cache storage: list[timestep][layer] = KVCache
|
|
227
|
+
self._caches: list[list[KVCache]] = []
|
|
228
|
+
|
|
229
|
+
# Pre-allocated shape (set during allocate)
|
|
230
|
+
self._shape: tuple[int, ...] | None = None
|
|
231
|
+
self._tokens_per_frame: int | None = None
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def from_dit_model(
|
|
235
|
+
cls,
|
|
236
|
+
dit_model: Any,
|
|
237
|
+
config: KVCacheConfig,
|
|
238
|
+
tokens_per_frame: int,
|
|
239
|
+
num_timesteps: int = 3,
|
|
240
|
+
sp_size: int | None = None,
|
|
241
|
+
device: str | torch.device | None = None,
|
|
242
|
+
dtype: torch.dtype | None = None,
|
|
243
|
+
) -> "KVCacheManager":
|
|
244
|
+
"""Create KV cache manager from DiT model.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
dit_model: LiveActDiT model with blocks, num_heads, dim attributes
|
|
248
|
+
config: KVCacheConfig instance
|
|
249
|
+
tokens_per_frame: Number of tokens per frame (h * w)
|
|
250
|
+
num_timesteps: Number of denoising timesteps
|
|
251
|
+
sp_size: Sequence parallel size (None: auto-detect from dit_model)
|
|
252
|
+
device: Target device (None: cuda)
|
|
253
|
+
dtype: Target dtype (None: bf16)
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
KVCacheManager instance
|
|
257
|
+
"""
|
|
258
|
+
num_layers = len(dit_model.blocks)
|
|
259
|
+
num_heads = dit_model.num_heads
|
|
260
|
+
head_dim = dit_model.dim // dit_model.num_heads
|
|
261
|
+
|
|
262
|
+
# Auto-detect sp_size from dit_model if not provided
|
|
263
|
+
if sp_size is None:
|
|
264
|
+
device_mesh = getattr(dit_model, "device_mesh", None)
|
|
265
|
+
if device_mesh is not None:
|
|
266
|
+
from telefuser.distributed.ulysses_comm import get_ulysses_world_size
|
|
267
|
+
|
|
268
|
+
sp_size = get_ulysses_world_size(device_mesh) or 1
|
|
269
|
+
else:
|
|
270
|
+
sp_size = 1
|
|
271
|
+
|
|
272
|
+
manager = cls(
|
|
273
|
+
config=config,
|
|
274
|
+
num_timesteps=num_timesteps,
|
|
275
|
+
num_layers=num_layers,
|
|
276
|
+
num_heads=num_heads,
|
|
277
|
+
head_dim=head_dim,
|
|
278
|
+
sp_size=sp_size,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Allocate immediately if device/dtype provided
|
|
282
|
+
if device is not None and dtype is not None:
|
|
283
|
+
manager.allocate(tokens_per_frame, device, dtype)
|
|
284
|
+
else:
|
|
285
|
+
manager._tokens_per_frame = tokens_per_frame
|
|
286
|
+
|
|
287
|
+
return manager
|
|
288
|
+
|
|
289
|
+
def allocate(
|
|
290
|
+
self,
|
|
291
|
+
tokens_per_frame: int,
|
|
292
|
+
device: str | torch.device = "cuda",
|
|
293
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
294
|
+
) -> None:
|
|
295
|
+
"""Allocate all cache tensors.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
tokens_per_frame: Number of tokens per frame
|
|
299
|
+
device: Target device (may be offloaded to CPU)
|
|
300
|
+
dtype: Storage dtype
|
|
301
|
+
"""
|
|
302
|
+
self._tokens_per_frame = tokens_per_frame
|
|
303
|
+
|
|
304
|
+
# Storage device (CPU if offload enabled)
|
|
305
|
+
storage_device = "cpu" if self.config.offload_cache else device
|
|
306
|
+
storage_dtype = torch.float8_e4m3fn if self.config.fp8_kv_cache else dtype
|
|
307
|
+
|
|
308
|
+
# Shape: [batch, cache_tokens, local_heads, head_dim]
|
|
309
|
+
cache_tokens = tokens_per_frame * self.config.cache_frames
|
|
310
|
+
self._shape = (1, cache_tokens, self.local_heads, self.head_dim)
|
|
311
|
+
|
|
312
|
+
# Create KVCache for each (t_idx, layer_idx) as list structure
|
|
313
|
+
self._caches = []
|
|
314
|
+
for t_idx in range(self.num_timesteps):
|
|
315
|
+
layer_caches = []
|
|
316
|
+
for layer_idx in range(self.num_layers):
|
|
317
|
+
cache = KVCache(
|
|
318
|
+
fp8_kv_cache=self.config.fp8_kv_cache,
|
|
319
|
+
offload_cache=self.config.offload_cache,
|
|
320
|
+
)
|
|
321
|
+
cache.allocate(self._shape, storage_dtype, storage_device)
|
|
322
|
+
layer_caches.append(cache)
|
|
323
|
+
self._caches.append(layer_caches)
|
|
324
|
+
|
|
325
|
+
def clear(self) -> None:
|
|
326
|
+
"""Clear all caches."""
|
|
327
|
+
for layer_caches in self._caches:
|
|
328
|
+
for cache in layer_caches:
|
|
329
|
+
cache.clear()
|
|
330
|
+
self._caches = []
|
|
331
|
+
self._shape = None
|
|
332
|
+
self._tokens_per_frame = None
|
|
333
|
+
|
|
334
|
+
def get_cache(self, t_idx: int, layer_idx: int) -> KVCache:
|
|
335
|
+
"""Get KVCache for specific timestep and layer.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
t_idx: Timestep index
|
|
339
|
+
layer_idx: Layer index
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
KVCache instance
|
|
343
|
+
"""
|
|
344
|
+
if t_idx >= len(self._caches):
|
|
345
|
+
raise IndexError(f"Timestep {t_idx} out of range (max: {len(self._caches) - 1})")
|
|
346
|
+
if layer_idx >= len(self._caches[t_idx]):
|
|
347
|
+
raise IndexError(
|
|
348
|
+
f"Layer {layer_idx} out of range for timestep {t_idx} (max: {len(self._caches[t_idx]) - 1})"
|
|
349
|
+
)
|
|
350
|
+
return self._caches[t_idx][layer_idx]
|
|
351
|
+
|
|
352
|
+
def get_timestep_caches(self, t_idx: int) -> list[KVCache]:
|
|
353
|
+
"""Get all layer caches for a timestep.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
t_idx: Timestep index
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
List of KVCache for each layer
|
|
360
|
+
"""
|
|
361
|
+
if t_idx >= len(self._caches):
|
|
362
|
+
raise IndexError(f"Timestep {t_idx} out of range (max: {len(self._caches) - 1})")
|
|
363
|
+
return self._caches[t_idx]
|
|
364
|
+
|
|
365
|
+
def to_dict(self) -> dict[int, dict[int, dict]]:
|
|
366
|
+
"""Convert to nested dict for serialization/debugging.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Dict: {t_idx: {layer_idx: {k, v, k_scale, v_scale, ...}}}
|
|
370
|
+
"""
|
|
371
|
+
result = {}
|
|
372
|
+
for t_idx, layer_caches in enumerate(self._caches):
|
|
373
|
+
result[t_idx] = {}
|
|
374
|
+
for layer_idx, cache in enumerate(layer_caches):
|
|
375
|
+
result[t_idx][layer_idx] = cache.to_dict()
|
|
376
|
+
return result
|
|
377
|
+
|
|
378
|
+
@classmethod
|
|
379
|
+
def from_dict(cls, d: dict) -> "KVCacheManager":
|
|
380
|
+
"""Create from nested dict (for deserialization).
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
d: Nested dict {t_idx: {layer_idx: {k, v, ...}}}
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
KVCacheManager instance
|
|
387
|
+
"""
|
|
388
|
+
# Extract structure info from dict
|
|
389
|
+
num_timesteps = len(d)
|
|
390
|
+
num_layers = len(d[0]) if num_timesteps > 0 else 0
|
|
391
|
+
|
|
392
|
+
# Get config from first entry
|
|
393
|
+
first_entry = d[0][0] if num_timesteps > 0 and num_layers > 0 else {}
|
|
394
|
+
config = KVCacheConfig(
|
|
395
|
+
fp8_kv_cache=first_entry.get("fp8_kv_cache", False),
|
|
396
|
+
offload_cache=first_entry.get("offload_cache", False),
|
|
397
|
+
cache_frames=6,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Infer shape from k tensor
|
|
401
|
+
k_tensor = first_entry.get("k")
|
|
402
|
+
if k_tensor is not None:
|
|
403
|
+
shape = k_tensor.shape
|
|
404
|
+
local_heads = shape[2]
|
|
405
|
+
head_dim = shape[3]
|
|
406
|
+
# We don't know num_heads without sp_size, assume sp_size=1
|
|
407
|
+
num_heads = local_heads
|
|
408
|
+
else:
|
|
409
|
+
raise ValueError("Cannot infer shape from empty cache dict")
|
|
410
|
+
|
|
411
|
+
manager = cls(
|
|
412
|
+
config=config,
|
|
413
|
+
num_timesteps=num_timesteps,
|
|
414
|
+
num_layers=num_layers,
|
|
415
|
+
num_heads=num_heads,
|
|
416
|
+
head_dim=head_dim,
|
|
417
|
+
sp_size=1,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Restore caches as list structure
|
|
421
|
+
for t_idx in range(num_timesteps):
|
|
422
|
+
layer_caches = []
|
|
423
|
+
for layer_idx in range(num_layers):
|
|
424
|
+
layer_caches.append(KVCache.from_dict(d[t_idx][layer_idx]))
|
|
425
|
+
manager._caches.append(layer_caches)
|
|
426
|
+
|
|
427
|
+
manager._shape = shape
|
|
428
|
+
return manager
|
|
429
|
+
|
|
430
|
+
@property
|
|
431
|
+
def shape(self) -> tuple[int, ...] | None:
|
|
432
|
+
"""Cache tensor shape."""
|
|
433
|
+
return self._shape
|
|
434
|
+
|
|
435
|
+
@property
|
|
436
|
+
def is_allocated(self) -> bool:
|
|
437
|
+
"""Check if caches are allocated."""
|
|
438
|
+
return len(self._caches) > 0
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
__all__ = ["CacheConfig", "CacheResult", "LatentCache"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def __getattr__(name: str) -> Any:
|
|
10
|
+
"""Lazily expose heavy symbols to keep lightweight imports usable."""
|
|
11
|
+
if name == "CacheResult":
|
|
12
|
+
module = import_module("telefuser.cache_mem.cache_types")
|
|
13
|
+
return getattr(module, "CacheResult")
|
|
14
|
+
if name == "LatentCache":
|
|
15
|
+
module = import_module("telefuser.cache_mem.latent_cache")
|
|
16
|
+
return getattr(module, "LatentCache")
|
|
17
|
+
if name == "CacheConfig":
|
|
18
|
+
try:
|
|
19
|
+
module = import_module("telefuser.cache_mem.config")
|
|
20
|
+
return getattr(module, "CacheConfig")
|
|
21
|
+
except (ImportError, ModuleNotFoundError):
|
|
22
|
+
return None
|
|
23
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def __dir__() -> list[str]:
|
|
27
|
+
return sorted(set(globals().keys()) | set(__all__))
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class CacheResult:
|
|
11
|
+
"""缓存查询结果。"""
|
|
12
|
+
|
|
13
|
+
hit: bool
|
|
14
|
+
skip_step: int = 0
|
|
15
|
+
cache_type: str = "none" # "approximate", "continue", "exact", "none"
|
|
16
|
+
similarity: float = 0.0
|
|
17
|
+
latent_state: Optional[torch.Tensor] = None
|
|
18
|
+
cached_prompt: str = ""
|
|
19
|
+
session_id: Optional[str] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class IndexEntry:
|
|
24
|
+
"""索引条目。"""
|
|
25
|
+
|
|
26
|
+
cache_id: str
|
|
27
|
+
prompt: str
|
|
28
|
+
saved_steps: List[int]
|
|
29
|
+
cache_type: str = "approximate_cache"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class VectorSearchResult:
|
|
34
|
+
"""向量检索结果。"""
|
|
35
|
+
|
|
36
|
+
cache_id: str
|
|
37
|
+
similarity: float
|
|
38
|
+
prompt: str
|
|
39
|
+
saved_steps: List[int]
|
|
40
|
+
payload: Dict[str, Any]
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CacheMode(Enum):
|
|
7
|
+
READ_WRITE = "read_write" # 读取和写入缓存(默认)
|
|
8
|
+
READ_ONLY = "read_only" # 仅读取缓存
|
|
9
|
+
WRITE_ONLY = "write_only" # 仅写入缓存
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class CacheConfig:
|
|
14
|
+
"""Cache configuration shared across stages/pipelines."""
|
|
15
|
+
|
|
16
|
+
# 基础缓存 (Basic cache)
|
|
17
|
+
enable_latent_cache: bool = False
|
|
18
|
+
cache_mode: CacheMode = CacheMode.READ_WRITE # read_write | read_only | write_only
|
|
19
|
+
latent_cache_dir: str = "./latent_cache"
|
|
20
|
+
max_cache_size_gb: int = 10
|
|
21
|
+
cache_log_enabled: bool = True
|
|
22
|
+
cache_log_dir: Optional[str] = None # 默认: {latent_cache_dir}/logs
|
|
23
|
+
cache_log_level: str = "DEBUG"
|
|
24
|
+
cache_log_rotation: str = "100 MB"
|
|
25
|
+
cache_log_retention: str = "7 days"
|
|
26
|
+
|
|
27
|
+
# KV 存储 (KV store,用于 latent 等键值缓存)
|
|
28
|
+
kv_store_type: str = "local_file" # "local_file" | "fluxon"
|
|
29
|
+
fluxon_config_path: Optional[str] = ""
|
|
30
|
+
|
|
31
|
+
# 向量存储 (Vector store,用于 embedding 检索)
|
|
32
|
+
vector_store_type: str = "faiss" # "qdrant" | "faiss"
|
|
33
|
+
qdrant_url: Optional[str] = ""
|
|
34
|
+
qdrant_api_key: Optional[str] = None
|
|
35
|
+
faiss_index_dir: Optional[str] = None
|
|
36
|
+
vector_dim: int = 2048 # 向量维度(FAISS 初始化需要,应与 embedding 模型输出维度一致)
|
|
37
|
+
cache_strategy_type: str = "video_approximate" # 策略类型,对应 STRATEGY_REGISTRY 中的 key
|
|
38
|
+
|
|
39
|
+
# 相似度与检索策略 (Similarity & lookup strategy)
|
|
40
|
+
key_steps: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5]) # 参与缓存复用的 step
|
|
41
|
+
lookup_mode: str = "video" # 检索模式,如 "video"
|
|
42
|
+
|
|
43
|
+
# 文本嵌入 (Prompt/text embedding 模型)
|
|
44
|
+
text_embedding_model_path: str = ""
|
|
45
|
+
text_embedding_instruction: str = "Represent the user's input"
|
|
46
|
+
text_embedding_device_id: Optional[int] = None
|
|
47
|
+
text_embedding_torch_dtype: Optional[str] = None
|
|
48
|
+
text_embedding_attn_impl: Optional[str] = None
|
|
49
|
+
|
|
50
|
+
# 视频嵌入 (Video embedding 模型)
|
|
51
|
+
video_embedding_enabled: bool = True
|
|
52
|
+
video_embedding_model_path: str = "Qwen/Qwen3-VL-Embedding-2B"
|
|
53
|
+
video_embedding_instruction: str = "Represent the user's input"
|
|
54
|
+
video_embedding_fps: float = 1.0
|
|
55
|
+
video_embedding_max_frames: int = 16
|
|
56
|
+
video_embedding_max_length: int = 8192
|
|
57
|
+
video_embedding_min_pixels: int = 4096
|
|
58
|
+
video_embedding_max_pixels: int = 1843200
|
|
59
|
+
video_embedding_total_pixels: int = 7864320
|
|
60
|
+
video_embedding_device_id: Optional[int] = None
|
|
61
|
+
video_embedding_torch_dtype: Optional[str] = None
|
|
62
|
+
video_embedding_attn_impl: Optional[str] = None
|
|
63
|
+
|
|
64
|
+
# 视频向量检索与重排 (Video vector search & rerank)
|
|
65
|
+
video_similarity_threshold: Optional[float] = 0.10
|
|
66
|
+
video_vector_collection: str = "video"
|
|
67
|
+
rerank_enabled: bool = False
|
|
68
|
+
rerank_model_path: str = "Qwen/Qwen3-VL-Reranker-2B"
|
|
69
|
+
rerank_top_k: int = 5
|
|
70
|
+
rerank_batch_size: int = 2
|
|
71
|
+
rerank_device_id: Optional[int] = None
|
|
72
|
+
rerank_torch_dtype: Optional[str] = None
|
|
73
|
+
rerank_score_threshold: float = 0.90
|
|
74
|
+
|
|
75
|
+
# 异步保存 (Async save / write-behind)
|
|
76
|
+
save_async_enabled: bool = True
|
|
77
|
+
save_queue_size: int = 2
|
|
78
|
+
save_on_full: str = "drop" # drop | sync | downgrade
|
|
79
|
+
save_queue_warn_threshold: int = 8
|
|
80
|
+
vector_wait_warn_s: float = 2.0
|
|
81
|
+
vector_wait_poll_s: float = 0.05
|
|
82
|
+
vector_wait_timeout_s: float = 120.0
|
|
83
|
+
flush_on_shutdown: bool = True
|