xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,698 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import queue
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
import hydra
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
import torch._dynamo.config
|
|
14
|
+
import torch._inductor.config
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
|
19
|
+
from fish_speech.text import clean_text, split_text
|
|
20
|
+
|
|
21
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
22
|
+
torch._inductor.config.coordinate_descent_tuning = True
|
|
23
|
+
torch._inductor.config.triton.unique_kernel_names = True
|
|
24
|
+
|
|
25
|
+
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
26
|
+
# Experimental feature to reduce compilation times, will be on by default in future
|
|
27
|
+
torch._inductor.config.fx_graph_cache = True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
from fish_speech.models.text2semantic.llama import (
|
|
31
|
+
BaseTransformer,
|
|
32
|
+
DualARTransformer,
|
|
33
|
+
NaiveTransformer,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def multinomial_sample_one_no_sync(
|
|
38
|
+
probs_sort,
|
|
39
|
+
): # Does multinomial sampling without a cuda synchronization
|
|
40
|
+
q = torch.empty_like(probs_sort).exponential_(1)
|
|
41
|
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def logits_to_probs(
|
|
45
|
+
logits,
|
|
46
|
+
previous_tokens: Optional[torch.Tensor] = None,
|
|
47
|
+
temperature: torch.Tensor = 1.0,
|
|
48
|
+
top_p: torch.Tensor = 1.0,
|
|
49
|
+
repetition_penalty: torch.Tensor = 1.0,
|
|
50
|
+
) -> torch.Tensor:
|
|
51
|
+
# Apply repetition penalty
|
|
52
|
+
if previous_tokens is not None:
|
|
53
|
+
previous_tokens = previous_tokens.long()
|
|
54
|
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
|
55
|
+
score = torch.where(
|
|
56
|
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
|
57
|
+
)
|
|
58
|
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
|
59
|
+
|
|
60
|
+
# Apply top-p sampling
|
|
61
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
62
|
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
63
|
+
sorted_indices_to_remove = cum_probs > top_p
|
|
64
|
+
sorted_indices_to_remove[0] = False # keep at least one option
|
|
65
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
66
|
+
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
|
67
|
+
)
|
|
68
|
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
69
|
+
|
|
70
|
+
logits = logits / max(temperature, 1e-5)
|
|
71
|
+
|
|
72
|
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
73
|
+
return probs
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def sample(
|
|
77
|
+
logits,
|
|
78
|
+
previous_tokens: Optional[torch.Tensor] = None,
|
|
79
|
+
**sampling_kwargs,
|
|
80
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
81
|
+
probs = logits_to_probs(
|
|
82
|
+
logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
|
83
|
+
)
|
|
84
|
+
idx_next = multinomial_sample_one_no_sync(probs)
|
|
85
|
+
return idx_next, probs
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def decode_one_token_ar(
|
|
89
|
+
model: DualARTransformer,
|
|
90
|
+
x: torch.Tensor,
|
|
91
|
+
input_pos: torch.Tensor,
|
|
92
|
+
previous_tokens: torch.Tensor = None,
|
|
93
|
+
**sampling_kwargs,
|
|
94
|
+
) -> torch.Tensor:
|
|
95
|
+
x = model.forward_generate(x, input_pos)
|
|
96
|
+
codebooks = [
|
|
97
|
+
sample(
|
|
98
|
+
x.logits,
|
|
99
|
+
previous_tokens=(
|
|
100
|
+
previous_tokens[0] if previous_tokens is not None else None
|
|
101
|
+
), # Disable repetition penalty for the token codebook
|
|
102
|
+
**sampling_kwargs,
|
|
103
|
+
)[0]
|
|
104
|
+
]
|
|
105
|
+
x = x.hidden_states
|
|
106
|
+
|
|
107
|
+
# Cleanup the cache
|
|
108
|
+
for layer in model.fast_layers:
|
|
109
|
+
layer.attention.kv_cache.k_cache.fill_(0)
|
|
110
|
+
layer.attention.kv_cache.v_cache.fill_(0)
|
|
111
|
+
|
|
112
|
+
for codebook_idx in range(model.config.num_codebooks):
|
|
113
|
+
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
|
114
|
+
logits = model.forward_generate_fast(x, input_pos)
|
|
115
|
+
a = sample(
|
|
116
|
+
logits,
|
|
117
|
+
previous_tokens=(
|
|
118
|
+
previous_tokens[codebook_idx + 1]
|
|
119
|
+
if previous_tokens is not None
|
|
120
|
+
else None
|
|
121
|
+
),
|
|
122
|
+
**sampling_kwargs,
|
|
123
|
+
)[0]
|
|
124
|
+
x = model.fast_embeddings(a)
|
|
125
|
+
codebooks.append(a)
|
|
126
|
+
|
|
127
|
+
return torch.stack(codebooks, dim=0)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def decode_one_token_naive(
|
|
131
|
+
model: NaiveTransformer,
|
|
132
|
+
x: torch.Tensor,
|
|
133
|
+
input_pos: torch.Tensor,
|
|
134
|
+
previous_tokens: torch.Tensor = None,
|
|
135
|
+
**sampling_kwargs,
|
|
136
|
+
) -> torch.Tensor:
|
|
137
|
+
x = model.forward_generate(x, input_pos)
|
|
138
|
+
|
|
139
|
+
codebooks = [
|
|
140
|
+
sample(
|
|
141
|
+
x.token_logits,
|
|
142
|
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
143
|
+
**sampling_kwargs,
|
|
144
|
+
)[0]
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
for i in range(model.config.num_codebooks):
|
|
148
|
+
codebooks.append(
|
|
149
|
+
sample(
|
|
150
|
+
x.codebook_logits[:, :, i],
|
|
151
|
+
previous_tokens=(
|
|
152
|
+
previous_tokens[i + 1] if previous_tokens is not None else None
|
|
153
|
+
),
|
|
154
|
+
**sampling_kwargs,
|
|
155
|
+
)[0]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return torch.stack(codebooks, dim=0)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def decode_n_tokens(
|
|
162
|
+
model: NaiveTransformer,
|
|
163
|
+
cur_token: torch.Tensor,
|
|
164
|
+
input_pos: torch.Tensor,
|
|
165
|
+
num_new_tokens: int,
|
|
166
|
+
im_end_id: int = 4,
|
|
167
|
+
decode_one_token=decode_one_token_naive,
|
|
168
|
+
**sampling_kwargs,
|
|
169
|
+
):
|
|
170
|
+
previous_tokens = torch.zeros(
|
|
171
|
+
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
|
172
|
+
dtype=torch.int,
|
|
173
|
+
device=cur_token.device,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
for i in tqdm(range(num_new_tokens)):
|
|
177
|
+
# We need to get windowed repeat penalty
|
|
178
|
+
win_size = 16
|
|
179
|
+
if i < win_size:
|
|
180
|
+
window = previous_tokens[:, :win_size]
|
|
181
|
+
else:
|
|
182
|
+
window = previous_tokens[:, i - win_size : i]
|
|
183
|
+
|
|
184
|
+
with torch.backends.cuda.sdp_kernel(
|
|
185
|
+
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
186
|
+
): # Actually better for Inductor to codegen attention here
|
|
187
|
+
next_token = decode_one_token(
|
|
188
|
+
model=model,
|
|
189
|
+
x=cur_token,
|
|
190
|
+
input_pos=input_pos,
|
|
191
|
+
previous_tokens=window,
|
|
192
|
+
**sampling_kwargs,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
input_pos += 1
|
|
196
|
+
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
|
197
|
+
previous_tokens[:, i : i + 1] = next_token.view(
|
|
198
|
+
model.config.num_codebooks + 1, -1
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if cur_token[0, 0, -1] == im_end_id:
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
return previous_tokens[:, : i + 1]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@torch.no_grad()
|
|
208
|
+
@torch.inference_mode()
|
|
209
|
+
def generate(
|
|
210
|
+
*,
|
|
211
|
+
model: NaiveTransformer,
|
|
212
|
+
prompt: torch.Tensor,
|
|
213
|
+
max_new_tokens: int,
|
|
214
|
+
im_end_id: int = 4,
|
|
215
|
+
decode_one_token=decode_one_token_naive,
|
|
216
|
+
**sampling_kwargs,
|
|
217
|
+
) -> torch.Tensor:
|
|
218
|
+
"""
|
|
219
|
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
223
|
+
T = prompt.size(1)
|
|
224
|
+
|
|
225
|
+
if max_new_tokens:
|
|
226
|
+
if T + max_new_tokens > model.config.max_seq_len:
|
|
227
|
+
max_new_tokens = model.config.max_seq_len - T
|
|
228
|
+
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
229
|
+
|
|
230
|
+
T_new = T + max_new_tokens
|
|
231
|
+
else:
|
|
232
|
+
T_new = model.config.max_seq_len
|
|
233
|
+
max_new_tokens = T_new - T
|
|
234
|
+
|
|
235
|
+
device, dtype = prompt.device, prompt.dtype
|
|
236
|
+
with torch.device(device):
|
|
237
|
+
model.setup_caches(
|
|
238
|
+
max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
codebook_dim = 1 + model.config.num_codebooks
|
|
242
|
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
243
|
+
empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
|
|
244
|
+
empty[:, :T] = prompt
|
|
245
|
+
seq = empty
|
|
246
|
+
input_pos = torch.arange(0, T, device=device)
|
|
247
|
+
|
|
248
|
+
# Use non-accelerated version for now, to avoid compilation overhead
|
|
249
|
+
prefill_decode = (
|
|
250
|
+
decode_one_token_naive
|
|
251
|
+
if isinstance(model, NaiveTransformer)
|
|
252
|
+
else decode_one_token_ar
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
next_token = prefill_decode(
|
|
256
|
+
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
|
|
257
|
+
)
|
|
258
|
+
seq[:, T : T + 1] = next_token
|
|
259
|
+
|
|
260
|
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
|
261
|
+
x = decode_n_tokens(
|
|
262
|
+
model,
|
|
263
|
+
next_token.view(1, codebook_dim, -1),
|
|
264
|
+
input_pos,
|
|
265
|
+
max_new_tokens - 1,
|
|
266
|
+
im_end_id=im_end_id,
|
|
267
|
+
decode_one_token=decode_one_token,
|
|
268
|
+
**sampling_kwargs,
|
|
269
|
+
)
|
|
270
|
+
# x = torch.cat(generated_tokens, dim=1)
|
|
271
|
+
seq = seq[:, : T + 1 + x.size(1)]
|
|
272
|
+
seq[:, T + 1 :] = x
|
|
273
|
+
|
|
274
|
+
return seq
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def encode_tokens(
|
|
278
|
+
tokenizer,
|
|
279
|
+
string,
|
|
280
|
+
device="cuda",
|
|
281
|
+
prompt_tokens=None,
|
|
282
|
+
num_codebooks=4,
|
|
283
|
+
):
|
|
284
|
+
string = clean_text(string)
|
|
285
|
+
string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
|
|
286
|
+
|
|
287
|
+
new_tokens = tokenizer.encode(
|
|
288
|
+
string,
|
|
289
|
+
add_special_tokens=False,
|
|
290
|
+
max_length=10**6,
|
|
291
|
+
truncation=False,
|
|
292
|
+
)
|
|
293
|
+
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
|
294
|
+
|
|
295
|
+
# Codebooks
|
|
296
|
+
zeros = (
|
|
297
|
+
torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
|
298
|
+
* CODEBOOK_PAD_TOKEN_ID
|
|
299
|
+
)
|
|
300
|
+
prompt = torch.cat((tokens, zeros), dim=0)
|
|
301
|
+
|
|
302
|
+
if prompt_tokens is None:
|
|
303
|
+
return prompt
|
|
304
|
+
|
|
305
|
+
# Get prompt tokens
|
|
306
|
+
if prompt_tokens.ndim == 3:
|
|
307
|
+
assert (
|
|
308
|
+
prompt_tokens.shape[0] == 1
|
|
309
|
+
), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
|
|
310
|
+
prompt_tokens = prompt_tokens[0]
|
|
311
|
+
|
|
312
|
+
assert prompt_tokens.ndim == 2
|
|
313
|
+
data = prompt_tokens + 1
|
|
314
|
+
|
|
315
|
+
if prompt_tokens.shape[0] > num_codebooks:
|
|
316
|
+
logger.warning(
|
|
317
|
+
f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
|
318
|
+
)
|
|
319
|
+
data = data[:num_codebooks]
|
|
320
|
+
|
|
321
|
+
# Add pad token for each codebook
|
|
322
|
+
data = torch.cat(
|
|
323
|
+
(data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
|
|
324
|
+
dim=1,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Since 1.0, we use <|semantic|>
|
|
328
|
+
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
329
|
+
end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
330
|
+
main_token_ids = (
|
|
331
|
+
torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
|
|
332
|
+
)
|
|
333
|
+
main_token_ids[0, -1] = end_token_id
|
|
334
|
+
|
|
335
|
+
data = torch.cat((main_token_ids, data), dim=0)
|
|
336
|
+
prompt = torch.cat((prompt, data), dim=1)
|
|
337
|
+
|
|
338
|
+
return prompt
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def load_model(checkpoint_path, device, precision, compile=False):
|
|
342
|
+
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
|
343
|
+
checkpoint_path, load_weights=True
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
model = model.to(device=device, dtype=precision)
|
|
347
|
+
logger.info(f"Restored model from checkpoint")
|
|
348
|
+
|
|
349
|
+
if isinstance(model, DualARTransformer):
|
|
350
|
+
decode_one_token = decode_one_token_ar
|
|
351
|
+
logger.info("Using DualARTransformer")
|
|
352
|
+
else:
|
|
353
|
+
decode_one_token = decode_one_token_naive
|
|
354
|
+
logger.info("Using NaiveTransformer")
|
|
355
|
+
|
|
356
|
+
if compile:
|
|
357
|
+
logger.info("Compiling function...")
|
|
358
|
+
decode_one_token = torch.compile(
|
|
359
|
+
decode_one_token,
|
|
360
|
+
fullgraph=True,
|
|
361
|
+
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
|
362
|
+
mode="reduce-overhead" if torch.cuda.is_available() else None,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
return model.eval(), decode_one_token
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@dataclass
|
|
369
|
+
class GenerateResponse:
|
|
370
|
+
action: Literal["sample", "next"]
|
|
371
|
+
codes: Optional[torch.Tensor] = None
|
|
372
|
+
text: Optional[str] = None
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def generate_long(
|
|
376
|
+
*,
|
|
377
|
+
model,
|
|
378
|
+
device: str | torch.device,
|
|
379
|
+
decode_one_token: callable,
|
|
380
|
+
text: str,
|
|
381
|
+
num_samples: int = 1,
|
|
382
|
+
max_new_tokens: int = 0,
|
|
383
|
+
top_p: int = 0.7,
|
|
384
|
+
repetition_penalty: float = 1.5,
|
|
385
|
+
temperature: float = 0.7,
|
|
386
|
+
compile: bool = False,
|
|
387
|
+
iterative_prompt: bool = True,
|
|
388
|
+
max_length: int = 2048,
|
|
389
|
+
chunk_length: int = 150,
|
|
390
|
+
prompt_text: Optional[str | list[str]] = None,
|
|
391
|
+
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
|
392
|
+
):
|
|
393
|
+
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
|
394
|
+
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
|
395
|
+
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
|
396
|
+
|
|
397
|
+
use_prompt = prompt_text is not None and prompt_tokens is not None
|
|
398
|
+
if use_prompt and isinstance(prompt_text, str):
|
|
399
|
+
prompt_text = [prompt_text]
|
|
400
|
+
prompt_tokens = [prompt_tokens]
|
|
401
|
+
|
|
402
|
+
assert use_prompt is False or len(prompt_text) == len(
|
|
403
|
+
prompt_tokens
|
|
404
|
+
), "Prompt text and tokens must have the same length"
|
|
405
|
+
|
|
406
|
+
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
407
|
+
tokenizer = model.tokenizer
|
|
408
|
+
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
409
|
+
|
|
410
|
+
encoded = []
|
|
411
|
+
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
412
|
+
encoded_prompts = []
|
|
413
|
+
|
|
414
|
+
if use_prompt:
|
|
415
|
+
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
|
416
|
+
encoded_prompts.append(
|
|
417
|
+
encode_tokens(
|
|
418
|
+
tokenizer,
|
|
419
|
+
string=t,
|
|
420
|
+
device=device,
|
|
421
|
+
prompt_tokens=c,
|
|
422
|
+
num_codebooks=model.config.num_codebooks,
|
|
423
|
+
)
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
for idx, text in enumerate(texts):
|
|
427
|
+
encoded.append(
|
|
428
|
+
encode_tokens(
|
|
429
|
+
tokenizer,
|
|
430
|
+
string=text,
|
|
431
|
+
device=device,
|
|
432
|
+
num_codebooks=model.config.num_codebooks,
|
|
433
|
+
)
|
|
434
|
+
)
|
|
435
|
+
logger.info(f"Encoded text: {text}")
|
|
436
|
+
|
|
437
|
+
# Move temperature, top_p, repetition_penalty to device
|
|
438
|
+
# This is important so that changing params doesn't trigger recompile
|
|
439
|
+
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
|
440
|
+
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
|
441
|
+
repetition_penalty = torch.tensor(
|
|
442
|
+
repetition_penalty, device=device, dtype=torch.float
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
for sample_idx in range(num_samples):
|
|
446
|
+
if torch.cuda.is_available():
|
|
447
|
+
torch.cuda.synchronize()
|
|
448
|
+
|
|
449
|
+
global_encoded = []
|
|
450
|
+
seg_idx = 0
|
|
451
|
+
|
|
452
|
+
while seg_idx < len(encoded):
|
|
453
|
+
logger.info(
|
|
454
|
+
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
seg = encoded[seg_idx]
|
|
458
|
+
global_encoded.append(seg)
|
|
459
|
+
|
|
460
|
+
lengths = reversed([seg.size(1) for seg in global_encoded])
|
|
461
|
+
|
|
462
|
+
# Pick last 2000 tokens
|
|
463
|
+
count = 0
|
|
464
|
+
for i, length in enumerate(lengths):
|
|
465
|
+
count += length
|
|
466
|
+
if count + length > max_length - 1024 - sum(
|
|
467
|
+
t.shape[1] for t in encoded_prompts
|
|
468
|
+
):
|
|
469
|
+
break
|
|
470
|
+
|
|
471
|
+
if i != 0 and i % 2 == 0:
|
|
472
|
+
i -= 1
|
|
473
|
+
|
|
474
|
+
# Rotate the list, always make sure first segment is included to avoid drift
|
|
475
|
+
if i < len(global_encoded) - 2:
|
|
476
|
+
partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
|
477
|
+
else:
|
|
478
|
+
partial_encoded = global_encoded
|
|
479
|
+
|
|
480
|
+
if use_prompt:
|
|
481
|
+
partial_encoded = encoded_prompts + partial_encoded
|
|
482
|
+
|
|
483
|
+
cat_encoded = torch.cat(partial_encoded, dim=1)
|
|
484
|
+
prompt_length = cat_encoded.size(1)
|
|
485
|
+
|
|
486
|
+
t0 = time.perf_counter()
|
|
487
|
+
y = generate(
|
|
488
|
+
model=model,
|
|
489
|
+
prompt=cat_encoded,
|
|
490
|
+
max_new_tokens=max_new_tokens,
|
|
491
|
+
im_end_id=im_end_id,
|
|
492
|
+
decode_one_token=decode_one_token,
|
|
493
|
+
temperature=temperature,
|
|
494
|
+
top_p=top_p,
|
|
495
|
+
repetition_penalty=repetition_penalty,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
if sample_idx == 0 and seg_idx == 0 and compile:
|
|
499
|
+
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
|
500
|
+
|
|
501
|
+
if torch.cuda.is_available():
|
|
502
|
+
torch.cuda.synchronize()
|
|
503
|
+
|
|
504
|
+
t = time.perf_counter() - t0
|
|
505
|
+
|
|
506
|
+
tokens_generated = y.size(1) - prompt_length
|
|
507
|
+
tokens_sec = tokens_generated / t
|
|
508
|
+
logger.info(
|
|
509
|
+
f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
|
510
|
+
)
|
|
511
|
+
logger.info(
|
|
512
|
+
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
if torch.cuda.is_available():
|
|
516
|
+
logger.info(
|
|
517
|
+
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# Put the generated tokens
|
|
521
|
+
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
|
522
|
+
codes = y[1:, prompt_length:-1].clone()
|
|
523
|
+
codes = codes - 1
|
|
524
|
+
assert (codes >= 0).all(), f"Negative code found"
|
|
525
|
+
|
|
526
|
+
decoded = y[:, prompt_length:-1].clone()
|
|
527
|
+
# But for global encoding, we should keep the <im_end> token
|
|
528
|
+
|
|
529
|
+
global_encoded.append(decoded)
|
|
530
|
+
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
|
531
|
+
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
|
532
|
+
seg_idx += 1
|
|
533
|
+
|
|
534
|
+
# This indicates the end of the current sample
|
|
535
|
+
yield GenerateResponse(action="next")
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
@dataclass
|
|
539
|
+
class WrappedGenerateResponse:
|
|
540
|
+
status: Literal["success", "error"]
|
|
541
|
+
response: Optional[GenerateResponse | Exception] = None
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
@dataclass
|
|
545
|
+
class GenerateRequest:
|
|
546
|
+
request: dict
|
|
547
|
+
response_queue: queue.Queue
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def launch_thread_safe_queue(
|
|
551
|
+
checkpoint_path,
|
|
552
|
+
device,
|
|
553
|
+
precision,
|
|
554
|
+
compile: bool = False,
|
|
555
|
+
):
|
|
556
|
+
input_queue = queue.Queue()
|
|
557
|
+
init_event = threading.Event()
|
|
558
|
+
|
|
559
|
+
def worker():
|
|
560
|
+
model, decode_one_token = load_model(
|
|
561
|
+
checkpoint_path, device, precision, compile=compile
|
|
562
|
+
)
|
|
563
|
+
init_event.set()
|
|
564
|
+
|
|
565
|
+
while True:
|
|
566
|
+
item: GenerateRequest | None = input_queue.get()
|
|
567
|
+
if item is None:
|
|
568
|
+
break
|
|
569
|
+
|
|
570
|
+
kwargs = item.request
|
|
571
|
+
response_queue = item.response_queue
|
|
572
|
+
|
|
573
|
+
try:
|
|
574
|
+
for chunk in generate_long(
|
|
575
|
+
model=model, decode_one_token=decode_one_token, **kwargs
|
|
576
|
+
):
|
|
577
|
+
response_queue.put(
|
|
578
|
+
WrappedGenerateResponse(status="success", response=chunk)
|
|
579
|
+
)
|
|
580
|
+
except Exception as e:
|
|
581
|
+
response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
|
582
|
+
|
|
583
|
+
threading.Thread(target=worker, daemon=True).start()
|
|
584
|
+
init_event.wait()
|
|
585
|
+
|
|
586
|
+
return input_queue
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@click.command()
|
|
590
|
+
@click.option(
|
|
591
|
+
"--text",
|
|
592
|
+
type=str,
|
|
593
|
+
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
594
|
+
)
|
|
595
|
+
@click.option("--prompt-text", type=str, default=None, multiple=True)
|
|
596
|
+
@click.option(
|
|
597
|
+
"--prompt-tokens",
|
|
598
|
+
type=click.Path(path_type=Path, exists=True),
|
|
599
|
+
default=None,
|
|
600
|
+
multiple=True,
|
|
601
|
+
)
|
|
602
|
+
@click.option("--num-samples", type=int, default=1)
|
|
603
|
+
@click.option("--max-new-tokens", type=int, default=0)
|
|
604
|
+
@click.option("--top-p", type=float, default=0.7)
|
|
605
|
+
@click.option("--repetition-penalty", type=float, default=1.2)
|
|
606
|
+
@click.option("--temperature", type=float, default=0.7)
|
|
607
|
+
@click.option(
|
|
608
|
+
"--checkpoint-path",
|
|
609
|
+
type=click.Path(path_type=Path, exists=True),
|
|
610
|
+
default="checkpoints/fish-speech-1.2-sft",
|
|
611
|
+
)
|
|
612
|
+
@click.option("--device", type=str, default="cuda")
|
|
613
|
+
@click.option("--compile/--no-compile", default=False)
|
|
614
|
+
@click.option("--seed", type=int, default=42)
|
|
615
|
+
@click.option("--half/--no-half", default=False)
|
|
616
|
+
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
|
617
|
+
@click.option("--chunk-length", type=int, default=100)
|
|
618
|
+
def main(
|
|
619
|
+
text: str,
|
|
620
|
+
prompt_text: Optional[list[str]],
|
|
621
|
+
prompt_tokens: Optional[list[Path]],
|
|
622
|
+
num_samples: int,
|
|
623
|
+
max_new_tokens: int,
|
|
624
|
+
top_p: int,
|
|
625
|
+
repetition_penalty: float,
|
|
626
|
+
temperature: float,
|
|
627
|
+
checkpoint_path: Path,
|
|
628
|
+
device: str,
|
|
629
|
+
compile: bool,
|
|
630
|
+
seed: int,
|
|
631
|
+
half: bool,
|
|
632
|
+
iterative_prompt: bool,
|
|
633
|
+
chunk_length: int,
|
|
634
|
+
) -> None:
|
|
635
|
+
|
|
636
|
+
precision = torch.half if half else torch.bfloat16
|
|
637
|
+
|
|
638
|
+
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
|
639
|
+
raise ValueError(
|
|
640
|
+
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
logger.info("Loading model ...")
|
|
644
|
+
t0 = time.time()
|
|
645
|
+
model, decode_one_token = load_model(
|
|
646
|
+
checkpoint_path, device, precision, compile=compile
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
if torch.cuda.is_available():
|
|
650
|
+
torch.cuda.synchronize()
|
|
651
|
+
|
|
652
|
+
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
|
653
|
+
|
|
654
|
+
if prompt_tokens is not None:
|
|
655
|
+
prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
|
|
656
|
+
|
|
657
|
+
torch.manual_seed(seed)
|
|
658
|
+
|
|
659
|
+
if torch.cuda.is_available():
|
|
660
|
+
torch.cuda.manual_seed(seed)
|
|
661
|
+
|
|
662
|
+
generator = generate_long(
|
|
663
|
+
model=model,
|
|
664
|
+
device=device,
|
|
665
|
+
decode_one_token=decode_one_token,
|
|
666
|
+
text=text,
|
|
667
|
+
num_samples=num_samples,
|
|
668
|
+
max_new_tokens=max_new_tokens,
|
|
669
|
+
top_p=top_p,
|
|
670
|
+
repetition_penalty=repetition_penalty,
|
|
671
|
+
temperature=temperature,
|
|
672
|
+
compile=compile,
|
|
673
|
+
iterative_prompt=iterative_prompt,
|
|
674
|
+
chunk_length=chunk_length,
|
|
675
|
+
prompt_text=prompt_text,
|
|
676
|
+
prompt_tokens=prompt_tokens,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
idx = 0
|
|
680
|
+
codes = []
|
|
681
|
+
|
|
682
|
+
for response in generator:
|
|
683
|
+
if response.action == "sample":
|
|
684
|
+
codes.append(response.codes)
|
|
685
|
+
logger.info(f"Sampled text: {response.text}")
|
|
686
|
+
elif response.action == "next":
|
|
687
|
+
if codes:
|
|
688
|
+
np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
|
|
689
|
+
logger.info(f"Saved codes to codes_{idx}.npy")
|
|
690
|
+
logger.info(f"Next sample")
|
|
691
|
+
codes = []
|
|
692
|
+
idx += 1
|
|
693
|
+
else:
|
|
694
|
+
logger.error(f"Error: {response}")
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
if __name__ == "__main__":
|
|
698
|
+
main()
|