xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,779 @@
1
+ import json
2
+ import math
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from loguru import logger
12
+ from torch import Tensor
13
+ from torch.nn import functional as F
14
+ from torch.nn.attention import SDPBackend, sdpa_kernel
15
+ from torch.utils.checkpoint import checkpoint
16
+ from transformers import AutoTokenizer
17
+
18
+ from fish_speech.conversation import SEMANTIC_TOKEN
19
+ from fish_speech.utils import RankedLogger
20
+
21
+ from .lora import LoraConfig, setup_lora
22
+
23
+ log = RankedLogger(__name__, rank_zero_only=True)
24
+
25
+
26
+ def find_multiple(n: int, k: int) -> int:
27
+ if n % k == 0:
28
+ return n
29
+ return n + k - (n % k)
30
+
31
+
32
+ @dataclass
33
+ class BaseModelArgs:
34
+ model_type: str = "base"
35
+
36
+ vocab_size: int = 32000
37
+ n_layer: int = 32
38
+ n_head: int = 32
39
+ dim: int = 4096
40
+ intermediate_size: int = None
41
+ n_local_heads: int = -1
42
+ head_dim: int = 64
43
+ rope_base: float = 10000
44
+ norm_eps: float = 1e-5
45
+ max_seq_len: int = 2048
46
+ dropout: float = 0.0
47
+ tie_word_embeddings: bool = True
48
+ attention_qkv_bias: bool = False
49
+
50
+ # Codebook configs
51
+ codebook_size: int = 160
52
+ num_codebooks: int = 4
53
+
54
+ # Gradient checkpointing
55
+ use_gradient_checkpointing: bool = True
56
+
57
+ # Initialize the model
58
+ initializer_range: float = 0.02
59
+
60
+ def __post_init__(self):
61
+ if self.n_local_heads == -1:
62
+ self.n_local_heads = self.n_head
63
+ if self.intermediate_size is None:
64
+ hidden_dim = 4 * self.dim
65
+ n_hidden = int(2 * hidden_dim / 3)
66
+ self.intermediate_size = find_multiple(n_hidden, 256)
67
+ self.head_dim = self.dim // self.n_head
68
+
69
+ @staticmethod
70
+ def from_pretrained(path: str):
71
+ path = Path(path)
72
+
73
+ if path.is_dir():
74
+ path = path / "config.json"
75
+
76
+ with open(path, "r", encoding="utf-8") as f:
77
+ data = json.load(f)
78
+
79
+ match data["model_type"]:
80
+ case "naive":
81
+ cls = NaiveModelArgs
82
+ case "dual_ar":
83
+ cls = DualARModelArgs
84
+ case _:
85
+ raise ValueError(f"Unknown model type: {data['model_type']}")
86
+
87
+ return cls(**data)
88
+
89
+ def save(self, path: str):
90
+ with open(path, "w") as f:
91
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
92
+
93
+
94
+ @dataclass
95
+ class NaiveModelArgs(BaseModelArgs):
96
+ model_type: str = "naive"
97
+
98
+
99
+ @dataclass
100
+ class DualARModelArgs(BaseModelArgs):
101
+ model_type: str = "dual_ar"
102
+ n_fast_layer: int = 4
103
+
104
+
105
+ class KVCache(nn.Module):
106
+ def __init__(
107
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
108
+ ):
109
+ super().__init__()
110
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
111
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
112
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
113
+
114
+ def update(self, input_pos, k_val, v_val):
115
+ # input_pos: [S], k_val: [B, H, S, D]
116
+ assert input_pos.shape[0] == k_val.shape[2]
117
+
118
+ k_out = self.k_cache
119
+ v_out = self.v_cache
120
+ k_out[:, :, input_pos] = k_val
121
+ v_out[:, :, input_pos] = v_val
122
+
123
+ return k_out, v_out
124
+
125
+
126
+ @dataclass
127
+ class TransformerForwardResult:
128
+ token_logits: Tensor
129
+ codebook_logits: Tensor
130
+
131
+
132
+ @dataclass
133
+ class BaseTransformerForwardResult:
134
+ logits: Tensor
135
+ hidden_states: Tensor
136
+
137
+
138
+ class BaseTransformer(nn.Module):
139
+ def __init__(
140
+ self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
141
+ ) -> None:
142
+ super().__init__()
143
+ self.config = config
144
+ self.tokenizer = tokenizer
145
+
146
+ self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
147
+
148
+ # Slow transformer
149
+ self.embeddings = nn.Embedding(
150
+ config.vocab_size,
151
+ config.dim,
152
+ )
153
+ self.codebook_embeddings = nn.Embedding(
154
+ config.codebook_size * config.num_codebooks,
155
+ config.dim,
156
+ )
157
+ self.layers = nn.ModuleList(
158
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
159
+ )
160
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
161
+
162
+ if self.config.tie_word_embeddings is False:
163
+ self.output = nn.Linear(
164
+ config.dim,
165
+ config.vocab_size,
166
+ bias=False,
167
+ )
168
+
169
+ self.register_buffer(
170
+ "freqs_cis",
171
+ precompute_freqs_cis(
172
+ config.max_seq_len,
173
+ config.dim // config.n_head,
174
+ config.rope_base,
175
+ ),
176
+ persistent=False,
177
+ )
178
+ self.register_buffer(
179
+ "causal_mask",
180
+ torch.tril(
181
+ torch.ones(
182
+ config.max_seq_len,
183
+ config.max_seq_len,
184
+ dtype=torch.bool,
185
+ )
186
+ ),
187
+ persistent=False,
188
+ )
189
+
190
+ # For kv cache
191
+ self.max_batch_size = -1
192
+ self.max_seq_len = -1
193
+
194
+ if init_weights:
195
+ self.apply(self._init_weights)
196
+
197
+ def setup_caches(
198
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
199
+ ):
200
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
201
+ return
202
+
203
+ head_dim = self.config.dim // self.config.n_head
204
+ max_seq_len = find_multiple(max_seq_len, 8)
205
+ self.max_seq_len = max_seq_len
206
+ self.max_batch_size = max_batch_size
207
+
208
+ for b in self.layers:
209
+ b.attention.kv_cache = KVCache(
210
+ max_batch_size,
211
+ max_seq_len,
212
+ self.config.n_local_heads,
213
+ head_dim,
214
+ dtype=dtype,
215
+ )
216
+
217
+ def embed(self, x: Tensor) -> Tensor:
218
+ vocab_embeds = [self.embeddings(x[:, 0])]
219
+ for i in range(self.config.num_codebooks):
220
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
221
+ emb[x[:, 0] != self.semantic_token_id] = 0
222
+ vocab_embeds.append(emb)
223
+
224
+ x = torch.stack(vocab_embeds, dim=3)
225
+ x = x.sum(dim=3)
226
+
227
+ return x
228
+
229
+ def forward(
230
+ self,
231
+ inp: Tensor,
232
+ key_padding_mask: Optional[Tensor] = None,
233
+ ) -> BaseTransformerForwardResult:
234
+ seq_len = inp.size(2)
235
+
236
+ # Here we want to merge the embeddings of the codebooks
237
+ x = self.embed(inp)
238
+
239
+ freqs_cis = self.freqs_cis[:seq_len]
240
+
241
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
242
+ # That is, FALSE means masked out
243
+ # To maintain consistency, key_padding_mask use TRUE to mask out
244
+ mask = None
245
+ if key_padding_mask is not None:
246
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
247
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
248
+
249
+ for layer in self.layers:
250
+ if self.config.use_gradient_checkpointing and self.training:
251
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
252
+ else:
253
+ x = layer(x, freqs_cis, mask)
254
+
255
+ # We got slow_out here
256
+ slow_out = self.norm(x)
257
+
258
+ if self.config.tie_word_embeddings:
259
+ token_logits = F.linear(slow_out, self.embeddings.weight)
260
+ else:
261
+ token_logits = self.output(slow_out)
262
+
263
+ return BaseTransformerForwardResult(
264
+ logits=token_logits,
265
+ hidden_states=x,
266
+ )
267
+
268
+ def forward_generate(
269
+ self,
270
+ x: Tensor,
271
+ input_pos: Optional[Tensor] = None,
272
+ return_all: bool = False,
273
+ ) -> BaseTransformerForwardResult:
274
+ # This is used for generation, optimized for torch compile
275
+ assert (
276
+ self.max_seq_len != -1 and self.max_batch_size != -1
277
+ ), "Please call setup_caches before forward_generate"
278
+
279
+ x = self.embed(x)
280
+
281
+ mask = self.causal_mask[
282
+ None, None, input_pos, : self.max_seq_len
283
+ ] # (B, N, Q, K)
284
+ freqs_cis = self.freqs_cis[input_pos]
285
+
286
+ for layer in self.layers:
287
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
288
+
289
+ # If prefill, we only calculate the logits of last token
290
+ if x.size(1) > 1 and not return_all:
291
+ x = x[:, -1:]
292
+
293
+ # We got slow_out here
294
+ slow_out = self.norm(x)
295
+
296
+ if self.config.tie_word_embeddings:
297
+ token_logits = F.linear(slow_out, self.embeddings.weight)
298
+ else:
299
+ token_logits = self.output(slow_out)
300
+
301
+ return BaseTransformerForwardResult(
302
+ logits=token_logits,
303
+ hidden_states=x,
304
+ )
305
+
306
+ def _init_weights(self, module):
307
+ std = self.config.initializer_range
308
+ if isinstance(module, nn.Linear):
309
+ module.weight.data.normal_(mean=0.0, std=std)
310
+ if module.bias is not None:
311
+ module.bias.data.zero_()
312
+ elif isinstance(module, nn.Embedding):
313
+ module.weight.data.normal_(mean=0.0, std=std)
314
+ if module.padding_idx is not None:
315
+ module.weight.data[module.padding_idx].zero_()
316
+
317
+ @staticmethod
318
+ def from_pretrained(
319
+ path: str,
320
+ load_weights: bool = False,
321
+ max_length: int | None = None,
322
+ lora_config: LoraConfig | None = None,
323
+ rope_base: int | None = None,
324
+ ) -> "BaseTransformer":
325
+ config = BaseModelArgs.from_pretrained(str(path))
326
+ if max_length is not None:
327
+ config.max_seq_len = max_length
328
+ log.info(f"Override max_seq_len to {max_length}")
329
+
330
+ if rope_base is not None:
331
+ config.rope_base = rope_base
332
+ log.info(f"Override rope_base to {rope_base}")
333
+
334
+ match config.model_type:
335
+ case "naive":
336
+ model_cls = NaiveTransformer
337
+ case "dual_ar":
338
+ model_cls = DualARTransformer
339
+ case _:
340
+ raise ValueError(f"Unknown model type: {config.model_type}")
341
+
342
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
343
+ log.info(f"Loading model from {path}, config: {config}")
344
+ model = model_cls(config, tokenizer=tokenizer)
345
+
346
+ if lora_config is not None:
347
+ setup_lora(model, lora_config)
348
+ log.info(f"LoRA setup: {lora_config}")
349
+
350
+ if load_weights is False:
351
+ log.info("Randomly initialized model")
352
+ else:
353
+
354
+ if "int8" in str(Path(path)):
355
+ logger.info("Using int8 weight-only quantization!")
356
+ from ...tools.llama.quantize import WeightOnlyInt8QuantHandler
357
+
358
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
359
+ model = simple_quantizer.convert_for_runtime()
360
+
361
+ if "int4" in str(Path(path)):
362
+ logger.info("Using int4 quantization!")
363
+ path_comps = path.name.split("-")
364
+ assert path_comps[-2].startswith("g")
365
+ groupsize = int(path_comps[-2][1:])
366
+ from ...tools.llama.quantize import WeightOnlyInt4QuantHandler
367
+
368
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
369
+ model = simple_quantizer.convert_for_runtime()
370
+
371
+ weights = torch.load(
372
+ Path(path) / "model.pth", map_location="cpu", mmap=True
373
+ )
374
+
375
+ if "state_dict" in weights:
376
+ logger.warning(
377
+ "Using a TextToSemantic LightningModule checkpoint, "
378
+ "please make sure it is a full model, not a LoRA model."
379
+ )
380
+ weights = weights["state_dict"]
381
+
382
+ if next(iter(weights.keys())).startswith("model."):
383
+ logger.info(
384
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
385
+ )
386
+ new_weights = OrderedDict()
387
+ for k, v in weights.items():
388
+ new_weights[k.replace("model.", "")] = v
389
+ weights = new_weights
390
+
391
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
392
+ for k, v in model.named_parameters():
393
+ if k not in weights:
394
+ logger.warning(f"No weight for {k}")
395
+ elif v.shape != weights[k].shape:
396
+ logger.warning(
397
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
398
+ )
399
+
400
+ err = model.load_state_dict(weights, strict=False, assign=True)
401
+ log.info(f"Loaded weights with error: {err}")
402
+
403
+ return model
404
+
405
+ def save_pretrained(self, path: str, drop_lora: bool = False):
406
+ path = Path(path)
407
+ path.mkdir(parents=True, exist_ok=True)
408
+
409
+ self.config.save(path / "config.json")
410
+ state_dict = self.state_dict()
411
+
412
+ if drop_lora:
413
+ for key in list(state_dict.keys()):
414
+ if "lora" not in key:
415
+ continue
416
+
417
+ state_dict.pop(key)
418
+ log.info(f"Drop LoRA parameter: {key}")
419
+
420
+ torch.save(state_dict, path / "model.pth")
421
+ self.tokenizer.save_pretrained(path)
422
+
423
+
424
+ class NaiveTransformer(BaseTransformer):
425
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
426
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
427
+
428
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
429
+ self.codebook_output = nn.Linear(
430
+ config.dim,
431
+ config.codebook_size * config.num_codebooks,
432
+ bias=False,
433
+ )
434
+
435
+ self.apply(self._init_weights)
436
+
437
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
438
+ token_logits = result.logits
439
+ x = result.hidden_states
440
+
441
+ # Codebook
442
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
443
+ codebook_logits = rearrange(
444
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
445
+ )
446
+
447
+ return TransformerForwardResult(
448
+ token_logits=token_logits,
449
+ codebook_logits=codebook_logits,
450
+ )
451
+
452
+ def forward(
453
+ self,
454
+ inp: Tensor,
455
+ key_padding_mask: Optional[Tensor] = None,
456
+ ) -> TransformerForwardResult:
457
+ result = super().forward(
458
+ inp=inp,
459
+ key_padding_mask=key_padding_mask,
460
+ )
461
+ return self.decode(result)
462
+
463
+ def forward_generate(
464
+ self, x: Tensor, input_pos: Optional[Tensor] = None
465
+ ) -> TransformerForwardResult:
466
+ result = super().forward_generate(x, input_pos)
467
+ return self.decode(result)
468
+
469
+
470
+ class DualARTransformer(BaseTransformer):
471
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
472
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
473
+
474
+ # Fast transformer
475
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
476
+
477
+ # The equivalent bs is so large that sdpa doesn't work
478
+ self.fast_layers = nn.ModuleList(
479
+ TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
480
+ )
481
+ self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
482
+ self.fast_output = nn.Linear(
483
+ config.dim,
484
+ config.codebook_size,
485
+ bias=False,
486
+ )
487
+
488
+ self.apply(self._init_weights)
489
+
490
+ def setup_caches(
491
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
492
+ ):
493
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
494
+
495
+ head_dim = self.config.dim // self.config.n_head
496
+
497
+ # Fast transformer
498
+ # The max seq len here is the number of codebooks
499
+ for b in self.fast_layers:
500
+ b.attention.kv_cache = KVCache(
501
+ max_batch_size,
502
+ self.config.num_codebooks,
503
+ self.config.n_local_heads,
504
+ head_dim,
505
+ dtype=dtype,
506
+ )
507
+
508
+ def forward(
509
+ self,
510
+ inp: Tensor,
511
+ key_padding_mask: Optional[Tensor] = None,
512
+ ) -> TransformerForwardResult:
513
+ parent_result = super().forward(inp, key_padding_mask)
514
+ token_logits = parent_result.logits
515
+ x = parent_result.hidden_states
516
+
517
+ # Fast transformer
518
+ fast_seq_len = self.config.num_codebooks
519
+ fast_mask = self.causal_mask[
520
+ None, None, :fast_seq_len, :fast_seq_len
521
+ ] # (B, N, Q, K)
522
+ fast_freqs_cis = self.freqs_cis[:fast_seq_len]
523
+
524
+ # Drop the last token and rotate left
525
+ codebooks = inp[:, 1:-1, 1:]
526
+ codebooks = F.pad(codebooks, (0, 1), value=0)
527
+ codebook_embeddings = self.fast_embeddings(codebooks)
528
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
529
+ b, s = x.size(0), x.size(2)
530
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
531
+
532
+ # Remove padded part
533
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
534
+ codebook_mask = (codebooks == 0).all(dim=-1)
535
+
536
+ if torch.all(codebook_mask):
537
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
538
+ codebook_mask[:8] = False
539
+
540
+ x_bs, x_len = x.size(0), x.size(1)
541
+ x = x[~codebook_mask]
542
+
543
+ for layer in self.fast_layers:
544
+ if self.config.use_gradient_checkpointing and self.training:
545
+ x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
546
+ else:
547
+ x = layer(x, fast_freqs_cis, fast_mask)
548
+
549
+ # unflatten the batch and num_codebooks
550
+ fast_out = self.fast_norm(x)
551
+ codebook_logits = self.fast_output(fast_out)
552
+
553
+ # Re-pad the codebook_logits
554
+ buffer = torch.zeros(
555
+ x_bs,
556
+ x_len,
557
+ codebook_logits.size(-1),
558
+ device=codebook_logits.device,
559
+ dtype=codebook_logits.dtype,
560
+ )
561
+ buffer[~codebook_mask] = codebook_logits
562
+ codebook_logits = buffer
563
+
564
+ assert codebook_logits.shape[1] == self.config.num_codebooks
565
+ codebook_logits = rearrange(
566
+ codebook_logits,
567
+ "(b s) n d -> b s n d",
568
+ b=b,
569
+ s=s,
570
+ n=self.config.num_codebooks,
571
+ )
572
+
573
+ return TransformerForwardResult(
574
+ token_logits=token_logits,
575
+ codebook_logits=codebook_logits,
576
+ )
577
+
578
+ def forward_generate_fast(
579
+ self, x: Tensor, input_pos: Optional[Tensor] = None
580
+ ) -> Tensor:
581
+ # Fast transformer
582
+ x = x.view(1, 1, -1)
583
+
584
+ fast_mask = self.causal_mask[
585
+ None, None, input_pos, : self.config.num_codebooks
586
+ ] # (B, N, Q, K)
587
+ fast_freqs_cis = self.freqs_cis[input_pos]
588
+
589
+ for layer in self.fast_layers:
590
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
591
+
592
+ # unflatten the batch and num_codebooks
593
+ fast_out = self.fast_norm(x) # only take the last token
594
+ codebook_logits = self.fast_output(fast_out)
595
+
596
+ return codebook_logits
597
+
598
+
599
+ class TransformerBlock(nn.Module):
600
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
601
+ super().__init__()
602
+ self.attention = Attention(config, use_sdpa=use_sdpa)
603
+ self.feed_forward = FeedForward(config)
604
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
605
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
606
+
607
+ def forward(
608
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
609
+ ) -> Tensor:
610
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
611
+ out = h + self.feed_forward(self.ffn_norm(h))
612
+ return out
613
+
614
+
615
+ class Attention(nn.Module):
616
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
617
+ super().__init__()
618
+ assert config.dim % config.n_head == 0
619
+
620
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
621
+ # key, query, value projections for all heads, but in a batch
622
+ self.wqkv = nn.Linear(
623
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
624
+ )
625
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
626
+ self.kv_cache = None
627
+
628
+ self.dropout = config.dropout
629
+ self.n_head = config.n_head
630
+ self.head_dim = config.head_dim
631
+ self.n_local_heads = config.n_local_heads
632
+ self.dim = config.dim
633
+ self.use_sdpa = use_sdpa
634
+ self._register_load_state_dict_pre_hook(self.load_hook)
635
+
636
+ def load_hook(self, state_dict, prefix, *args):
637
+ if prefix + "wq.weight" in state_dict:
638
+ wq = state_dict.pop(prefix + "wq.weight")
639
+ wk = state_dict.pop(prefix + "wk.weight")
640
+ wv = state_dict.pop(prefix + "wv.weight")
641
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
642
+
643
+ def forward(
644
+ self,
645
+ x: Tensor,
646
+ freqs_cis: Tensor,
647
+ mask: Tensor,
648
+ input_pos: Optional[Tensor] = None,
649
+ ) -> Tensor:
650
+ bsz, seqlen, _ = x.shape
651
+
652
+ kv_size = self.n_local_heads * self.head_dim
653
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
654
+
655
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
656
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
657
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
658
+
659
+ q = apply_rotary_emb(q, freqs_cis)
660
+ k = apply_rotary_emb(k, freqs_cis)
661
+
662
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
663
+
664
+ if self.kv_cache is not None:
665
+ k, v = self.kv_cache.update(input_pos, k, v)
666
+
667
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
668
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
669
+
670
+ if self.use_sdpa:
671
+ if mask is None:
672
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
673
+ y = F.scaled_dot_product_attention(
674
+ q,
675
+ k,
676
+ v,
677
+ dropout_p=self.dropout if self.training else 0.0,
678
+ is_causal=True,
679
+ # No third party attn_mask here to use flash_attention
680
+ )
681
+ else:
682
+ y = F.scaled_dot_product_attention(
683
+ q,
684
+ k,
685
+ v,
686
+ attn_mask=mask,
687
+ dropout_p=self.dropout if self.training else 0.0,
688
+ )
689
+ else:
690
+ y = self.eq_scaled_dot_product_attention(
691
+ q,
692
+ k,
693
+ v,
694
+ attn_mask=mask,
695
+ dropout_p=self.dropout if self.training else 0.0,
696
+ )
697
+
698
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
699
+
700
+ return self.wo(y)
701
+
702
+ def eq_scaled_dot_product_attention(
703
+ self,
704
+ query,
705
+ key,
706
+ value,
707
+ attn_mask=None,
708
+ dropout_p=0.0,
709
+ ) -> torch.Tensor:
710
+ # This is a standard scaled dot product attention
711
+ # It's low efficient, but it doesn't raise cuda error
712
+
713
+ L, S = query.size(-2), key.size(-2)
714
+ scale_factor = 1 / math.sqrt(query.size(-1))
715
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
716
+
717
+ if attn_mask is not None:
718
+ if attn_mask.dtype == torch.bool:
719
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
720
+ else:
721
+ attn_bias += attn_mask
722
+
723
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
724
+ attn_weight += attn_bias
725
+ attn_weight = torch.softmax(attn_weight, dim=-1)
726
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
727
+
728
+ return attn_weight @ value
729
+
730
+
731
+ class FeedForward(nn.Module):
732
+ def __init__(self, config: BaseModelArgs) -> None:
733
+ super().__init__()
734
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
735
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
736
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
737
+
738
+ def forward(self, x: Tensor) -> Tensor:
739
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
740
+
741
+
742
+ class RMSNorm(nn.Module):
743
+ def __init__(self, dim: int, eps: float = 1e-5):
744
+ super().__init__()
745
+ self.eps = eps
746
+ self.weight = nn.Parameter(torch.ones(dim))
747
+
748
+ def _norm(self, x):
749
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
750
+
751
+ def forward(self, x: Tensor) -> Tensor:
752
+ output = self._norm(x.float()).type_as(x)
753
+ return output * self.weight
754
+
755
+
756
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
757
+ freqs = 1.0 / (
758
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
759
+ )
760
+ t = torch.arange(seq_len, device=freqs.device)
761
+ freqs = torch.outer(t, freqs)
762
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
763
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
764
+ return cache.to(dtype=torch.bfloat16)
765
+
766
+
767
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
768
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
769
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
770
+ x_out2 = torch.stack(
771
+ [
772
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
773
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
774
+ ],
775
+ -1,
776
+ )
777
+
778
+ x_out2 = x_out2.flatten(3)
779
+ return x_out2.type_as(x)