xinference 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +50 -1
- xinference/client/restful/restful_client.py +82 -2
- xinference/constants.py +3 -0
- xinference/core/chat_interface.py +297 -83
- xinference/core/model.py +1 -0
- xinference/core/progress_tracker.py +16 -8
- xinference/core/supervisor.py +45 -1
- xinference/core/worker.py +262 -37
- xinference/deploy/cmdline.py +33 -1
- xinference/model/audio/core.py +11 -1
- xinference/model/audio/megatts.py +105 -0
- xinference/model/audio/model_spec.json +24 -1
- xinference/model/audio/model_spec_modelscope.json +26 -1
- xinference/model/core.py +14 -0
- xinference/model/embedding/core.py +6 -1
- xinference/model/flexible/core.py +6 -1
- xinference/model/image/core.py +6 -1
- xinference/model/image/model_spec.json +17 -1
- xinference/model/image/model_spec_modelscope.json +17 -1
- xinference/model/llm/__init__.py +0 -4
- xinference/model/llm/core.py +4 -0
- xinference/model/llm/llama_cpp/core.py +40 -16
- xinference/model/llm/llm_family.json +413 -84
- xinference/model/llm/llm_family.py +24 -1
- xinference/model/llm/llm_family_modelscope.json +447 -0
- xinference/model/llm/mlx/core.py +16 -2
- xinference/model/llm/transformers/__init__.py +14 -0
- xinference/model/llm/transformers/core.py +30 -6
- xinference/model/llm/transformers/gemma3.py +17 -2
- xinference/model/llm/transformers/intern_vl.py +28 -18
- xinference/model/llm/transformers/minicpmv26.py +21 -2
- xinference/model/llm/transformers/qwen-omni.py +308 -0
- xinference/model/llm/transformers/qwen2_audio.py +1 -1
- xinference/model/llm/transformers/qwen2_vl.py +20 -4
- xinference/model/llm/utils.py +11 -1
- xinference/model/llm/vllm/core.py +35 -0
- xinference/model/llm/vllm/distributed_executor.py +8 -2
- xinference/model/rerank/core.py +6 -1
- xinference/model/utils.py +118 -1
- xinference/model/video/core.py +6 -1
- xinference/thirdparty/megatts3/__init__.py +0 -0
- xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
- xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
- xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
- xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
- xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
- xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
- xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
- xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
- xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
- xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
- xinference/types.py +10 -0
- xinference/utils.py +54 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
- xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
- xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
- xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
- xinference/web/ui/src/locales/en.json +2 -1
- xinference/web/ui/src/locales/zh.json +2 -1
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/METADATA +127 -114
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/RECORD +96 -60
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
- xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
- xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
- xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
- /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
- {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
# Copyright 2025 ByteDance and/or its affiliates.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def make_positions(tensor, padding_idx):
|
|
21
|
+
"""Replace non-padding symbols with their position numbers.
|
|
22
|
+
|
|
23
|
+
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
|
24
|
+
"""
|
|
25
|
+
# The series of casts and type-conversions here are carefully
|
|
26
|
+
# balanced to both work with ONNX export and XLA. In particular XLA
|
|
27
|
+
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
|
28
|
+
# how to handle the dtype kwarg in cumsum.
|
|
29
|
+
mask = tensor.ne(padding_idx).int()
|
|
30
|
+
return (
|
|
31
|
+
torch.cumsum(mask, dim=1).type_as(mask) * mask
|
|
32
|
+
).long() + padding_idx
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def softmax(x, dim):
|
|
36
|
+
return F.softmax(x, dim=dim, dtype=torch.float32)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
|
|
40
|
+
if maxlen is None:
|
|
41
|
+
maxlen = lengths.max()
|
|
42
|
+
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
|
|
43
|
+
mask.type(dtype)
|
|
44
|
+
return mask
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def weights_nonzero_speech(target):
|
|
48
|
+
# target : B x T x mel
|
|
49
|
+
# Assign weight 1.0 to all labels except for padding (id=0).
|
|
50
|
+
dim = target.size(-1)
|
|
51
|
+
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_full_incremental_state_key(module_instance, key):
|
|
58
|
+
module_name = module_instance.__class__.__name__
|
|
59
|
+
|
|
60
|
+
# assign a unique ID to each module instance, so that incremental state is
|
|
61
|
+
# not shared across module instances
|
|
62
|
+
if not hasattr(module_instance, '_instance_id'):
|
|
63
|
+
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
|
|
64
|
+
module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
|
|
65
|
+
|
|
66
|
+
return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_incremental_state(module, incremental_state, key):
|
|
70
|
+
"""Helper for getting incremental state for an nn.Module."""
|
|
71
|
+
full_key = _get_full_incremental_state_key(module, key)
|
|
72
|
+
if incremental_state is None or full_key not in incremental_state:
|
|
73
|
+
return None
|
|
74
|
+
return incremental_state[full_key]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def set_incremental_state(module, incremental_state, key, value):
|
|
78
|
+
"""Helper for setting incremental state for an nn.Module."""
|
|
79
|
+
if incremental_state is not None:
|
|
80
|
+
full_key = _get_full_incremental_state_key(module, key)
|
|
81
|
+
incremental_state[full_key] = value
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def fill_with_neg_inf(t):
|
|
85
|
+
"""FP16-compatible function that fills a tensor with -inf."""
|
|
86
|
+
return t.float().fill_(float('-inf')).type_as(t)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def fill_with_neg_inf2(t):
|
|
90
|
+
"""FP16-compatible function that fills a tensor with -inf."""
|
|
91
|
+
return t.float().fill_(-1e8).type_as(t)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def select_attn(attn_logits, type='best'):
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
:param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
|
|
98
|
+
:return:
|
|
99
|
+
"""
|
|
100
|
+
encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
|
|
101
|
+
# [n_layers * n_head, B, T_sp, T_txt]
|
|
102
|
+
encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
|
|
103
|
+
if type == 'best':
|
|
104
|
+
indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
|
|
105
|
+
encdec_attn = encdec_attn.gather(
|
|
106
|
+
0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
|
|
107
|
+
return encdec_attn
|
|
108
|
+
elif type == 'mean':
|
|
109
|
+
return encdec_attn.mean(0)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
|
113
|
+
"""Make mask tensor containing indices of padded part.
|
|
114
|
+
Args:
|
|
115
|
+
lengths (LongTensor or List): Batch of lengths (B,).
|
|
116
|
+
xs (Tensor, optional): The reference tensor.
|
|
117
|
+
If set, masks will be the same shape as this tensor.
|
|
118
|
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
|
119
|
+
See the example.
|
|
120
|
+
Returns:
|
|
121
|
+
Tensor: Mask tensor containing indices of padded part.
|
|
122
|
+
dtype=torch.uint8 in PyTorch 1.2-
|
|
123
|
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
|
124
|
+
Examples:
|
|
125
|
+
With only lengths.
|
|
126
|
+
>>> lengths = [5, 3, 2]
|
|
127
|
+
>>> make_non_pad_mask(lengths)
|
|
128
|
+
masks = [[0, 0, 0, 0 ,0],
|
|
129
|
+
[0, 0, 0, 1, 1],
|
|
130
|
+
[0, 0, 1, 1, 1]]
|
|
131
|
+
With the reference tensor.
|
|
132
|
+
>>> xs = torch.zeros((3, 2, 4))
|
|
133
|
+
>>> make_pad_mask(lengths, xs)
|
|
134
|
+
tensor([[[0, 0, 0, 0],
|
|
135
|
+
[0, 0, 0, 0]],
|
|
136
|
+
[[0, 0, 0, 1],
|
|
137
|
+
[0, 0, 0, 1]],
|
|
138
|
+
[[0, 0, 1, 1],
|
|
139
|
+
[0, 0, 1, 1]]], dtype=torch.uint8)
|
|
140
|
+
>>> xs = torch.zeros((3, 2, 6))
|
|
141
|
+
>>> make_pad_mask(lengths, xs)
|
|
142
|
+
tensor([[[0, 0, 0, 0, 0, 1],
|
|
143
|
+
[0, 0, 0, 0, 0, 1]],
|
|
144
|
+
[[0, 0, 0, 1, 1, 1],
|
|
145
|
+
[0, 0, 0, 1, 1, 1]],
|
|
146
|
+
[[0, 0, 1, 1, 1, 1],
|
|
147
|
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
|
148
|
+
With the reference tensor and dimension indicator.
|
|
149
|
+
>>> xs = torch.zeros((3, 6, 6))
|
|
150
|
+
>>> make_pad_mask(lengths, xs, 1)
|
|
151
|
+
tensor([[[0, 0, 0, 0, 0, 0],
|
|
152
|
+
[0, 0, 0, 0, 0, 0],
|
|
153
|
+
[0, 0, 0, 0, 0, 0],
|
|
154
|
+
[0, 0, 0, 0, 0, 0],
|
|
155
|
+
[0, 0, 0, 0, 0, 0],
|
|
156
|
+
[1, 1, 1, 1, 1, 1]],
|
|
157
|
+
[[0, 0, 0, 0, 0, 0],
|
|
158
|
+
[0, 0, 0, 0, 0, 0],
|
|
159
|
+
[0, 0, 0, 0, 0, 0],
|
|
160
|
+
[1, 1, 1, 1, 1, 1],
|
|
161
|
+
[1, 1, 1, 1, 1, 1],
|
|
162
|
+
[1, 1, 1, 1, 1, 1]],
|
|
163
|
+
[[0, 0, 0, 0, 0, 0],
|
|
164
|
+
[0, 0, 0, 0, 0, 0],
|
|
165
|
+
[1, 1, 1, 1, 1, 1],
|
|
166
|
+
[1, 1, 1, 1, 1, 1],
|
|
167
|
+
[1, 1, 1, 1, 1, 1],
|
|
168
|
+
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
|
169
|
+
>>> make_pad_mask(lengths, xs, 2)
|
|
170
|
+
tensor([[[0, 0, 0, 0, 0, 1],
|
|
171
|
+
[0, 0, 0, 0, 0, 1],
|
|
172
|
+
[0, 0, 0, 0, 0, 1],
|
|
173
|
+
[0, 0, 0, 0, 0, 1],
|
|
174
|
+
[0, 0, 0, 0, 0, 1],
|
|
175
|
+
[0, 0, 0, 0, 0, 1]],
|
|
176
|
+
[[0, 0, 0, 1, 1, 1],
|
|
177
|
+
[0, 0, 0, 1, 1, 1],
|
|
178
|
+
[0, 0, 0, 1, 1, 1],
|
|
179
|
+
[0, 0, 0, 1, 1, 1],
|
|
180
|
+
[0, 0, 0, 1, 1, 1],
|
|
181
|
+
[0, 0, 0, 1, 1, 1]],
|
|
182
|
+
[[0, 0, 1, 1, 1, 1],
|
|
183
|
+
[0, 0, 1, 1, 1, 1],
|
|
184
|
+
[0, 0, 1, 1, 1, 1],
|
|
185
|
+
[0, 0, 1, 1, 1, 1],
|
|
186
|
+
[0, 0, 1, 1, 1, 1],
|
|
187
|
+
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
|
188
|
+
"""
|
|
189
|
+
if length_dim == 0:
|
|
190
|
+
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
|
191
|
+
|
|
192
|
+
if not isinstance(lengths, list):
|
|
193
|
+
lengths = lengths.tolist()
|
|
194
|
+
bs = int(len(lengths))
|
|
195
|
+
if xs is None:
|
|
196
|
+
maxlen = int(max(lengths))
|
|
197
|
+
else:
|
|
198
|
+
maxlen = xs.size(length_dim)
|
|
199
|
+
|
|
200
|
+
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
|
201
|
+
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
|
202
|
+
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
|
203
|
+
mask = seq_range_expand >= seq_length_expand
|
|
204
|
+
|
|
205
|
+
if xs is not None:
|
|
206
|
+
assert xs.size(0) == bs, (xs.size(0), bs)
|
|
207
|
+
|
|
208
|
+
if length_dim < 0:
|
|
209
|
+
length_dim = xs.dim() + length_dim
|
|
210
|
+
# ind = (:, None, ..., None, :, , None, ..., None)
|
|
211
|
+
ind = tuple(
|
|
212
|
+
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
|
213
|
+
)
|
|
214
|
+
mask = mask[ind].expand_as(xs).to(xs.device)
|
|
215
|
+
return mask
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
|
219
|
+
"""Make mask tensor containing indices of non-padded part.
|
|
220
|
+
Args:
|
|
221
|
+
lengths (LongTensor or List): Batch of lengths (B,).
|
|
222
|
+
xs (Tensor, optional): The reference tensor.
|
|
223
|
+
If set, masks will be the same shape as this tensor.
|
|
224
|
+
length_dim (int, optional): Dimension indicator of the above tensor.
|
|
225
|
+
See the example.
|
|
226
|
+
Returns:
|
|
227
|
+
ByteTensor: mask tensor containing indices of padded part.
|
|
228
|
+
dtype=torch.uint8 in PyTorch 1.2-
|
|
229
|
+
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
|
230
|
+
Examples:
|
|
231
|
+
With only lengths.
|
|
232
|
+
>>> lengths = [5, 3, 2]
|
|
233
|
+
>>> make_non_pad_mask(lengths)
|
|
234
|
+
masks = [[1, 1, 1, 1 ,1],
|
|
235
|
+
[1, 1, 1, 0, 0],
|
|
236
|
+
[1, 1, 0, 0, 0]]
|
|
237
|
+
With the reference tensor.
|
|
238
|
+
>>> xs = torch.zeros((3, 2, 4))
|
|
239
|
+
>>> make_non_pad_mask(lengths, xs)
|
|
240
|
+
tensor([[[1, 1, 1, 1],
|
|
241
|
+
[1, 1, 1, 1]],
|
|
242
|
+
[[1, 1, 1, 0],
|
|
243
|
+
[1, 1, 1, 0]],
|
|
244
|
+
[[1, 1, 0, 0],
|
|
245
|
+
[1, 1, 0, 0]]], dtype=torch.uint8)
|
|
246
|
+
>>> xs = torch.zeros((3, 2, 6))
|
|
247
|
+
>>> make_non_pad_mask(lengths, xs)
|
|
248
|
+
tensor([[[1, 1, 1, 1, 1, 0],
|
|
249
|
+
[1, 1, 1, 1, 1, 0]],
|
|
250
|
+
[[1, 1, 1, 0, 0, 0],
|
|
251
|
+
[1, 1, 1, 0, 0, 0]],
|
|
252
|
+
[[1, 1, 0, 0, 0, 0],
|
|
253
|
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
|
254
|
+
With the reference tensor and dimension indicator.
|
|
255
|
+
>>> xs = torch.zeros((3, 6, 6))
|
|
256
|
+
>>> make_non_pad_mask(lengths, xs, 1)
|
|
257
|
+
tensor([[[1, 1, 1, 1, 1, 1],
|
|
258
|
+
[1, 1, 1, 1, 1, 1],
|
|
259
|
+
[1, 1, 1, 1, 1, 1],
|
|
260
|
+
[1, 1, 1, 1, 1, 1],
|
|
261
|
+
[1, 1, 1, 1, 1, 1],
|
|
262
|
+
[0, 0, 0, 0, 0, 0]],
|
|
263
|
+
[[1, 1, 1, 1, 1, 1],
|
|
264
|
+
[1, 1, 1, 1, 1, 1],
|
|
265
|
+
[1, 1, 1, 1, 1, 1],
|
|
266
|
+
[0, 0, 0, 0, 0, 0],
|
|
267
|
+
[0, 0, 0, 0, 0, 0],
|
|
268
|
+
[0, 0, 0, 0, 0, 0]],
|
|
269
|
+
[[1, 1, 1, 1, 1, 1],
|
|
270
|
+
[1, 1, 1, 1, 1, 1],
|
|
271
|
+
[0, 0, 0, 0, 0, 0],
|
|
272
|
+
[0, 0, 0, 0, 0, 0],
|
|
273
|
+
[0, 0, 0, 0, 0, 0],
|
|
274
|
+
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
|
275
|
+
>>> make_non_pad_mask(lengths, xs, 2)
|
|
276
|
+
tensor([[[1, 1, 1, 1, 1, 0],
|
|
277
|
+
[1, 1, 1, 1, 1, 0],
|
|
278
|
+
[1, 1, 1, 1, 1, 0],
|
|
279
|
+
[1, 1, 1, 1, 1, 0],
|
|
280
|
+
[1, 1, 1, 1, 1, 0],
|
|
281
|
+
[1, 1, 1, 1, 1, 0]],
|
|
282
|
+
[[1, 1, 1, 0, 0, 0],
|
|
283
|
+
[1, 1, 1, 0, 0, 0],
|
|
284
|
+
[1, 1, 1, 0, 0, 0],
|
|
285
|
+
[1, 1, 1, 0, 0, 0],
|
|
286
|
+
[1, 1, 1, 0, 0, 0],
|
|
287
|
+
[1, 1, 1, 0, 0, 0]],
|
|
288
|
+
[[1, 1, 0, 0, 0, 0],
|
|
289
|
+
[1, 1, 0, 0, 0, 0],
|
|
290
|
+
[1, 1, 0, 0, 0, 0],
|
|
291
|
+
[1, 1, 0, 0, 0, 0],
|
|
292
|
+
[1, 1, 0, 0, 0, 0],
|
|
293
|
+
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
|
294
|
+
"""
|
|
295
|
+
return ~make_pad_mask(lengths, xs, length_dim)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def get_mask_from_lengths(lengths):
|
|
299
|
+
max_len = torch.max(lengths).item()
|
|
300
|
+
ids = torch.arange(0, max_len).to(lengths.device)
|
|
301
|
+
mask = (ids < lengths.unsqueeze(1)).bool()
|
|
302
|
+
return mask
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def group_hidden_by_segs(h, seg_ids, max_len):
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
:param h: [B, T, H]
|
|
309
|
+
:param seg_ids: [B, T]
|
|
310
|
+
:return: h_ph: [B, T_ph, H]
|
|
311
|
+
"""
|
|
312
|
+
B, T, H = h.shape
|
|
313
|
+
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
|
|
314
|
+
all_ones = h.new_ones(h.shape[:2])
|
|
315
|
+
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
|
|
316
|
+
h_gby_segs = h_gby_segs[:, 1:]
|
|
317
|
+
cnt_gby_segs = cnt_gby_segs[:, 1:]
|
|
318
|
+
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
|
|
319
|
+
return h_gby_segs, cnt_gby_segs
|
|
320
|
+
|
|
321
|
+
def expand_by_repeat_times(source_encoding, lengths):
|
|
322
|
+
"""
|
|
323
|
+
source_encoding: [T, C]
|
|
324
|
+
lengths, list of int, [T,], how many times each token should repeat
|
|
325
|
+
return:
|
|
326
|
+
expanded_encoding: [T_expand, C]
|
|
327
|
+
"""
|
|
328
|
+
hid_dim = source_encoding.shape[1]
|
|
329
|
+
out2source = []
|
|
330
|
+
for i, length in enumerate(lengths):
|
|
331
|
+
out2source += [i for _ in range(length)]
|
|
332
|
+
out2source = torch.LongTensor(out2source).to(source_encoding.device)
|
|
333
|
+
out2source_ = out2source[:, None].repeat([1, hid_dim])
|
|
334
|
+
expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H]
|
|
335
|
+
return expanded_encoding
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def expand_word2ph(word_encoding, ph2word):
|
|
339
|
+
word_encoding = F.pad(word_encoding,[0,0,1,0])
|
|
340
|
+
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
|
|
341
|
+
out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H]
|
|
342
|
+
return out
|