telefuser 0.1.0.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (296) hide show
  1. telefuser/__init__.py +6 -0
  2. telefuser/_logo.py +12 -0
  3. telefuser/_version.py +24 -0
  4. telefuser/cache/__init__.py +5 -0
  5. telefuser/cache/kv_cache.py +438 -0
  6. telefuser/cache_mem/__init__.py +27 -0
  7. telefuser/cache_mem/cache_types.py +40 -0
  8. telefuser/cache_mem/config.py +83 -0
  9. telefuser/cache_mem/connection.py +197 -0
  10. telefuser/cache_mem/encoders.py +398 -0
  11. telefuser/cache_mem/encoding/__init__.py +0 -0
  12. telefuser/cache_mem/encoding/interfaces.py +27 -0
  13. telefuser/cache_mem/latent_cache.py +213 -0
  14. telefuser/cache_mem/log_monitor.py +77 -0
  15. telefuser/cache_mem/metadata.py +268 -0
  16. telefuser/cache_mem/src/__init__.py +0 -0
  17. telefuser/cache_mem/src/models/__init__.py +0 -0
  18. telefuser/cache_mem/src/models/qwen3_vl_embedding.py +346 -0
  19. telefuser/cache_mem/src/models/qwen3_vl_reranker.py +437 -0
  20. telefuser/cache_mem/state/__init__.py +0 -0
  21. telefuser/cache_mem/state/interfaces.py +67 -0
  22. telefuser/cache_mem/storage/__init__.py +11 -0
  23. telefuser/cache_mem/storage/fluxon.py +24 -0
  24. telefuser/cache_mem/storage/interfaces.py +25 -0
  25. telefuser/cache_mem/storage/local_file.py +112 -0
  26. telefuser/cache_mem/storage/memory.py +24 -0
  27. telefuser/cache_mem/strategies.py +819 -0
  28. telefuser/cache_mem/vector_store/__init__.py +5 -0
  29. telefuser/cache_mem/vector_store/faiss.py +298 -0
  30. telefuser/cache_mem/vector_store/interfaces.py +42 -0
  31. telefuser/cache_mem/vector_store/qdrant.py +46 -0
  32. telefuser/client/__init__.py +34 -0
  33. telefuser/client/openai/__init__.py +34 -0
  34. telefuser/client/openai/client.py +146 -0
  35. telefuser/client/openai/images.py +221 -0
  36. telefuser/client/openai/videos.py +307 -0
  37. telefuser/client/tf_client.py +1016 -0
  38. telefuser/core/__init__.py +37 -0
  39. telefuser/core/base_model.py +262 -0
  40. telefuser/core/base_pipeline.py +421 -0
  41. telefuser/core/base_stage.py +169 -0
  42. telefuser/core/config.py +409 -0
  43. telefuser/core/config_serializer.py +54 -0
  44. telefuser/core/model_registry.py +108 -0
  45. telefuser/core/module_manager.py +412 -0
  46. telefuser/distributed/__init__.py +95 -0
  47. telefuser/distributed/device_mesh.py +347 -0
  48. telefuser/distributed/fsdp.py +143 -0
  49. telefuser/distributed/parallel_shard.py +250 -0
  50. telefuser/distributed/pp_comm.py +306 -0
  51. telefuser/distributed/ring.py +357 -0
  52. telefuser/distributed/tp_parallelize.py +63 -0
  53. telefuser/distributed/ulysses_comm.py +250 -0
  54. telefuser/entrypoints/__init__.py +3 -0
  55. telefuser/entrypoints/cli/main.py +257 -0
  56. telefuser/feature_cache/__init__.py +47 -0
  57. telefuser/feature_cache/ada_taylor_cache/__init__.py +28 -0
  58. telefuser/feature_cache/ada_taylor_cache/ada_taylor_cache.py +656 -0
  59. telefuser/feature_cache/ada_taylor_cache/params/HunyuanVideo15-I2V-480P.json +111 -0
  60. telefuser/feature_cache/ada_taylor_cache/params/HunyuanVideo15-T2V-480P.json +111 -0
  61. telefuser/feature_cache/ada_taylor_cache/params/Qwen-Image-2512.json +111 -0
  62. telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-FL2V-14B-720P.json +89 -0
  63. telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-I2V-14B-480P.json +89 -0
  64. telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-I2V-14B-720P.json +89 -0
  65. telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-T2V-14B.json +109 -0
  66. telefuser/feature_cache/ada_taylor_cache/params/Wan2_1-T2V-1_3B.json +109 -0
  67. telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-FL2V-A14B.json +89 -0
  68. telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-I2V-A14B-Camera.json +109 -0
  69. telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-I2V-A14B.json +89 -0
  70. telefuser/feature_cache/ada_taylor_cache/params/Wan2_2-T2V-A14B.json +89 -0
  71. telefuser/feature_cache/base.py +150 -0
  72. telefuser/kernel/__init__.py +55 -0
  73. telefuser/kernel/triton/__init__.py +43 -0
  74. telefuser/kernel/triton/merge_attn_states.py +115 -0
  75. telefuser/kernel/triton/norm.py +816 -0
  76. telefuser/kernel/triton/quant.py +147 -0
  77. telefuser/kernel/triton/quant_per_block.py +154 -0
  78. telefuser/kernel/triton/rotary.py +162 -0
  79. telefuser/kernel/triton/scale_shift.py +1064 -0
  80. telefuser/kernel/triton/sparse_int8_attn.py +280 -0
  81. telefuser/metrics/__init__.py +101 -0
  82. telefuser/metrics/collector.py +442 -0
  83. telefuser/metrics/config.py +111 -0
  84. telefuser/metrics/exporters.py +113 -0
  85. telefuser/metrics/registry.py +485 -0
  86. telefuser/metrics/service_metrics.py +436 -0
  87. telefuser/metrics/stage_metrics.py +207 -0
  88. telefuser/models/TCDecoder.py +352 -0
  89. telefuser/models/__init__.py +24 -0
  90. telefuser/models/flashvsr_dit.py +608 -0
  91. telefuser/models/flux2_dit.py +1126 -0
  92. telefuser/models/hunyuan_video_byt5.py +433 -0
  93. telefuser/models/hunyuan_video_dit.py +2124 -0
  94. telefuser/models/hunyuan_video_image_encoder.py +222 -0
  95. telefuser/models/hunyuan_video_text_encoder.py +461 -0
  96. telefuser/models/hunyuan_video_upsampler.py +320 -0
  97. telefuser/models/hunyuan_video_vae.py +850 -0
  98. telefuser/models/lingbot_world_fast_dit.py +573 -0
  99. telefuser/models/liveact_dit.py +1213 -0
  100. telefuser/models/longcat_video_dit.py +1214 -0
  101. telefuser/models/ltx_audio_vae.py +1183 -0
  102. telefuser/models/ltx_dit.py +2202 -0
  103. telefuser/models/ltx_gemma_text_encoder.py +1004 -0
  104. telefuser/models/ltx_upsampler.py +416 -0
  105. telefuser/models/ltx_video_vae.py +2668 -0
  106. telefuser/models/qwen_image_dit.py +780 -0
  107. telefuser/models/qwen_image_text_encoder.py +196 -0
  108. telefuser/models/qwen_image_vae.py +643 -0
  109. telefuser/models/realesrgan.py +356 -0
  110. telefuser/models/rift_hdv3.py +353 -0
  111. telefuser/models/t5_tokenizer.py +96 -0
  112. telefuser/models/video_projector.py +457 -0
  113. telefuser/models/wan22_video_vae.py +1548 -0
  114. telefuser/models/wan_video_dit.py +1586 -0
  115. telefuser/models/wan_video_image_encoder.py +534 -0
  116. telefuser/models/wan_video_text_encoder.py +317 -0
  117. telefuser/models/wan_video_vae.py +1519 -0
  118. telefuser/models/wav2vec2.py +154 -0
  119. telefuser/models/xlm_roberta.py +157 -0
  120. telefuser/models/z_image_dit.py +695 -0
  121. telefuser/models/z_image_text_encoder.py +81 -0
  122. telefuser/offload/__init__.py +26 -0
  123. telefuser/offload/async_offload.py +417 -0
  124. telefuser/offload/model_offload.py +35 -0
  125. telefuser/offload/sequential_offload.py +318 -0
  126. telefuser/ops/__init__.py +33 -0
  127. telefuser/ops/activations.py +187 -0
  128. telefuser/ops/attention/__init__.py +29 -0
  129. telefuser/ops/attention/attention_impl.py +529 -0
  130. telefuser/ops/attention/backends.py +209 -0
  131. telefuser/ops/attention/bsa.py +250 -0
  132. telefuser/ops/attention/local_sparse_attn.py +547 -0
  133. telefuser/ops/attention/sparse_patterns.py +622 -0
  134. telefuser/ops/attention/sparse_sage.py +80 -0
  135. telefuser/ops/base.py +145 -0
  136. telefuser/ops/custom_op.py +121 -0
  137. telefuser/ops/ffn.py +69 -0
  138. telefuser/ops/fp8_gemm.py +348 -0
  139. telefuser/ops/normalization.py +274 -0
  140. telefuser/ops/quantized_linear.py +164 -0
  141. telefuser/ops/rotary.py +138 -0
  142. telefuser/orchestrator/__init__.py +22 -0
  143. telefuser/orchestrator/artifact_save_stage.py +119 -0
  144. telefuser/orchestrator/pipeline_orchestrator.py +358 -0
  145. telefuser/orchestrator/stage_wrapper.py +276 -0
  146. telefuser/pipelines/__init__.py +9 -0
  147. telefuser/pipelines/common/realesrgan_upscale.py +92 -0
  148. telefuser/pipelines/common/rift_vfi.py +54 -0
  149. telefuser/pipelines/flashvsr/__init__.py +4 -0
  150. telefuser/pipelines/flashvsr/dit_denoising.py +312 -0
  151. telefuser/pipelines/flashvsr/flashvsr_stream.py +197 -0
  152. telefuser/pipelines/flashvsr/vae.py +57 -0
  153. telefuser/pipelines/flux2_klein/__init__.py +5 -0
  154. telefuser/pipelines/flux2_klein/dit_denoising.py +329 -0
  155. telefuser/pipelines/flux2_klein/pipeline.py +427 -0
  156. telefuser/pipelines/flux2_klein/text_encoding.py +201 -0
  157. telefuser/pipelines/flux2_klein/vae.py +215 -0
  158. telefuser/pipelines/hunyuan_video_1_5/__init__.py +55 -0
  159. telefuser/pipelines/hunyuan_video_1_5/dit_denoising.py +270 -0
  160. telefuser/pipelines/hunyuan_video_1_5/image_encoding.py +87 -0
  161. telefuser/pipelines/hunyuan_video_1_5/pipeline.py +324 -0
  162. telefuser/pipelines/hunyuan_video_1_5/sr_dit_denoising.py +363 -0
  163. telefuser/pipelines/hunyuan_video_1_5/text_encoding.py +291 -0
  164. telefuser/pipelines/hunyuan_video_1_5/upsampler.py +95 -0
  165. telefuser/pipelines/hunyuan_video_1_5/vae.py +133 -0
  166. telefuser/pipelines/lingbot_world_fast/__init__.py +31 -0
  167. telefuser/pipelines/lingbot_world_fast/control.py +208 -0
  168. telefuser/pipelines/lingbot_world_fast/denoising.py +85 -0
  169. telefuser/pipelines/lingbot_world_fast/pipeline.py +592 -0
  170. telefuser/pipelines/lingbot_world_fast/service.py +483 -0
  171. telefuser/pipelines/lingbot_world_fast/session.py +76 -0
  172. telefuser/pipelines/liveact/__init__.py +16 -0
  173. telefuser/pipelines/liveact/audio_encoding.py +365 -0
  174. telefuser/pipelines/liveact/denoising.py +306 -0
  175. telefuser/pipelines/liveact/pipeline.py +337 -0
  176. telefuser/pipelines/longcat_video/__init__.py +12 -0
  177. telefuser/pipelines/longcat_video/dit_denoising.py +297 -0
  178. telefuser/pipelines/longcat_video/longcat_video.py +542 -0
  179. telefuser/pipelines/longcat_video/refine_denoise.py +235 -0
  180. telefuser/pipelines/longcat_video/text_encoding.py +118 -0
  181. telefuser/pipelines/ltx_video/__init__.py +1 -0
  182. telefuser/pipelines/ltx_video/dit_denoising.py +1010 -0
  183. telefuser/pipelines/ltx_video/gemma_text_encoding.py +165 -0
  184. telefuser/pipelines/ltx_video/ltx23_video.py +518 -0
  185. telefuser/pipelines/ltx_video/upsampler.py +29 -0
  186. telefuser/pipelines/ltx_video/vae.py +195 -0
  187. telefuser/pipelines/qwen_image/__init__.py +11 -0
  188. telefuser/pipelines/qwen_image/dit_denoising.py +228 -0
  189. telefuser/pipelines/qwen_image/qwen_image.py +301 -0
  190. telefuser/pipelines/qwen_image/qwen_image_edit.py +209 -0
  191. telefuser/pipelines/qwen_image/text_encoding.py +223 -0
  192. telefuser/pipelines/qwen_image/vae.py +91 -0
  193. telefuser/pipelines/wan_video/__init__.py +6 -0
  194. telefuser/pipelines/wan_video/async_wan22_video.py +467 -0
  195. telefuser/pipelines/wan_video/clip_encoding.py +54 -0
  196. telefuser/pipelines/wan_video/latent_data_utils.py +53 -0
  197. telefuser/pipelines/wan_video/moe_dit_denoising.py +409 -0
  198. telefuser/pipelines/wan_video/single_dit_denoising.py +262 -0
  199. telefuser/pipelines/wan_video/text_encoding.py +58 -0
  200. telefuser/pipelines/wan_video/ti2v_denoising.py +396 -0
  201. telefuser/pipelines/wan_video/vae.py +237 -0
  202. telefuser/pipelines/wan_video/wan21_video.py +353 -0
  203. telefuser/pipelines/wan_video/wan22_ti2v.py +372 -0
  204. telefuser/pipelines/wan_video/wan22_video.py +318 -0
  205. telefuser/pipelines/z_image/__init__.py +8 -0
  206. telefuser/pipelines/z_image/dit_denoising.py +281 -0
  207. telefuser/pipelines/z_image/text_encoding.py +117 -0
  208. telefuser/pipelines/z_image/vae.py +49 -0
  209. telefuser/pipelines/z_image/z_image.py +139 -0
  210. telefuser/platforms/__init__.py +80 -0
  211. telefuser/platforms/cpu.py +30 -0
  212. telefuser/platforms/cuda.py +86 -0
  213. telefuser/platforms/interface.py +99 -0
  214. telefuser/platforms/npu.py +78 -0
  215. telefuser/platforms/rocm.py +72 -0
  216. telefuser/schedulers/__init__.py +13 -0
  217. telefuser/schedulers/flow_match.py +377 -0
  218. telefuser/schedulers/flow_match_discrete.py +325 -0
  219. telefuser/schedulers/lcm.py +81 -0
  220. telefuser/schedulers/unipc.py +697 -0
  221. telefuser/service/__init__.py +38 -0
  222. telefuser/service/api/__init__.py +41 -0
  223. telefuser/service/api/api_server.py +327 -0
  224. telefuser/service/api/middleware.py +334 -0
  225. telefuser/service/api/openai/__init__.py +58 -0
  226. telefuser/service/api/openai/adapter.py +423 -0
  227. telefuser/service/api/openai/image_routes.py +417 -0
  228. telefuser/service/api/openai/protocol.py +287 -0
  229. telefuser/service/api/openai/video_routes.py +416 -0
  230. telefuser/service/api/routers/__init__.py +20 -0
  231. telefuser/service/api/routers/files.py +65 -0
  232. telefuser/service/api/routers/service.py +143 -0
  233. telefuser/service/api/routers/stream.py +88 -0
  234. telefuser/service/api/routers/tasks.py +432 -0
  235. telefuser/service/api/routers/webrtc.py +164 -0
  236. telefuser/service/api/schema.py +80 -0
  237. telefuser/service/api/stream_schema.py +76 -0
  238. telefuser/service/api/task_contract_runtime.py +148 -0
  239. telefuser/service/api/utils.py +139 -0
  240. telefuser/service/cache/__init__.py +4 -0
  241. telefuser/service/cache/cache_factory.py +176 -0
  242. telefuser/service/cache/cache_service.py +389 -0
  243. telefuser/service/core/__init__.py +30 -0
  244. telefuser/service/core/config.py +249 -0
  245. telefuser/service/core/container.py +264 -0
  246. telefuser/service/core/contract_templates.py +147 -0
  247. telefuser/service/core/file_service.py +269 -0
  248. telefuser/service/core/pipeline_contract.py +339 -0
  249. telefuser/service/core/pipeline_loader.py +94 -0
  250. telefuser/service/core/pipeline_pool.py +280 -0
  251. telefuser/service/core/pipeline_runner.py +205 -0
  252. telefuser/service/core/pipeline_service.py +311 -0
  253. telefuser/service/core/replica_worker.py +298 -0
  254. telefuser/service/core/stream_pipeline_service.py +261 -0
  255. telefuser/service/core/task_manager.py +416 -0
  256. telefuser/service/core/task_processor.py +162 -0
  257. telefuser/service/core/task_service.py +156 -0
  258. telefuser/service/main.py +99 -0
  259. telefuser/service/media/__init__.py +17 -0
  260. telefuser/service/media/media_base.py +298 -0
  261. telefuser/service/security/__init__.py +32 -0
  262. telefuser/service/security/security_validator.py +797 -0
  263. telefuser/service/webrtc/__init__.py +32 -0
  264. telefuser/service/webrtc/chunk_router.py +111 -0
  265. telefuser/service/webrtc/session_manager.py +322 -0
  266. telefuser/service/webrtc/track.py +307 -0
  267. telefuser/service_types.py +83 -0
  268. telefuser/utils/__init__.py +25 -0
  269. telefuser/utils/audio.py +51 -0
  270. telefuser/utils/func.py +31 -0
  271. telefuser/utils/hf_model_analyzer.py +382 -0
  272. telefuser/utils/hf_model_utils.py +209 -0
  273. telefuser/utils/hf_utils.py +256 -0
  274. telefuser/utils/logging.py +749 -0
  275. telefuser/utils/lora_loader.py +295 -0
  276. telefuser/utils/lora_network.py +212 -0
  277. telefuser/utils/memory_snapshot.py +423 -0
  278. telefuser/utils/model_weight.py +163 -0
  279. telefuser/utils/profiler.py +1079 -0
  280. telefuser/utils/stage_bench_harness.py +740 -0
  281. telefuser/utils/system.py +228 -0
  282. telefuser/utils/torch_compile.py +83 -0
  283. telefuser/utils/utils.py +49 -0
  284. telefuser/utils/video.py +464 -0
  285. telefuser/worker/__init__.py +18 -0
  286. telefuser/worker/native_worker.py +125 -0
  287. telefuser/worker/parallel_worker.py +292 -0
  288. telefuser/worker/ray_worker.py +107 -0
  289. telefuser-0.1.0.post3.dist-info/METADATA +379 -0
  290. telefuser-0.1.0.post3.dist-info/RECORD +296 -0
  291. telefuser-0.1.0.post3.dist-info/WHEEL +5 -0
  292. telefuser-0.1.0.post3.dist-info/entry_points.txt +2 -0
  293. telefuser-0.1.0.post3.dist-info/licenses/LICENSE +201 -0
  294. telefuser-0.1.0.post3.dist-info/scm_file_list.json +763 -0
  295. telefuser-0.1.0.post3.dist-info/scm_version.json +8 -0
  296. telefuser-0.1.0.post3.dist-info/top_level.txt +1 -0
telefuser/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ try:
2
+ from ._version import __version__
3
+ except ModuleNotFoundError as exc:
4
+ if exc.name != "telefuser._version":
5
+ raise
6
+ __version__ = "0.0.0+unknown"
telefuser/_logo.py ADDED
@@ -0,0 +1,12 @@
1
+ """Lightweight package branding constants."""
2
+
3
+ from __future__ import annotations
4
+
5
+ TELEFUSER_LOGO = r"""
6
+ ████████╗███████╗██╗ ███████╗███████╗██╗ ██╗███████╗███████╗█████████╗
7
+ ╚══██╔══╝██╔════╝██║ ██╔════╝██╔════╝██║ ██║██╔════╝██╔════╝██╔════██║
8
+ ██║ █████╗ ██║ █████╗ █████╗ ██║ ██║███████╗█████╗ ███████╔═╝
9
+ ██║ ██╔══╝ ██║ ██╔══╝ ██╔══╝ ██║ ██║╚════██║██╔══╝ ██╔══██║
10
+ ██║ ███████╗███████╗███████╗██║ ╚██████╔╝███████║███████╗██║ ████╗
11
+ ╚═╝ ╚══════╝╚══════╝ ╚═════╝╚═╝ ╚═════╝ ╚══════╝╚══════╝╚═╝ ╚═══╝
12
+ """
telefuser/_version.py ADDED
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.1.0.post3'
22
+ __version_tuple__ = version_tuple = (0, 1, 0, 'post3')
23
+
24
+ __commit_id__ = commit_id = 'gca0bc08c7'
@@ -0,0 +1,5 @@
1
+ """TeleFuser Cache Module."""
2
+
3
+ from telefuser.cache.kv_cache import KVCache, KVCacheConfig, KVCacheManager
4
+
5
+ __all__ = ["KVCache", "KVCacheConfig", "KVCacheManager"]
@@ -0,0 +1,438 @@
1
+ """KVCache module for LiveAct - List-based structure for torch.compile optimization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+
11
+ @dataclass
12
+ class KVCacheConfig:
13
+ """Configuration for KV cache management.
14
+
15
+ Attributes:
16
+ fp8_kv_cache: Enable FP8 quantization for memory efficiency
17
+ offload_cache: Offload cache to CPU memory
18
+ cache_frames: Number of frames to cache after compression (default 6)
19
+ """
20
+
21
+ fp8_kv_cache: bool = False
22
+ offload_cache: bool = False
23
+ cache_frames: int = 6
24
+
25
+
26
+ class KVCache:
27
+ """Minimal KV cache for a single (timestep, layer) entry.
28
+
29
+ Preserves exact original behavior:
30
+ - Direct dict access (k, v, k_scale, v_scale)
31
+ - FP8 quantization support
32
+ - CPU offload support
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ fp8_kv_cache: bool = False,
38
+ offload_cache: bool = False,
39
+ ):
40
+ """Initialize KV cache.
41
+
42
+ Args:
43
+ fp8_kv_cache: Enable FP8 quantization
44
+ offload_cache: Enable CPU offload
45
+ """
46
+ self.fp8_kv_cache = fp8_kv_cache
47
+ self.offload_cache = offload_cache
48
+
49
+ # Storage tensors
50
+ self.k: torch.Tensor | None = None
51
+ self.v: torch.Tensor | None = None
52
+ self.k_scale: torch.Tensor | None = None
53
+ self.v_scale: torch.Tensor | None = None
54
+
55
+ def allocate(self, shape: tuple[int, ...], dtype: torch.dtype, device: str | torch.device) -> None:
56
+ """Allocate cache tensors.
57
+
58
+ Args:
59
+ shape: Shape of K/V tensor [batch, seq, heads, head_dim]
60
+ dtype: Storage dtype (bf16 or fp8)
61
+ device: Device for storage
62
+ """
63
+ storage_dtype = torch.float8_e4m3fn if self.fp8_kv_cache else dtype
64
+ self.k = torch.zeros(shape, dtype=storage_dtype, device=device)
65
+ self.v = torch.zeros(shape, dtype=storage_dtype, device=device)
66
+
67
+ if self.fp8_kv_cache:
68
+ # Scale shape: [batch, seq, heads, 1]
69
+ scale_shape = (shape[0], shape[1], shape[2], 1)
70
+ self.k_scale = torch.ones(scale_shape, dtype=torch.float32, device=device)
71
+ self.v_scale = torch.ones(scale_shape, dtype=torch.float32, device=device)
72
+
73
+ def clear(self) -> None:
74
+ """Clear cache."""
75
+ self.k = None
76
+ self.v = None
77
+ self.k_scale = None
78
+ self.v_scale = None
79
+
80
+ def load(self, device: str | torch.device, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
81
+ """Load K/V tensors to compute device.
82
+
83
+ Args:
84
+ device: Target device
85
+ dtype: Target dtype
86
+
87
+ Returns:
88
+ (k, v) tensors on device
89
+ """
90
+ # Move to device if offloaded
91
+ if self.offload_cache:
92
+ self._move_to_device(device)
93
+
94
+ # Dequantize if FP8
95
+ if self.fp8_kv_cache:
96
+ k = self._dequantize(self.k, self.k_scale, dtype)
97
+ v = self._dequantize(self.v, self.v_scale, dtype)
98
+ else:
99
+ if self.k.dtype != dtype:
100
+ self.k = self.k.to(dtype=dtype)
101
+ if self.v.dtype != dtype:
102
+ self.v = self.v.to(dtype=dtype)
103
+ k = self.k
104
+ v = self.v
105
+
106
+ return k, v
107
+
108
+ def store(self, k: torch.Tensor, v: torch.Tensor) -> None:
109
+ """Store K/V tensors.
110
+
111
+ Args:
112
+ k: Key tensor
113
+ v: Value tensor
114
+ """
115
+ if self.fp8_kv_cache:
116
+ self.k, self.k_scale = self._quantize(k)
117
+ self.v, self.v_scale = self._quantize(v)
118
+ else:
119
+ self.k = k
120
+ self.v = v
121
+
122
+ if self.offload_cache:
123
+ self._move_to_device("cpu")
124
+
125
+ def _move_to_device(self, device: str | torch.device) -> None:
126
+ """Move cache tensors to device."""
127
+ self.k = self.k.to(device=device, non_blocking=True)
128
+ self.v = self.v.to(device=device, non_blocking=True)
129
+ if self.k_scale is not None:
130
+ self.k_scale = self.k_scale.to(device=device, non_blocking=True)
131
+ if self.v_scale is not None:
132
+ self.v_scale = self.v_scale.to(device=device, non_blocking=True)
133
+
134
+ def _quantize(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
135
+ """Quantize tensor to FP8."""
136
+ fp8_max = torch.finfo(torch.float8_e4m3fn).max
137
+ scale = tensor.detach().abs().amax(dim=-1, keepdim=True).to(torch.float32)
138
+ scale = torch.clamp(scale / fp8_max, min=1e-12)
139
+ q_tensor = (tensor / scale.to(dtype=tensor.dtype)).to(torch.float8_e4m3fn)
140
+ return q_tensor.contiguous(), scale.contiguous()
141
+
142
+ def _dequantize(self, q_tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
143
+ """Dequantize FP8 tensor."""
144
+ return q_tensor.to(dtype=dtype) * scale.to(device=q_tensor.device, dtype=dtype)
145
+
146
+ def to_dict(self) -> dict:
147
+ """Convert to dict for backward compatibility."""
148
+ return {
149
+ "k": self.k,
150
+ "v": self.v,
151
+ "k_scale": self.k_scale,
152
+ "v_scale": self.v_scale,
153
+ "fp8_kv_cache": self.fp8_kv_cache,
154
+ "offload_cache": self.offload_cache,
155
+ }
156
+
157
+ @classmethod
158
+ def from_dict(cls, d: dict) -> "KVCache":
159
+ """Create from dict."""
160
+ cache = cls(
161
+ fp8_kv_cache=d.get("fp8_kv_cache", False),
162
+ offload_cache=d.get("offload_cache", False),
163
+ )
164
+ cache.k = d.get("k")
165
+ cache.v = d.get("v")
166
+ cache.k_scale = d.get("k_scale")
167
+ cache.v_scale = d.get("v_scale")
168
+ return cache
169
+
170
+
171
+ class KVCacheManager:
172
+ """Manager for nested KV cache structure (timestep -> layer -> KVCache).
173
+
174
+ Uses list structure for optimal torch.compile performance:
175
+ - List indexing is faster than dict hashing
176
+ - No graph breaks from dynamic dict keys
177
+ - Memory contiguous for compiler optimization
178
+
179
+ Usage:
180
+ config = KVCacheConfig(fp8_kv_cache=False, offload_cache=True, cache_frames=6)
181
+ manager = KVCacheManager.from_dit_model(
182
+ dit_model,
183
+ config=config,
184
+ tokens_per_frame=520,
185
+ num_timesteps=3,
186
+ )
187
+ manager.allocate(device="cuda", dtype=torch.bfloat16)
188
+
189
+ # Access cache for specific (t_idx, layer_idx)
190
+ cache = manager.get_cache(t_idx=0, layer_idx=5)
191
+ k, v = cache.load(device, dtype)
192
+
193
+ # Get all layer caches for a timestep (returns list)
194
+ kv_list = manager.get_timestep_caches(0)
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ config: KVCacheConfig,
200
+ num_timesteps: int,
201
+ num_layers: int,
202
+ num_heads: int,
203
+ head_dim: int,
204
+ sp_size: int = 1,
205
+ ):
206
+ """Initialize KV cache manager.
207
+
208
+ Args:
209
+ config: KVCacheConfig instance
210
+ num_timesteps: Number of denoising timesteps
211
+ num_layers: Number of transformer layers
212
+ num_heads: Total number of attention heads
213
+ head_dim: Dimension per head
214
+ sp_size: Sequence parallel world size (heads are sharded)
215
+ """
216
+ self.config = config
217
+ self.num_timesteps = num_timesteps
218
+ self.num_layers = num_layers
219
+ self.num_heads = num_heads
220
+ self.head_dim = head_dim
221
+ self.sp_size = sp_size
222
+
223
+ # Local heads after SP sharding
224
+ self.local_heads = num_heads // sp_size
225
+
226
+ # Cache storage: list[timestep][layer] = KVCache
227
+ self._caches: list[list[KVCache]] = []
228
+
229
+ # Pre-allocated shape (set during allocate)
230
+ self._shape: tuple[int, ...] | None = None
231
+ self._tokens_per_frame: int | None = None
232
+
233
+ @classmethod
234
+ def from_dit_model(
235
+ cls,
236
+ dit_model: Any,
237
+ config: KVCacheConfig,
238
+ tokens_per_frame: int,
239
+ num_timesteps: int = 3,
240
+ sp_size: int | None = None,
241
+ device: str | torch.device | None = None,
242
+ dtype: torch.dtype | None = None,
243
+ ) -> "KVCacheManager":
244
+ """Create KV cache manager from DiT model.
245
+
246
+ Args:
247
+ dit_model: LiveActDiT model with blocks, num_heads, dim attributes
248
+ config: KVCacheConfig instance
249
+ tokens_per_frame: Number of tokens per frame (h * w)
250
+ num_timesteps: Number of denoising timesteps
251
+ sp_size: Sequence parallel size (None: auto-detect from dit_model)
252
+ device: Target device (None: cuda)
253
+ dtype: Target dtype (None: bf16)
254
+
255
+ Returns:
256
+ KVCacheManager instance
257
+ """
258
+ num_layers = len(dit_model.blocks)
259
+ num_heads = dit_model.num_heads
260
+ head_dim = dit_model.dim // dit_model.num_heads
261
+
262
+ # Auto-detect sp_size from dit_model if not provided
263
+ if sp_size is None:
264
+ device_mesh = getattr(dit_model, "device_mesh", None)
265
+ if device_mesh is not None:
266
+ from telefuser.distributed.ulysses_comm import get_ulysses_world_size
267
+
268
+ sp_size = get_ulysses_world_size(device_mesh) or 1
269
+ else:
270
+ sp_size = 1
271
+
272
+ manager = cls(
273
+ config=config,
274
+ num_timesteps=num_timesteps,
275
+ num_layers=num_layers,
276
+ num_heads=num_heads,
277
+ head_dim=head_dim,
278
+ sp_size=sp_size,
279
+ )
280
+
281
+ # Allocate immediately if device/dtype provided
282
+ if device is not None and dtype is not None:
283
+ manager.allocate(tokens_per_frame, device, dtype)
284
+ else:
285
+ manager._tokens_per_frame = tokens_per_frame
286
+
287
+ return manager
288
+
289
+ def allocate(
290
+ self,
291
+ tokens_per_frame: int,
292
+ device: str | torch.device = "cuda",
293
+ dtype: torch.dtype = torch.bfloat16,
294
+ ) -> None:
295
+ """Allocate all cache tensors.
296
+
297
+ Args:
298
+ tokens_per_frame: Number of tokens per frame
299
+ device: Target device (may be offloaded to CPU)
300
+ dtype: Storage dtype
301
+ """
302
+ self._tokens_per_frame = tokens_per_frame
303
+
304
+ # Storage device (CPU if offload enabled)
305
+ storage_device = "cpu" if self.config.offload_cache else device
306
+ storage_dtype = torch.float8_e4m3fn if self.config.fp8_kv_cache else dtype
307
+
308
+ # Shape: [batch, cache_tokens, local_heads, head_dim]
309
+ cache_tokens = tokens_per_frame * self.config.cache_frames
310
+ self._shape = (1, cache_tokens, self.local_heads, self.head_dim)
311
+
312
+ # Create KVCache for each (t_idx, layer_idx) as list structure
313
+ self._caches = []
314
+ for t_idx in range(self.num_timesteps):
315
+ layer_caches = []
316
+ for layer_idx in range(self.num_layers):
317
+ cache = KVCache(
318
+ fp8_kv_cache=self.config.fp8_kv_cache,
319
+ offload_cache=self.config.offload_cache,
320
+ )
321
+ cache.allocate(self._shape, storage_dtype, storage_device)
322
+ layer_caches.append(cache)
323
+ self._caches.append(layer_caches)
324
+
325
+ def clear(self) -> None:
326
+ """Clear all caches."""
327
+ for layer_caches in self._caches:
328
+ for cache in layer_caches:
329
+ cache.clear()
330
+ self._caches = []
331
+ self._shape = None
332
+ self._tokens_per_frame = None
333
+
334
+ def get_cache(self, t_idx: int, layer_idx: int) -> KVCache:
335
+ """Get KVCache for specific timestep and layer.
336
+
337
+ Args:
338
+ t_idx: Timestep index
339
+ layer_idx: Layer index
340
+
341
+ Returns:
342
+ KVCache instance
343
+ """
344
+ if t_idx >= len(self._caches):
345
+ raise IndexError(f"Timestep {t_idx} out of range (max: {len(self._caches) - 1})")
346
+ if layer_idx >= len(self._caches[t_idx]):
347
+ raise IndexError(
348
+ f"Layer {layer_idx} out of range for timestep {t_idx} (max: {len(self._caches[t_idx]) - 1})"
349
+ )
350
+ return self._caches[t_idx][layer_idx]
351
+
352
+ def get_timestep_caches(self, t_idx: int) -> list[KVCache]:
353
+ """Get all layer caches for a timestep.
354
+
355
+ Args:
356
+ t_idx: Timestep index
357
+
358
+ Returns:
359
+ List of KVCache for each layer
360
+ """
361
+ if t_idx >= len(self._caches):
362
+ raise IndexError(f"Timestep {t_idx} out of range (max: {len(self._caches) - 1})")
363
+ return self._caches[t_idx]
364
+
365
+ def to_dict(self) -> dict[int, dict[int, dict]]:
366
+ """Convert to nested dict for serialization/debugging.
367
+
368
+ Returns:
369
+ Dict: {t_idx: {layer_idx: {k, v, k_scale, v_scale, ...}}}
370
+ """
371
+ result = {}
372
+ for t_idx, layer_caches in enumerate(self._caches):
373
+ result[t_idx] = {}
374
+ for layer_idx, cache in enumerate(layer_caches):
375
+ result[t_idx][layer_idx] = cache.to_dict()
376
+ return result
377
+
378
+ @classmethod
379
+ def from_dict(cls, d: dict) -> "KVCacheManager":
380
+ """Create from nested dict (for deserialization).
381
+
382
+ Args:
383
+ d: Nested dict {t_idx: {layer_idx: {k, v, ...}}}
384
+
385
+ Returns:
386
+ KVCacheManager instance
387
+ """
388
+ # Extract structure info from dict
389
+ num_timesteps = len(d)
390
+ num_layers = len(d[0]) if num_timesteps > 0 else 0
391
+
392
+ # Get config from first entry
393
+ first_entry = d[0][0] if num_timesteps > 0 and num_layers > 0 else {}
394
+ config = KVCacheConfig(
395
+ fp8_kv_cache=first_entry.get("fp8_kv_cache", False),
396
+ offload_cache=first_entry.get("offload_cache", False),
397
+ cache_frames=6,
398
+ )
399
+
400
+ # Infer shape from k tensor
401
+ k_tensor = first_entry.get("k")
402
+ if k_tensor is not None:
403
+ shape = k_tensor.shape
404
+ local_heads = shape[2]
405
+ head_dim = shape[3]
406
+ # We don't know num_heads without sp_size, assume sp_size=1
407
+ num_heads = local_heads
408
+ else:
409
+ raise ValueError("Cannot infer shape from empty cache dict")
410
+
411
+ manager = cls(
412
+ config=config,
413
+ num_timesteps=num_timesteps,
414
+ num_layers=num_layers,
415
+ num_heads=num_heads,
416
+ head_dim=head_dim,
417
+ sp_size=1,
418
+ )
419
+
420
+ # Restore caches as list structure
421
+ for t_idx in range(num_timesteps):
422
+ layer_caches = []
423
+ for layer_idx in range(num_layers):
424
+ layer_caches.append(KVCache.from_dict(d[t_idx][layer_idx]))
425
+ manager._caches.append(layer_caches)
426
+
427
+ manager._shape = shape
428
+ return manager
429
+
430
+ @property
431
+ def shape(self) -> tuple[int, ...] | None:
432
+ """Cache tensor shape."""
433
+ return self._shape
434
+
435
+ @property
436
+ def is_allocated(self) -> bool:
437
+ """Check if caches are allocated."""
438
+ return len(self._caches) > 0
@@ -0,0 +1,27 @@
1
+ from __future__ import annotations
2
+
3
+ from importlib import import_module
4
+ from typing import Any
5
+
6
+ __all__ = ["CacheConfig", "CacheResult", "LatentCache"]
7
+
8
+
9
+ def __getattr__(name: str) -> Any:
10
+ """Lazily expose heavy symbols to keep lightweight imports usable."""
11
+ if name == "CacheResult":
12
+ module = import_module("telefuser.cache_mem.cache_types")
13
+ return getattr(module, "CacheResult")
14
+ if name == "LatentCache":
15
+ module = import_module("telefuser.cache_mem.latent_cache")
16
+ return getattr(module, "LatentCache")
17
+ if name == "CacheConfig":
18
+ try:
19
+ module = import_module("telefuser.cache_mem.config")
20
+ return getattr(module, "CacheConfig")
21
+ except (ImportError, ModuleNotFoundError):
22
+ return None
23
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
24
+
25
+
26
+ def __dir__() -> list[str]:
27
+ return sorted(set(globals().keys()) | set(__all__))
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class CacheResult:
11
+ """缓存查询结果。"""
12
+
13
+ hit: bool
14
+ skip_step: int = 0
15
+ cache_type: str = "none" # "approximate", "continue", "exact", "none"
16
+ similarity: float = 0.0
17
+ latent_state: Optional[torch.Tensor] = None
18
+ cached_prompt: str = ""
19
+ session_id: Optional[str] = None
20
+
21
+
22
+ @dataclass
23
+ class IndexEntry:
24
+ """索引条目。"""
25
+
26
+ cache_id: str
27
+ prompt: str
28
+ saved_steps: List[int]
29
+ cache_type: str = "approximate_cache"
30
+
31
+
32
+ @dataclass
33
+ class VectorSearchResult:
34
+ """向量检索结果。"""
35
+
36
+ cache_id: str
37
+ similarity: float
38
+ prompt: str
39
+ saved_steps: List[int]
40
+ payload: Dict[str, Any]
@@ -0,0 +1,83 @@
1
+ from dataclasses import dataclass, field
2
+ from enum import Enum
3
+ from typing import List, Optional
4
+
5
+
6
+ class CacheMode(Enum):
7
+ READ_WRITE = "read_write" # 读取和写入缓存(默认)
8
+ READ_ONLY = "read_only" # 仅读取缓存
9
+ WRITE_ONLY = "write_only" # 仅写入缓存
10
+
11
+
12
+ @dataclass
13
+ class CacheConfig:
14
+ """Cache configuration shared across stages/pipelines."""
15
+
16
+ # 基础缓存 (Basic cache)
17
+ enable_latent_cache: bool = False
18
+ cache_mode: CacheMode = CacheMode.READ_WRITE # read_write | read_only | write_only
19
+ latent_cache_dir: str = "./latent_cache"
20
+ max_cache_size_gb: int = 10
21
+ cache_log_enabled: bool = True
22
+ cache_log_dir: Optional[str] = None # 默认: {latent_cache_dir}/logs
23
+ cache_log_level: str = "DEBUG"
24
+ cache_log_rotation: str = "100 MB"
25
+ cache_log_retention: str = "7 days"
26
+
27
+ # KV 存储 (KV store,用于 latent 等键值缓存)
28
+ kv_store_type: str = "local_file" # "local_file" | "fluxon"
29
+ fluxon_config_path: Optional[str] = ""
30
+
31
+ # 向量存储 (Vector store,用于 embedding 检索)
32
+ vector_store_type: str = "faiss" # "qdrant" | "faiss"
33
+ qdrant_url: Optional[str] = ""
34
+ qdrant_api_key: Optional[str] = None
35
+ faiss_index_dir: Optional[str] = None
36
+ vector_dim: int = 2048 # 向量维度(FAISS 初始化需要,应与 embedding 模型输出维度一致)
37
+ cache_strategy_type: str = "video_approximate" # 策略类型,对应 STRATEGY_REGISTRY 中的 key
38
+
39
+ # 相似度与检索策略 (Similarity & lookup strategy)
40
+ key_steps: List[int] = field(default_factory=lambda: [0, 1, 2, 3, 4, 5]) # 参与缓存复用的 step
41
+ lookup_mode: str = "video" # 检索模式,如 "video"
42
+
43
+ # 文本嵌入 (Prompt/text embedding 模型)
44
+ text_embedding_model_path: str = ""
45
+ text_embedding_instruction: str = "Represent the user's input"
46
+ text_embedding_device_id: Optional[int] = None
47
+ text_embedding_torch_dtype: Optional[str] = None
48
+ text_embedding_attn_impl: Optional[str] = None
49
+
50
+ # 视频嵌入 (Video embedding 模型)
51
+ video_embedding_enabled: bool = True
52
+ video_embedding_model_path: str = "Qwen/Qwen3-VL-Embedding-2B"
53
+ video_embedding_instruction: str = "Represent the user's input"
54
+ video_embedding_fps: float = 1.0
55
+ video_embedding_max_frames: int = 16
56
+ video_embedding_max_length: int = 8192
57
+ video_embedding_min_pixels: int = 4096
58
+ video_embedding_max_pixels: int = 1843200
59
+ video_embedding_total_pixels: int = 7864320
60
+ video_embedding_device_id: Optional[int] = None
61
+ video_embedding_torch_dtype: Optional[str] = None
62
+ video_embedding_attn_impl: Optional[str] = None
63
+
64
+ # 视频向量检索与重排 (Video vector search & rerank)
65
+ video_similarity_threshold: Optional[float] = 0.10
66
+ video_vector_collection: str = "video"
67
+ rerank_enabled: bool = False
68
+ rerank_model_path: str = "Qwen/Qwen3-VL-Reranker-2B"
69
+ rerank_top_k: int = 5
70
+ rerank_batch_size: int = 2
71
+ rerank_device_id: Optional[int] = None
72
+ rerank_torch_dtype: Optional[str] = None
73
+ rerank_score_threshold: float = 0.90
74
+
75
+ # 异步保存 (Async save / write-behind)
76
+ save_async_enabled: bool = True
77
+ save_queue_size: int = 2
78
+ save_on_full: str = "drop" # drop | sync | downgrade
79
+ save_queue_warn_threshold: int = 8
80
+ vector_wait_warn_s: float = 2.0
81
+ vector_wait_poll_s: float = 0.05
82
+ vector_wait_timeout_s: float = 120.0
83
+ flush_on_shutdown: bool = True