xinference 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +2 -0
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +72 -66
- xinference/core/model.py +78 -25
- xinference/core/supervisor.py +81 -10
- xinference/core/utils.py +12 -8
- xinference/core/worker.py +32 -0
- xinference/model/audio/core.py +5 -0
- xinference/model/audio/cosyvoice.py +25 -3
- xinference/model/audio/f5tts.py +15 -10
- xinference/model/audio/f5tts_mlx.py +260 -0
- xinference/model/audio/fish_speech.py +35 -111
- xinference/model/audio/model_spec.json +19 -3
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/audio/utils.py +32 -0
- xinference/model/image/core.py +69 -1
- xinference/model/image/model_spec.json +145 -4
- xinference/model/image/model_spec_modelscope.json +150 -4
- xinference/model/image/stable_diffusion/core.py +45 -13
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +143 -0
- xinference/model/llm/llm_family.py +15 -36
- xinference/model/llm/llm_family_modelscope.json +148 -0
- xinference/model/llm/mlx/core.py +37 -32
- xinference/model/llm/transformers/cogagent.py +272 -0
- xinference/model/llm/transformers/core.py +2 -0
- xinference/model/llm/transformers/qwen2_vl.py +12 -1
- xinference/model/llm/utils.py +28 -3
- xinference/model/llm/vllm/core.py +48 -9
- xinference/model/llm/vllm/xavier/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/allocator.py +74 -0
- xinference/model/llm/vllm/xavier/block.py +112 -0
- xinference/model/llm/vllm/xavier/block_manager.py +71 -0
- xinference/model/llm/vllm/xavier/block_tracker.py +116 -0
- xinference/model/llm/vllm/xavier/engine.py +247 -0
- xinference/model/llm/vllm/xavier/executor.py +132 -0
- xinference/model/llm/vllm/xavier/scheduler.py +422 -0
- xinference/model/llm/vllm/xavier/test/__init__.py +13 -0
- xinference/model/llm/vllm/xavier/test/test_xavier.py +122 -0
- xinference/model/llm/vllm/xavier/transfer.py +298 -0
- xinference/model/video/diffusers.py +14 -0
- xinference/model/video/model_spec.json +15 -0
- xinference/model/video/model_spec_modelscope.json +16 -0
- xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
- xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
- xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
- xinference/thirdparty/cosyvoice/bin/train.py +42 -8
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
- xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
- xinference/thirdparty/cosyvoice/cli/model.py +330 -80
- xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
- xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
- xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
- xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
- xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
- xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
- xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
- xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
- xinference/thirdparty/cosyvoice/utils/common.py +28 -1
- xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
- xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
- xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
- xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
- xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
- xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
- xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
- xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
- xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
- xinference/thirdparty/fish_speech/tools/schema.py +11 -28
- xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
- xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
- xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
- xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
- xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
- xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
- xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
- xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
- xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
- xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
- xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
- xinference/thirdparty/matcha/utils/utils.py +2 -2
- xinference/types.py +13 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.51a587ff.css +2 -0
- xinference/web/ui/build/static/css/main.51a587ff.css.map +1 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js +3 -0
- xinference/web/ui/build/static/js/main.1eb206d1.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/03c4052f1b91f6ba0c5389bdcf49c43319b4076c08e4b8585dab312538ae290a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/1786b83003b8e9605a0f5f855a185d4d16e38fc893dfb326a2a9cca206b4240a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/17cbc181dd674b9150b80c73ed6a82656de0082d857f6e5f66d9716129ac0b38.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/185ceb8872d562e032b47e79df6a45670e06345b8ed70aad1a131e0476783c5c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/26b8c9f34b0bed789b3a833767672e39302d1e0c09b4276f4d58d1df7b6bd93b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2b484da66c724d0d56a40849c109327408796a668b1381511b6e9e03baa48658.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2cbbbce9b84df73330d4c42b82436ed881b3847628f2fbc346aa62e2859fd88c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/2ec9b14431ed33ce6901bf9f27007be4e6e472709c99d6e22b50ce528e4b78ee.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3b966db018f96be4a055d6ca205f0990d4d0b370e2980c17d8bca2c9a021819c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3eefb411b24c2b3ce053570ef50daccf154022f0e168be5ed0fec21394baf9f4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/522b229e3cac219123f0d69673f5570e191c2d2a505dc65b312d336eae2279c0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/52e45f17ba300580ea3fcc9f9228ccba194bb092b76f25e9255af311f8b05aab.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5a0bc4631f936459afc1a3b1d3ec2420118b1f00e11f60ccac3e08088f3f27a8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/611fa2c6c53b66039991d06dfb0473b5ab37fc63b4564e0f6e1718523768a045.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/6329bc76c406fe5eb305412383fbde5950f847bb5e43261f73f37622c365acb4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/63c8e07687ea53a4f8a910ee5e42e0eb26cd1acbfbe820f3e3248a786ee51401.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69b2d5001684174ec9da57e07914eed3eac4960018bceb6cbfa801d861301d7c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/710c1acda69e561e30a933b98c6a56d50197868b15c21e2aad55ab6d46649eb6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/720deca1fce5a1dc5056048fa8258fd138a82ea855f350b6613f104a73fb761f.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/76a23b92d26a499c57e61eea2b895fbc9771bd0849a72e66f8e633192017978b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/858063f23b34dfe600254eb5afd85518b0002ec4b30b7386616c45600826e3b2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/920b82c1c89124cf217109eeedbfcd3aae3b917be50c9dfb6bbb4ce26bdfd2e7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/94d8b7aeb0076f2ce07db598cea0e87b13bc8d5614eb530b8d6e696c2daf6f88.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9e917fe7022d01b2ccbe5cc0ce73d70bb72bee584ff293bad71bdff6695dee28.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/9f28fdb8399f1d0474f0aca86f1658dc94f5bf0c90f6146352de150692de8862.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/a0dfafa06b2bb7cba8cad41c482503f61944f759f4318139362602ef5cc47ccb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/afb8084f539534cd594755ea2205ecd5bd1f62dddcfdf75a2eace59a28131278.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b57b1438b77294c1f3f6cfce12ac487d8106c6f016975ba0aec94d98997e2e1e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b9917b0bf8e4d55ccbac1c334aa04d6ff3c5b6ed9e5d38b9ea2c687fa7d3f5a9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bbcc94b0149963d1d6f267ee1f4f03d3925b758392ce2f516c3fe8af0e0169fc.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bdee44abeadc4abc17d41c52eb49c6e19a4b1a267b6e16876ce91bdeeebfc52d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/beb112b70f4a56db95920a9e20efb6c97c37b68450716730217a9ee1a9ae92be.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/c88db97be0cdf440193b3995996e83510a04cb00048135485fc0e26d197e80b5.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d49e5314d34310a62d01a03067ce1bec5da00abce84c5196aa9c6842fa79a430.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d7664d18c4ddbad9c3a6a31b91f7c00fb0dde804608674a9860ee50f33e54708.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/d9072c318b819b7c90a0f7e9cc0b6413b4dbeb8e9859898e53d75ea882fcde99.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/db16a983bc08a05f0439cc61ca0840e49e1d8400eef678909f16c032a418a3d6.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/dc249829767b8abcbc3677e0b07b6d3ecbfdfe6d08cfe23a665eb33373a9aa9d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e242c583c2dbc2784f0fcf513523975f7d5df447e106c1c17e49e8578a6fc3ed.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/eac5f1296513e69e4b96f750ddccd4d0264e2bae4e4c449144e83274a48698d9.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/ed57202cb79649bb716400436590245547df241988fc7c8e1d85d132299542d2.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f125bf72e773a14cdaebd0c343e80adb909d12e317ee5c00cd4a57442fbe2c62.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f91af913d7f91c410719ab13136aaed3aaf0f8dda06652f25c42cb5231587398.json +1 -0
- xinference/web/ui/node_modules/.package-lock.json +67 -3
- xinference/web/ui/node_modules/@babel/runtime/package.json +592 -538
- xinference/web/ui/node_modules/html-parse-stringify/package.json +50 -0
- xinference/web/ui/node_modules/i18next/dist/esm/package.json +1 -0
- xinference/web/ui/node_modules/i18next/package.json +129 -0
- xinference/web/ui/node_modules/react-i18next/.eslintrc.json +74 -0
- xinference/web/ui/node_modules/react-i18next/dist/es/package.json +1 -0
- xinference/web/ui/node_modules/react-i18next/package.json +162 -0
- xinference/web/ui/node_modules/void-elements/package.json +34 -0
- xinference/web/ui/package-lock.json +69 -3
- xinference/web/ui/package.json +2 -0
- xinference/web/ui/src/locales/en.json +186 -0
- xinference/web/ui/src/locales/zh.json +186 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/METADATA +19 -11
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/RECORD +178 -111
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +0 -943
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
- xinference/thirdparty/fish_speech/tools/webui.py +0 -548
- xinference/web/ui/build/static/css/main.5061c4c3.css +0 -2
- xinference/web/ui/build/static/css/main.5061c4c3.css.map +0 -1
- xinference/web/ui/build/static/js/main.4eb4ee80.js +0 -3
- xinference/web/ui/build/static/js/main.4eb4ee80.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1130403f9e46f5738a23b45ac59b57de8f360c908c713e2c0670c2cce9bd367a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1f269fb2a368363c1cb2237825f1dba093b6bdd8c44cc05954fd19ec2c1fff03.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/8d33354bd2100c8602afc3341f131a88cc36aaeecd5a4b365ed038514708e350.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/9375a35b05d56989b2755bf72161fa707c92f28569d33765a75f91a568fda6e9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a158a9ffa0c9b169aee53dd4a0c44501a596755b4e4f6ede7746d65a72e2a71f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/c7bf40bab396765f67d0fed627ed3665890608b2d0edaa3e8cb7cfc96310db45.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d6c643278a0b28320e6f33a60f5fb64c053997cbdc39a60e53ccc574688ade9e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +0 -1
- /xinference/web/ui/build/static/js/{main.4eb4ee80.js.LICENSE.txt → main.1eb206d1.js.LICENSE.txt} +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/LICENSE +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/WHEEL +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.1.0.dist-info → xinference-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -2,41 +2,10 @@ from dataclasses import dataclass, field
|
|
|
2
2
|
from typing import Literal
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
|
|
6
|
-
|
|
7
|
-
IM_START_TOKEN = "<|im_start|>"
|
|
8
|
-
IM_END_TOKEN = "<|im_end|>"
|
|
9
|
-
SEMANTIC_TOKEN = "<|semantic|>"
|
|
10
|
-
MEL_TOKEN = "<|mel|>"
|
|
11
|
-
PHONEME_START_TOKEN = "<|phoneme_start|>"
|
|
12
|
-
PHONEME_END_TOKEN = "<|phoneme_end|>"
|
|
13
|
-
ALL_SPECIAL_TOKENS = [
|
|
14
|
-
IM_START_TOKEN,
|
|
15
|
-
IM_END_TOKEN,
|
|
16
|
-
SEMANTIC_TOKEN,
|
|
17
|
-
MEL_TOKEN,
|
|
18
|
-
PHONEME_START_TOKEN,
|
|
19
|
-
PHONEME_END_TOKEN,
|
|
20
|
-
]
|
|
21
|
-
|
|
22
|
-
CODEBOOK_PAD_TOKEN_ID = 0
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class FishTokenizerConfig(PretrainedConfig):
|
|
26
|
-
share_codebook_embeddings: bool = True
|
|
27
|
-
codebook_size: int = 1024
|
|
28
|
-
num_codebooks: int = 8
|
|
29
5
|
|
|
6
|
+
from .tokenizer import MODALITY_TOKENS, FishTokenizer
|
|
30
7
|
|
|
31
|
-
|
|
32
|
-
def __init__(self, *args, **kwargs):
|
|
33
|
-
super().__init__(*args, **kwargs)
|
|
34
|
-
self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
|
|
35
|
-
self.codebook_size = kwargs.pop("codebook_size", 1024)
|
|
36
|
-
self.num_codebooks = kwargs.pop("num_codebooks", 8)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
|
|
8
|
+
CODEBOOK_PAD_TOKEN_ID = 0
|
|
40
9
|
|
|
41
10
|
|
|
42
11
|
@dataclass(kw_only=True)
|
|
@@ -54,77 +23,72 @@ class TextPart(BasePart):
|
|
|
54
23
|
text: str
|
|
55
24
|
|
|
56
25
|
|
|
57
|
-
@dataclass(kw_only=True)
|
|
58
|
-
class MelPart(BasePart):
|
|
59
|
-
mels: torch.Tensor
|
|
60
|
-
|
|
61
|
-
|
|
62
26
|
@dataclass(kw_only=True)
|
|
63
27
|
class EncodedMessage:
|
|
64
28
|
tokens: torch.Tensor
|
|
65
29
|
labels: torch.Tensor
|
|
30
|
+
vq_mask_tokens: torch.Tensor | None = None
|
|
31
|
+
vq_mask_labels: torch.Tensor | None = None
|
|
66
32
|
vq_parts: list[torch.Tensor]
|
|
67
|
-
mel_parts: list[torch.Tensor]
|
|
68
33
|
vq_require_losses: torch.Tensor | None = None
|
|
69
34
|
|
|
70
35
|
|
|
71
36
|
@dataclass(kw_only=True)
|
|
72
37
|
class Message:
|
|
73
38
|
role: Literal["system", "user", "assistant"]
|
|
74
|
-
parts: list[VQPart | TextPart
|
|
39
|
+
parts: list[VQPart | TextPart] = field(default_factory=list)
|
|
75
40
|
add_im_start: bool = True
|
|
76
41
|
add_im_end: bool = True
|
|
77
42
|
cal_loss: bool = False
|
|
43
|
+
modality: Literal["text", "voice", "interleave"] | None = None
|
|
78
44
|
|
|
79
45
|
# By default, ignore the loss of the auto-generated im_start token
|
|
80
46
|
ignore_im_start_loss: bool = True
|
|
81
47
|
|
|
82
48
|
def encode(
|
|
83
49
|
self: "Message",
|
|
84
|
-
tokenizer:
|
|
50
|
+
tokenizer: FishTokenizer,
|
|
85
51
|
) -> EncodedMessage:
|
|
86
52
|
all_tokens = []
|
|
87
53
|
all_labels = []
|
|
88
54
|
|
|
89
55
|
# Multi-modal tokens
|
|
90
56
|
vq_parts = []
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
|
|
94
|
-
[SEMANTIC_TOKEN, MEL_TOKEN]
|
|
95
|
-
)
|
|
57
|
+
vq_masks = []
|
|
96
58
|
|
|
97
59
|
parts = self.parts.copy()
|
|
98
60
|
if self.add_im_start:
|
|
99
|
-
|
|
61
|
+
modality_token = MODALITY_TOKENS[self.modality] if self.modality else ""
|
|
62
|
+
parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n{modality_token}"))
|
|
100
63
|
|
|
101
64
|
if self.add_im_end:
|
|
102
65
|
parts.append(TextPart(text="<|im_end|>"))
|
|
103
66
|
|
|
104
67
|
for part in parts:
|
|
105
68
|
if isinstance(part, TextPart):
|
|
106
|
-
tokens =
|
|
107
|
-
part.text,
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
return_tensors="pt",
|
|
111
|
-
).int()[0]
|
|
69
|
+
tokens = torch.tensor(
|
|
70
|
+
tokenizer.encode(part.text),
|
|
71
|
+
dtype=torch.int,
|
|
72
|
+
)
|
|
112
73
|
elif isinstance(part, VQPart):
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
|
|
123
|
-
mel_parts.append(part.mels)
|
|
74
|
+
curr_codes = part.codes.clone()
|
|
75
|
+
tokens = torch.tensor(
|
|
76
|
+
[
|
|
77
|
+
tokenizer.semantic_id_to_token_id[i.item()]
|
|
78
|
+
for i in curr_codes[0].int()
|
|
79
|
+
],
|
|
80
|
+
dtype=torch.int,
|
|
81
|
+
)
|
|
82
|
+
vq_parts.append(curr_codes)
|
|
124
83
|
else:
|
|
125
84
|
raise ValueError(f"Unsupported part type: {type(part)}")
|
|
126
85
|
|
|
127
86
|
all_tokens.append(tokens)
|
|
87
|
+
if isinstance(part, VQPart):
|
|
88
|
+
vq_masks.append(torch.ones_like(tokens, dtype=torch.bool))
|
|
89
|
+
else:
|
|
90
|
+
vq_masks.append(torch.zeros_like(tokens, dtype=torch.bool))
|
|
91
|
+
|
|
128
92
|
if self.cal_loss:
|
|
129
93
|
all_labels.append(tokens.clone())
|
|
130
94
|
else:
|
|
@@ -132,7 +96,9 @@ class Message:
|
|
|
132
96
|
|
|
133
97
|
tokens = torch.cat(all_tokens, dim=0)
|
|
134
98
|
labels = torch.cat(all_labels, dim=0)
|
|
135
|
-
|
|
99
|
+
vq_masks = torch.cat(vq_masks, dim=0)
|
|
100
|
+
|
|
101
|
+
assert tokens.shape == labels.shape == vq_masks.shape
|
|
136
102
|
|
|
137
103
|
if self.ignore_im_start_loss and self.add_im_start:
|
|
138
104
|
labels[: len(all_tokens[0])] = -100
|
|
@@ -141,7 +107,8 @@ class Message:
|
|
|
141
107
|
tokens=tokens,
|
|
142
108
|
labels=labels,
|
|
143
109
|
vq_parts=vq_parts,
|
|
144
|
-
|
|
110
|
+
vq_mask_tokens=vq_masks,
|
|
111
|
+
vq_mask_labels=vq_masks,
|
|
145
112
|
)
|
|
146
113
|
|
|
147
114
|
|
|
@@ -149,17 +116,23 @@ class Message:
|
|
|
149
116
|
class Conversation:
|
|
150
117
|
messages: list[Message]
|
|
151
118
|
|
|
119
|
+
def __init__(self: "Conversation", messages: list[Message] | None = None):
|
|
120
|
+
self.messages = messages or []
|
|
121
|
+
|
|
152
122
|
def encode(
|
|
153
123
|
self: "Conversation",
|
|
154
|
-
tokenizer:
|
|
124
|
+
tokenizer: FishTokenizer,
|
|
155
125
|
add_shift: bool = True,
|
|
126
|
+
ignore_loss_tokens: list[str] = [],
|
|
156
127
|
) -> EncodedMessage:
|
|
157
128
|
# Build the input_ids and labels
|
|
158
129
|
tokens = []
|
|
159
130
|
labels = []
|
|
160
131
|
vq_parts = []
|
|
161
|
-
|
|
132
|
+
vq_mask_tokens = []
|
|
133
|
+
vq_mask_labels = []
|
|
162
134
|
vq_require_losses = []
|
|
135
|
+
ignore_loss_token_ids = [tokenizer.get_token_id(i) for i in ignore_loss_tokens]
|
|
163
136
|
|
|
164
137
|
for message in self.messages:
|
|
165
138
|
encoded = message.encode(
|
|
@@ -168,16 +141,25 @@ class Conversation:
|
|
|
168
141
|
tokens.append(encoded.tokens)
|
|
169
142
|
labels.append(encoded.labels)
|
|
170
143
|
vq_parts.extend(encoded.vq_parts)
|
|
171
|
-
|
|
144
|
+
vq_mask_tokens.append(encoded.vq_mask_tokens)
|
|
145
|
+
vq_mask_labels.append(encoded.vq_mask_labels)
|
|
172
146
|
vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
|
|
173
147
|
|
|
174
148
|
tokens = torch.cat(tokens, dim=0)
|
|
175
149
|
labels = torch.cat(labels, dim=0)
|
|
150
|
+
vq_mask_tokens = torch.cat(vq_mask_tokens, dim=0)
|
|
151
|
+
vq_mask_labels = torch.cat(vq_mask_labels, dim=0)
|
|
176
152
|
vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
|
|
177
153
|
|
|
178
154
|
if add_shift:
|
|
179
155
|
tokens = tokens[:-1]
|
|
180
156
|
labels = labels[1:]
|
|
157
|
+
vq_mask_tokens = vq_mask_tokens[:-1]
|
|
158
|
+
vq_mask_labels = vq_mask_labels[1:]
|
|
159
|
+
|
|
160
|
+
for i in ignore_loss_token_ids:
|
|
161
|
+
assert i != -100 and i is not None
|
|
162
|
+
labels[labels == i] = -100
|
|
181
163
|
|
|
182
164
|
assert tokens.dtype in [
|
|
183
165
|
torch.int,
|
|
@@ -188,15 +170,18 @@ class Conversation:
|
|
|
188
170
|
tokens=tokens,
|
|
189
171
|
labels=labels,
|
|
190
172
|
vq_parts=vq_parts,
|
|
191
|
-
|
|
173
|
+
vq_mask_tokens=vq_mask_tokens,
|
|
174
|
+
vq_mask_labels=vq_mask_labels,
|
|
192
175
|
vq_require_losses=vq_require_losses,
|
|
193
176
|
)
|
|
194
177
|
|
|
195
178
|
def encode_for_inference(
|
|
196
179
|
self: "Conversation",
|
|
197
|
-
tokenizer:
|
|
180
|
+
tokenizer: FishTokenizer,
|
|
198
181
|
num_codebooks: int,
|
|
199
182
|
) -> EncodedMessage:
|
|
183
|
+
# self.visualize(tokenizer)
|
|
184
|
+
|
|
200
185
|
encoded = self.encode(tokenizer, add_shift=False)
|
|
201
186
|
tokens = encoded.tokens
|
|
202
187
|
values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
|
|
@@ -205,24 +190,47 @@ class Conversation:
|
|
|
205
190
|
if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
|
|
206
191
|
return values
|
|
207
192
|
|
|
208
|
-
semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
|
|
209
|
-
[SEMANTIC_TOKEN, MEL_TOKEN]
|
|
210
|
-
)
|
|
211
193
|
vq_parts = encoded.vq_parts
|
|
194
|
+
vq_parts = [part.to(values.device) for part in vq_parts]
|
|
212
195
|
vq_parts = torch.cat(vq_parts, dim=1)
|
|
213
|
-
values[
|
|
196
|
+
values[0, encoded.vq_mask_tokens] = vq_parts[0] + tokenizer.semantic_begin_id
|
|
197
|
+
values[1:, encoded.vq_mask_tokens] = vq_parts
|
|
198
|
+
|
|
214
199
|
return values
|
|
215
200
|
|
|
216
|
-
def visualize(
|
|
217
|
-
|
|
201
|
+
def visualize(
|
|
202
|
+
self: "Conversation",
|
|
203
|
+
tokenizer: FishTokenizer,
|
|
204
|
+
ignore_loss_tokens: list[str] = [],
|
|
205
|
+
):
|
|
206
|
+
encoded = self.encode(
|
|
207
|
+
tokenizer, add_shift=False, ignore_loss_tokens=ignore_loss_tokens
|
|
208
|
+
)
|
|
218
209
|
|
|
219
|
-
|
|
220
|
-
|
|
210
|
+
# Colors for alternating tokens
|
|
211
|
+
colors = {
|
|
212
|
+
"blue": "\033[94m", # Light blue
|
|
213
|
+
"cyan": "\033[96m", # Cyan
|
|
214
|
+
"green": "\033[92m", # Light green
|
|
215
|
+
"dark_green": "\033[32m", # Dark green
|
|
216
|
+
}
|
|
217
|
+
blue_idx = 0
|
|
218
|
+
green_idx = 0
|
|
219
|
+
|
|
220
|
+
def print_in_blue(x):
|
|
221
|
+
nonlocal blue_idx
|
|
222
|
+
color = colors["blue"] if blue_idx % 2 == 0 else colors["cyan"]
|
|
223
|
+
print(f"{color}{x}\033[0m", end="")
|
|
224
|
+
blue_idx += 1
|
|
225
|
+
|
|
226
|
+
def print_in_green(x):
|
|
227
|
+
nonlocal green_idx
|
|
228
|
+
color = colors["green"] if green_idx % 2 == 0 else colors["dark_green"]
|
|
229
|
+
print(f"{color}{x}\033[0m", end="")
|
|
230
|
+
green_idx += 1
|
|
221
231
|
|
|
222
232
|
for tok, lab in zip(encoded.tokens, encoded.labels):
|
|
223
|
-
val = tokenizer.decode(tok
|
|
224
|
-
if val == "\n":
|
|
225
|
-
val = "\\n\n"
|
|
233
|
+
val = tokenizer.decode([tok])
|
|
226
234
|
|
|
227
235
|
if lab == -100:
|
|
228
236
|
print_in_green(val)
|
|
@@ -231,6 +239,9 @@ class Conversation:
|
|
|
231
239
|
|
|
232
240
|
print()
|
|
233
241
|
|
|
242
|
+
def append(self: "Conversation", message: Message):
|
|
243
|
+
self.messages.append(message)
|
|
244
|
+
|
|
234
245
|
|
|
235
246
|
if __name__ == "__main__":
|
|
236
247
|
message0 = Message(
|
|
@@ -248,7 +259,7 @@ if __name__ == "__main__":
|
|
|
248
259
|
cal_loss=True,
|
|
249
260
|
)
|
|
250
261
|
conversation = Conversation([message0, message1])
|
|
251
|
-
tokenizer =
|
|
262
|
+
tokenizer = FishTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
|
|
252
263
|
conversation.visualize(tokenizer)
|
|
253
264
|
|
|
254
265
|
encoded = conversation.encode(tokenizer)
|
|
@@ -16,7 +16,7 @@ from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
16
16
|
from torch.utils.checkpoint import checkpoint
|
|
17
17
|
from transformers import AutoTokenizer
|
|
18
18
|
|
|
19
|
-
from fish_speech.
|
|
19
|
+
from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
|
|
20
20
|
from fish_speech.utils import RankedLogger
|
|
21
21
|
|
|
22
22
|
from .lora import LoraConfig, setup_lora
|
|
@@ -61,6 +61,7 @@ class BaseModelArgs:
|
|
|
61
61
|
# Dummy vars
|
|
62
62
|
is_reward_model: bool = False
|
|
63
63
|
share_codebook_embeddings: bool = True
|
|
64
|
+
scale_codebook_embeddings: bool = False
|
|
64
65
|
|
|
65
66
|
def __post_init__(self):
|
|
66
67
|
if self.n_local_heads == -1:
|
|
@@ -164,13 +165,17 @@ class BaseTransformerForwardResult:
|
|
|
164
165
|
|
|
165
166
|
class BaseTransformer(nn.Module):
|
|
166
167
|
def __init__(
|
|
167
|
-
self,
|
|
168
|
+
self,
|
|
169
|
+
config: BaseModelArgs,
|
|
170
|
+
tokenizer: FishTokenizer | AutoTokenizer,
|
|
171
|
+
init_weights: bool = True,
|
|
168
172
|
) -> None:
|
|
169
173
|
super().__init__()
|
|
170
174
|
self.config = config
|
|
171
175
|
self.tokenizer = tokenizer
|
|
172
|
-
|
|
173
|
-
|
|
176
|
+
self.semantic_token_ids = [
|
|
177
|
+
tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
|
|
178
|
+
]
|
|
174
179
|
|
|
175
180
|
# Slow transformer
|
|
176
181
|
self.embeddings = nn.Embedding(
|
|
@@ -245,8 +250,10 @@ class BaseTransformer(nn.Module):
|
|
|
245
250
|
vocab_embeds = [self.embeddings(x[:, 0])]
|
|
246
251
|
for i in range(self.config.num_codebooks):
|
|
247
252
|
emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
|
|
248
|
-
|
|
249
|
-
|
|
253
|
+
semantic_token_ids_tensor = torch.tensor(
|
|
254
|
+
self.semantic_token_ids, device=x.device
|
|
255
|
+
)
|
|
256
|
+
emb[~torch.isin(x[:, 0], semantic_token_ids_tensor)] = 0
|
|
250
257
|
|
|
251
258
|
x = torch.stack(vocab_embeds, dim=3)
|
|
252
259
|
x = x.sum(dim=3)
|
|
@@ -294,20 +301,45 @@ class BaseTransformer(nn.Module):
|
|
|
294
301
|
|
|
295
302
|
def forward_generate(
|
|
296
303
|
self,
|
|
297
|
-
|
|
304
|
+
inp: Tensor,
|
|
298
305
|
input_pos: Optional[Tensor] = None,
|
|
306
|
+
vq_masks: Optional[Tensor] = None, # this is not used in fact
|
|
299
307
|
return_all: bool = False,
|
|
300
308
|
) -> BaseTransformerForwardResult:
|
|
301
309
|
# This is used for generation, optimized for torch compile
|
|
302
|
-
assert (
|
|
303
|
-
|
|
304
|
-
), "Please call setup_caches before forward_generate"
|
|
310
|
+
# assert (
|
|
311
|
+
# self.max_seq_len != -1 and self.max_batch_size != -1
|
|
312
|
+
# ), "Please call setup_caches before forward_generate"
|
|
305
313
|
|
|
306
|
-
|
|
314
|
+
embeds = []
|
|
315
|
+
for i in range(self.config.num_codebooks):
|
|
316
|
+
if self.config.share_codebook_embeddings:
|
|
317
|
+
_tokens = inp[:, i + 1] + i * self.config.codebook_size
|
|
318
|
+
else:
|
|
319
|
+
_tokens = inp[:, i + 1]
|
|
307
320
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
321
|
+
emb = self.codebook_embeddings(_tokens)
|
|
322
|
+
embeds.append(emb)
|
|
323
|
+
|
|
324
|
+
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
|
|
325
|
+
# if self.config.use_codebook_mlp:
|
|
326
|
+
# vq_embeds_sum = vq_embeds_sum / self.config.num_codebooks
|
|
327
|
+
# vq_embeds_sum = self.codebook_mlp(vq_embeds_sum)
|
|
328
|
+
|
|
329
|
+
vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
|
|
330
|
+
inp[:, 0] <= self.tokenizer.semantic_end_id
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
vq_embeds_sum[~vq_masks] = 0
|
|
334
|
+
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
|
|
335
|
+
|
|
336
|
+
if input_pos is None:
|
|
337
|
+
input_pos = torch.arange(inp.shape[-1], device=x.device)
|
|
338
|
+
max_seq_len = inp.shape[-1]
|
|
339
|
+
else:
|
|
340
|
+
max_seq_len = self.max_seq_len
|
|
341
|
+
|
|
342
|
+
mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
|
|
311
343
|
freqs_cis = self.freqs_cis[input_pos]
|
|
312
344
|
|
|
313
345
|
for layer in self.layers:
|
|
@@ -320,7 +352,9 @@ class BaseTransformer(nn.Module):
|
|
|
320
352
|
# We got slow_out here
|
|
321
353
|
slow_out = self.norm(x)
|
|
322
354
|
|
|
323
|
-
if self.config.
|
|
355
|
+
if self.config.is_reward_model:
|
|
356
|
+
token_logits = self.score_output(slow_out)
|
|
357
|
+
elif self.config.tie_word_embeddings:
|
|
324
358
|
token_logits = F.linear(slow_out, self.embeddings.weight)
|
|
325
359
|
else:
|
|
326
360
|
token_logits = self.output(slow_out)
|
|
@@ -348,6 +382,7 @@ class BaseTransformer(nn.Module):
|
|
|
348
382
|
max_length: int | None = None,
|
|
349
383
|
lora_config: LoraConfig | None = None,
|
|
350
384
|
rope_base: int | None = None,
|
|
385
|
+
is_agent: bool = False,
|
|
351
386
|
) -> "BaseTransformer":
|
|
352
387
|
config = BaseModelArgs.from_pretrained(str(path))
|
|
353
388
|
if max_length is not None:
|
|
@@ -366,7 +401,12 @@ class BaseTransformer(nn.Module):
|
|
|
366
401
|
case _:
|
|
367
402
|
raise ValueError(f"Unknown model type: {config.model_type}")
|
|
368
403
|
|
|
369
|
-
|
|
404
|
+
if is_agent:
|
|
405
|
+
tokenizer = AutoTokenizer.from_pretrained(str(path))
|
|
406
|
+
else:
|
|
407
|
+
tokenizer_path = str(path) + "/tokenizer.tiktoken"
|
|
408
|
+
tokenizer = FishTokenizer(tokenizer_path)
|
|
409
|
+
|
|
370
410
|
log.info(f"Loading model from {path}, config: {config}")
|
|
371
411
|
model = model_cls(config, tokenizer=tokenizer)
|
|
372
412
|
|
|
@@ -452,7 +492,7 @@ class BaseTransformer(nn.Module):
|
|
|
452
492
|
|
|
453
493
|
|
|
454
494
|
class NaiveTransformer(BaseTransformer):
|
|
455
|
-
def __init__(self, config: NaiveModelArgs, tokenizer:
|
|
495
|
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
|
456
496
|
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
457
497
|
|
|
458
498
|
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
|
@@ -498,7 +538,7 @@ class NaiveTransformer(BaseTransformer):
|
|
|
498
538
|
|
|
499
539
|
|
|
500
540
|
class DualARTransformer(BaseTransformer):
|
|
501
|
-
def __init__(self, config: NaiveModelArgs, tokenizer:
|
|
541
|
+
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
|
|
502
542
|
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
|
503
543
|
|
|
504
544
|
# Project to fast dim if needed
|
|
@@ -654,9 +694,12 @@ class DualARTransformer(BaseTransformer):
|
|
|
654
694
|
return codebook_logits
|
|
655
695
|
|
|
656
696
|
def forward_generate(
|
|
657
|
-
self,
|
|
697
|
+
self,
|
|
698
|
+
x: Tensor,
|
|
699
|
+
input_pos: Optional[Tensor] = None,
|
|
700
|
+
vq_masks: Optional[Tensor] = None,
|
|
658
701
|
) -> TransformerForwardResult:
|
|
659
|
-
x = super().forward_generate(x, input_pos)
|
|
702
|
+
x = super().forward_generate(x, input_pos, vq_masks)
|
|
660
703
|
x.hidden_states = self.fast_project_in(x.hidden_states)
|
|
661
704
|
return x
|
|
662
705
|
|
|
@@ -1,33 +1,8 @@
|
|
|
1
1
|
import re
|
|
2
2
|
|
|
3
3
|
SYMBOLS_MAPPING = {
|
|
4
|
-
"\n": "",
|
|
5
|
-
"…": ".",
|
|
6
|
-
"“": "'",
|
|
7
|
-
"”": "'",
|
|
8
4
|
"‘": "'",
|
|
9
5
|
"’": "'",
|
|
10
|
-
"【": "",
|
|
11
|
-
"】": "",
|
|
12
|
-
"[": "",
|
|
13
|
-
"]": "",
|
|
14
|
-
"(": "",
|
|
15
|
-
")": "",
|
|
16
|
-
"(": "",
|
|
17
|
-
")": "",
|
|
18
|
-
"・": "",
|
|
19
|
-
"·": "",
|
|
20
|
-
"「": "'",
|
|
21
|
-
"」": "'",
|
|
22
|
-
"《": "'",
|
|
23
|
-
"》": "'",
|
|
24
|
-
"—": "",
|
|
25
|
-
"~": "",
|
|
26
|
-
"~": "",
|
|
27
|
-
":": ",",
|
|
28
|
-
";": ",",
|
|
29
|
-
";": ",",
|
|
30
|
-
":": ",",
|
|
31
6
|
}
|
|
32
7
|
|
|
33
8
|
REPLACE_SYMBOL_REGEX = re.compile(
|
|
@@ -57,6 +32,6 @@ def clean_text(text):
|
|
|
57
32
|
text = EMOJI_REGEX.sub(r"", text)
|
|
58
33
|
|
|
59
34
|
# Remove continuous periods (...) and commas (,,,)
|
|
60
|
-
text = re.sub(r"[
|
|
35
|
+
text = re.sub(r"[,]{2,}", lambda m: m.group()[0], text)
|
|
61
36
|
|
|
62
37
|
return text
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import tiktoken
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# This is a modified version of the default pattern from GPT-4o, that better handles punctuations.
|
|
11
|
+
FISH_TIKTOKEN_PATTERN = "|".join(
|
|
12
|
+
[
|
|
13
|
+
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
|
|
14
|
+
r"\p{P}",
|
|
15
|
+
r"[^\r\n\p{L}\p{N}]?\p{L}+",
|
|
16
|
+
r"\p{N}",
|
|
17
|
+
r" ?[^\s\p{L}\p{N}]+[\r\n]*",
|
|
18
|
+
r"\s*[\r\n]+",
|
|
19
|
+
r"\s+(\?!\S)",
|
|
20
|
+
r"\s+",
|
|
21
|
+
]
|
|
22
|
+
)
|
|
23
|
+
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
|
24
|
+
|
|
25
|
+
BOS_TOKEN = "<|begin_of_text|>"
|
|
26
|
+
EOS_TOKEN = "<|end_of_text|>"
|
|
27
|
+
PAD_TOKEN = "<|pad|>"
|
|
28
|
+
IM_START_TOKEN = "<|im_start|>"
|
|
29
|
+
IM_END_TOKEN = "<|im_end|>"
|
|
30
|
+
|
|
31
|
+
MODALITY_TEXT_TOKEN = "<|text|>"
|
|
32
|
+
MODALITY_VOICE_TOKEN = "<|voice|>"
|
|
33
|
+
MODALITY_INTERLEAVE_TOKEN = "<|interleave|>"
|
|
34
|
+
MODALITY_TOKENS = {
|
|
35
|
+
"text": MODALITY_TEXT_TOKEN,
|
|
36
|
+
"voice": MODALITY_VOICE_TOKEN,
|
|
37
|
+
"interleave": MODALITY_INTERLEAVE_TOKEN,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
PLACEHOLDER_TOKEN = [""] * 4
|
|
41
|
+
for i in range(4):
|
|
42
|
+
PLACEHOLDER_TOKEN[i] = f"<|placeholder:{i}|>"
|
|
43
|
+
|
|
44
|
+
SEMANTIC_TOKEN_TEMPLATE = "<|semantic:{i}|>"
|
|
45
|
+
SEMANTIC_TOKENS = [SEMANTIC_TOKEN_TEMPLATE.format(i=i) for i in range(1024)]
|
|
46
|
+
|
|
47
|
+
# Warning: when you add a new special token, you should only add it to the end of the list.
|
|
48
|
+
ALL_SPECIAL_TOKENS = [
|
|
49
|
+
BOS_TOKEN,
|
|
50
|
+
EOS_TOKEN,
|
|
51
|
+
PAD_TOKEN,
|
|
52
|
+
IM_START_TOKEN,
|
|
53
|
+
IM_END_TOKEN,
|
|
54
|
+
PLACEHOLDER_TOKEN[0],
|
|
55
|
+
PLACEHOLDER_TOKEN[1],
|
|
56
|
+
PLACEHOLDER_TOKEN[2],
|
|
57
|
+
PLACEHOLDER_TOKEN[3],
|
|
58
|
+
MODALITY_TEXT_TOKEN,
|
|
59
|
+
MODALITY_VOICE_TOKEN,
|
|
60
|
+
MODALITY_INTERLEAVE_TOKEN,
|
|
61
|
+
*SEMANTIC_TOKENS,
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class FishTokenizer:
|
|
66
|
+
def __init__(self, model_path: str) -> None:
|
|
67
|
+
mergeable_ranks = self.load_tiktoken_bpe(model_path)
|
|
68
|
+
special_token_begin = len(mergeable_ranks)
|
|
69
|
+
self.all_special_tokens_with_ids = {
|
|
70
|
+
token: special_token_begin + i for i, token in enumerate(ALL_SPECIAL_TOKENS)
|
|
71
|
+
}
|
|
72
|
+
self.semantic_id_to_token_id = {
|
|
73
|
+
i: self.all_special_tokens_with_ids[token]
|
|
74
|
+
for i, token in enumerate(SEMANTIC_TOKENS)
|
|
75
|
+
}
|
|
76
|
+
self.semantic_begin_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[0]]
|
|
77
|
+
self.semantic_end_id = self.all_special_tokens_with_ids[SEMANTIC_TOKENS[-1]]
|
|
78
|
+
|
|
79
|
+
self.tkt_model = tiktoken.core.Encoding(
|
|
80
|
+
name=Path(model_path).stem,
|
|
81
|
+
pat_str=FISH_TIKTOKEN_PATTERN,
|
|
82
|
+
mergeable_ranks=mergeable_ranks,
|
|
83
|
+
special_tokens=self.all_special_tokens_with_ids,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
|
|
88
|
+
data = {}
|
|
89
|
+
for line in open(tiktoken_bpe_file).read().splitlines():
|
|
90
|
+
if not line:
|
|
91
|
+
continue
|
|
92
|
+
token, rank = line.split()
|
|
93
|
+
data[base64.b64decode(token)] = int(rank)
|
|
94
|
+
return data
|
|
95
|
+
|
|
96
|
+
def get_token_id(self, token: str) -> int:
|
|
97
|
+
return self.all_special_tokens_with_ids[token]
|
|
98
|
+
|
|
99
|
+
def encode(self, s: str, allowed_special: bool | set[str] = True) -> list[int]:
|
|
100
|
+
assert isinstance(s, str)
|
|
101
|
+
|
|
102
|
+
subs = []
|
|
103
|
+
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
|
|
104
|
+
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
|
|
105
|
+
|
|
106
|
+
if allowed_special is True:
|
|
107
|
+
allowed_special = self.tkt_model.special_tokens_set
|
|
108
|
+
elif allowed_special is False:
|
|
109
|
+
allowed_special = set()
|
|
110
|
+
|
|
111
|
+
return sum(
|
|
112
|
+
self.tkt_model.encode_batch(
|
|
113
|
+
subs, allowed_special=allowed_special, disallowed_special=set()
|
|
114
|
+
),
|
|
115
|
+
start=[],
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
def decode(self, tokens: list[int]) -> str:
|
|
119
|
+
return self.tkt_model.decode(tokens)
|
|
120
|
+
|
|
121
|
+
def save_pretrained(self, path: str):
|
|
122
|
+
path = Path(path)
|
|
123
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
124
|
+
|
|
125
|
+
with open(path / "tokenizer.tiktoken", "w") as f:
|
|
126
|
+
for token, rank in self.tkt_model._mergeable_ranks.items():
|
|
127
|
+
f.write(f"{base64.b64encode(token).decode()} {rank}\n")
|
|
128
|
+
|
|
129
|
+
with open(path / "special_tokens.json", "w") as f:
|
|
130
|
+
json.dump(
|
|
131
|
+
self.all_special_tokens_with_ids,
|
|
132
|
+
f,
|
|
133
|
+
indent=2,
|
|
134
|
+
ensure_ascii=False,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def from_pretrained(path: str):
|
|
139
|
+
return FishTokenizer(Path(path) / "tokenizer.tiktoken")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
if __name__ == "__main__":
|
|
143
|
+
tokenizer = FishTokenizer("data/mpacks/v1.4-pretrain/tokenizer.all.tiktoken")
|
|
144
|
+
tokenizer.save_pretrained("checkpoints/fish-speech-0.5B")
|
|
145
|
+
tokenizer = FishTokenizer.from_pretrained("checkpoints/fish-speech-0.5B")
|
|
146
|
+
|
|
147
|
+
print(
|
|
148
|
+
[
|
|
149
|
+
tokenizer.decode([i])
|
|
150
|
+
for i in tokenizer.encode(f"{BOS_TOKEN}你好,世界!{EOS_TOKEN}")
|
|
151
|
+
]
|
|
152
|
+
)
|