xinference 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (132) hide show
  1. xinference/_compat.py +1 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +54 -1
  4. xinference/client/restful/restful_client.py +82 -2
  5. xinference/constants.py +3 -0
  6. xinference/core/chat_interface.py +297 -83
  7. xinference/core/model.py +24 -3
  8. xinference/core/progress_tracker.py +16 -8
  9. xinference/core/supervisor.py +51 -1
  10. xinference/core/worker.py +315 -47
  11. xinference/deploy/cmdline.py +33 -1
  12. xinference/model/audio/core.py +11 -1
  13. xinference/model/audio/megatts.py +105 -0
  14. xinference/model/audio/model_spec.json +24 -1
  15. xinference/model/audio/model_spec_modelscope.json +26 -1
  16. xinference/model/core.py +14 -0
  17. xinference/model/embedding/core.py +6 -1
  18. xinference/model/flexible/core.py +6 -1
  19. xinference/model/image/core.py +6 -1
  20. xinference/model/image/model_spec.json +17 -1
  21. xinference/model/image/model_spec_modelscope.json +17 -1
  22. xinference/model/llm/__init__.py +4 -6
  23. xinference/model/llm/core.py +5 -0
  24. xinference/model/llm/llama_cpp/core.py +46 -17
  25. xinference/model/llm/llm_family.json +530 -85
  26. xinference/model/llm/llm_family.py +24 -1
  27. xinference/model/llm/llm_family_modelscope.json +572 -1
  28. xinference/model/llm/mlx/core.py +16 -2
  29. xinference/model/llm/reasoning_parser.py +3 -3
  30. xinference/model/llm/sglang/core.py +111 -13
  31. xinference/model/llm/transformers/__init__.py +14 -0
  32. xinference/model/llm/transformers/core.py +31 -6
  33. xinference/model/llm/transformers/deepseek_vl.py +1 -1
  34. xinference/model/llm/transformers/deepseek_vl2.py +287 -0
  35. xinference/model/llm/transformers/gemma3.py +17 -2
  36. xinference/model/llm/transformers/intern_vl.py +28 -18
  37. xinference/model/llm/transformers/minicpmv26.py +21 -2
  38. xinference/model/llm/transformers/qwen-omni.py +308 -0
  39. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  40. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  41. xinference/model/llm/utils.py +37 -15
  42. xinference/model/llm/vllm/core.py +184 -8
  43. xinference/model/llm/vllm/distributed_executor.py +320 -0
  44. xinference/model/rerank/core.py +22 -12
  45. xinference/model/utils.py +118 -1
  46. xinference/model/video/core.py +6 -1
  47. xinference/thirdparty/deepseek_vl2/__init__.py +31 -0
  48. xinference/thirdparty/deepseek_vl2/models/__init__.py +26 -0
  49. xinference/thirdparty/deepseek_vl2/models/configuration_deepseek.py +210 -0
  50. xinference/thirdparty/deepseek_vl2/models/conversation.py +310 -0
  51. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek.py +1975 -0
  52. xinference/thirdparty/deepseek_vl2/models/modeling_deepseek_vl_v2.py +697 -0
  53. xinference/thirdparty/deepseek_vl2/models/processing_deepseek_vl_v2.py +675 -0
  54. xinference/thirdparty/deepseek_vl2/models/siglip_vit.py +661 -0
  55. xinference/thirdparty/deepseek_vl2/serve/__init__.py +0 -0
  56. xinference/thirdparty/deepseek_vl2/serve/app_modules/__init__.py +0 -0
  57. xinference/thirdparty/deepseek_vl2/serve/app_modules/gradio_utils.py +83 -0
  58. xinference/thirdparty/deepseek_vl2/serve/app_modules/overwrites.py +81 -0
  59. xinference/thirdparty/deepseek_vl2/serve/app_modules/presets.py +115 -0
  60. xinference/thirdparty/deepseek_vl2/serve/app_modules/utils.py +333 -0
  61. xinference/thirdparty/deepseek_vl2/serve/assets/Kelpy-Codos.js +100 -0
  62. xinference/thirdparty/deepseek_vl2/serve/assets/avatar.png +0 -0
  63. xinference/thirdparty/deepseek_vl2/serve/assets/custom.css +355 -0
  64. xinference/thirdparty/deepseek_vl2/serve/assets/custom.js +22 -0
  65. xinference/thirdparty/deepseek_vl2/serve/assets/favicon.ico +0 -0
  66. xinference/thirdparty/deepseek_vl2/serve/assets/simsun.ttc +0 -0
  67. xinference/thirdparty/deepseek_vl2/serve/inference.py +197 -0
  68. xinference/thirdparty/deepseek_vl2/utils/__init__.py +18 -0
  69. xinference/thirdparty/deepseek_vl2/utils/io.py +80 -0
  70. xinference/thirdparty/megatts3/__init__.py +0 -0
  71. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  72. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  73. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  74. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  75. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  76. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  77. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  78. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  79. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  80. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  81. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  82. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  83. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  84. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  85. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  86. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  87. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  88. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  89. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  90. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  91. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  92. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  93. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  94. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  95. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  96. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  97. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  98. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  99. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  100. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  101. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  102. xinference/types.py +10 -0
  103. xinference/utils.py +54 -0
  104. xinference/web/ui/build/asset-manifest.json +6 -6
  105. xinference/web/ui/build/index.html +1 -1
  106. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  107. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  108. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  109. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  110. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  111. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  112. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  113. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  114. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  116. xinference/web/ui/src/locales/en.json +2 -1
  117. xinference/web/ui/src/locales/zh.json +2 -1
  118. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/METADATA +128 -115
  119. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/RECORD +124 -63
  120. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
  121. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  122. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  123. xinference/web/ui/build/static/js/main.3cea968e.js +0 -3
  124. xinference/web/ui/build/static/js/main.3cea968e.js.map +0 -1
  125. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  126. xinference/web/ui/node_modules/.cache/babel-loader/7f59e45e3f268ab8a4788b6fb024cf8dab088736dff22f5a3a39c122a83ab930.json +0 -1
  127. xinference/web/ui/node_modules/.cache/babel-loader/dcd60488509450bfff37bfff56de2c096d51de17dd00ec60d4db49c8b483ada1.json +0 -1
  128. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  129. /xinference/web/ui/build/static/js/{main.3cea968e.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  130. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
  131. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
  132. {xinference-1.4.0.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ def make_positions(tensor, padding_idx):
21
+ """Replace non-padding symbols with their position numbers.
22
+
23
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
24
+ """
25
+ # The series of casts and type-conversions here are carefully
26
+ # balanced to both work with ONNX export and XLA. In particular XLA
27
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
28
+ # how to handle the dtype kwarg in cumsum.
29
+ mask = tensor.ne(padding_idx).int()
30
+ return (
31
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
32
+ ).long() + padding_idx
33
+
34
+
35
+ def softmax(x, dim):
36
+ return F.softmax(x, dim=dim, dtype=torch.float32)
37
+
38
+
39
+ def sequence_mask(lengths, maxlen=None, dtype=torch.bool):
40
+ if maxlen is None:
41
+ maxlen = lengths.max()
42
+ mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t()
43
+ mask.type(dtype)
44
+ return mask
45
+
46
+
47
+ def weights_nonzero_speech(target):
48
+ # target : B x T x mel
49
+ # Assign weight 1.0 to all labels except for padding (id=0).
50
+ dim = target.size(-1)
51
+ return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
52
+
53
+
54
+ INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
55
+
56
+
57
+ def _get_full_incremental_state_key(module_instance, key):
58
+ module_name = module_instance.__class__.__name__
59
+
60
+ # assign a unique ID to each module instance, so that incremental state is
61
+ # not shared across module instances
62
+ if not hasattr(module_instance, '_instance_id'):
63
+ INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
64
+ module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
65
+
66
+ return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
67
+
68
+
69
+ def get_incremental_state(module, incremental_state, key):
70
+ """Helper for getting incremental state for an nn.Module."""
71
+ full_key = _get_full_incremental_state_key(module, key)
72
+ if incremental_state is None or full_key not in incremental_state:
73
+ return None
74
+ return incremental_state[full_key]
75
+
76
+
77
+ def set_incremental_state(module, incremental_state, key, value):
78
+ """Helper for setting incremental state for an nn.Module."""
79
+ if incremental_state is not None:
80
+ full_key = _get_full_incremental_state_key(module, key)
81
+ incremental_state[full_key] = value
82
+
83
+
84
+ def fill_with_neg_inf(t):
85
+ """FP16-compatible function that fills a tensor with -inf."""
86
+ return t.float().fill_(float('-inf')).type_as(t)
87
+
88
+
89
+ def fill_with_neg_inf2(t):
90
+ """FP16-compatible function that fills a tensor with -inf."""
91
+ return t.float().fill_(-1e8).type_as(t)
92
+
93
+
94
+ def select_attn(attn_logits, type='best'):
95
+ """
96
+
97
+ :param attn_logits: [n_layers, B, n_head, T_sp, T_txt]
98
+ :return:
99
+ """
100
+ encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2)
101
+ # [n_layers * n_head, B, T_sp, T_txt]
102
+ encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1)
103
+ if type == 'best':
104
+ indices = encdec_attn.max(-1).values.sum(-1).argmax(0)
105
+ encdec_attn = encdec_attn.gather(
106
+ 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0]
107
+ return encdec_attn
108
+ elif type == 'mean':
109
+ return encdec_attn.mean(0)
110
+
111
+
112
+ def make_pad_mask(lengths, xs=None, length_dim=-1):
113
+ """Make mask tensor containing indices of padded part.
114
+ Args:
115
+ lengths (LongTensor or List): Batch of lengths (B,).
116
+ xs (Tensor, optional): The reference tensor.
117
+ If set, masks will be the same shape as this tensor.
118
+ length_dim (int, optional): Dimension indicator of the above tensor.
119
+ See the example.
120
+ Returns:
121
+ Tensor: Mask tensor containing indices of padded part.
122
+ dtype=torch.uint8 in PyTorch 1.2-
123
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
124
+ Examples:
125
+ With only lengths.
126
+ >>> lengths = [5, 3, 2]
127
+ >>> make_non_pad_mask(lengths)
128
+ masks = [[0, 0, 0, 0 ,0],
129
+ [0, 0, 0, 1, 1],
130
+ [0, 0, 1, 1, 1]]
131
+ With the reference tensor.
132
+ >>> xs = torch.zeros((3, 2, 4))
133
+ >>> make_pad_mask(lengths, xs)
134
+ tensor([[[0, 0, 0, 0],
135
+ [0, 0, 0, 0]],
136
+ [[0, 0, 0, 1],
137
+ [0, 0, 0, 1]],
138
+ [[0, 0, 1, 1],
139
+ [0, 0, 1, 1]]], dtype=torch.uint8)
140
+ >>> xs = torch.zeros((3, 2, 6))
141
+ >>> make_pad_mask(lengths, xs)
142
+ tensor([[[0, 0, 0, 0, 0, 1],
143
+ [0, 0, 0, 0, 0, 1]],
144
+ [[0, 0, 0, 1, 1, 1],
145
+ [0, 0, 0, 1, 1, 1]],
146
+ [[0, 0, 1, 1, 1, 1],
147
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
148
+ With the reference tensor and dimension indicator.
149
+ >>> xs = torch.zeros((3, 6, 6))
150
+ >>> make_pad_mask(lengths, xs, 1)
151
+ tensor([[[0, 0, 0, 0, 0, 0],
152
+ [0, 0, 0, 0, 0, 0],
153
+ [0, 0, 0, 0, 0, 0],
154
+ [0, 0, 0, 0, 0, 0],
155
+ [0, 0, 0, 0, 0, 0],
156
+ [1, 1, 1, 1, 1, 1]],
157
+ [[0, 0, 0, 0, 0, 0],
158
+ [0, 0, 0, 0, 0, 0],
159
+ [0, 0, 0, 0, 0, 0],
160
+ [1, 1, 1, 1, 1, 1],
161
+ [1, 1, 1, 1, 1, 1],
162
+ [1, 1, 1, 1, 1, 1]],
163
+ [[0, 0, 0, 0, 0, 0],
164
+ [0, 0, 0, 0, 0, 0],
165
+ [1, 1, 1, 1, 1, 1],
166
+ [1, 1, 1, 1, 1, 1],
167
+ [1, 1, 1, 1, 1, 1],
168
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
169
+ >>> make_pad_mask(lengths, xs, 2)
170
+ tensor([[[0, 0, 0, 0, 0, 1],
171
+ [0, 0, 0, 0, 0, 1],
172
+ [0, 0, 0, 0, 0, 1],
173
+ [0, 0, 0, 0, 0, 1],
174
+ [0, 0, 0, 0, 0, 1],
175
+ [0, 0, 0, 0, 0, 1]],
176
+ [[0, 0, 0, 1, 1, 1],
177
+ [0, 0, 0, 1, 1, 1],
178
+ [0, 0, 0, 1, 1, 1],
179
+ [0, 0, 0, 1, 1, 1],
180
+ [0, 0, 0, 1, 1, 1],
181
+ [0, 0, 0, 1, 1, 1]],
182
+ [[0, 0, 1, 1, 1, 1],
183
+ [0, 0, 1, 1, 1, 1],
184
+ [0, 0, 1, 1, 1, 1],
185
+ [0, 0, 1, 1, 1, 1],
186
+ [0, 0, 1, 1, 1, 1],
187
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
188
+ """
189
+ if length_dim == 0:
190
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
191
+
192
+ if not isinstance(lengths, list):
193
+ lengths = lengths.tolist()
194
+ bs = int(len(lengths))
195
+ if xs is None:
196
+ maxlen = int(max(lengths))
197
+ else:
198
+ maxlen = xs.size(length_dim)
199
+
200
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
201
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
202
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
203
+ mask = seq_range_expand >= seq_length_expand
204
+
205
+ if xs is not None:
206
+ assert xs.size(0) == bs, (xs.size(0), bs)
207
+
208
+ if length_dim < 0:
209
+ length_dim = xs.dim() + length_dim
210
+ # ind = (:, None, ..., None, :, , None, ..., None)
211
+ ind = tuple(
212
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
213
+ )
214
+ mask = mask[ind].expand_as(xs).to(xs.device)
215
+ return mask
216
+
217
+
218
+ def make_non_pad_mask(lengths, xs=None, length_dim=-1):
219
+ """Make mask tensor containing indices of non-padded part.
220
+ Args:
221
+ lengths (LongTensor or List): Batch of lengths (B,).
222
+ xs (Tensor, optional): The reference tensor.
223
+ If set, masks will be the same shape as this tensor.
224
+ length_dim (int, optional): Dimension indicator of the above tensor.
225
+ See the example.
226
+ Returns:
227
+ ByteTensor: mask tensor containing indices of padded part.
228
+ dtype=torch.uint8 in PyTorch 1.2-
229
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
230
+ Examples:
231
+ With only lengths.
232
+ >>> lengths = [5, 3, 2]
233
+ >>> make_non_pad_mask(lengths)
234
+ masks = [[1, 1, 1, 1 ,1],
235
+ [1, 1, 1, 0, 0],
236
+ [1, 1, 0, 0, 0]]
237
+ With the reference tensor.
238
+ >>> xs = torch.zeros((3, 2, 4))
239
+ >>> make_non_pad_mask(lengths, xs)
240
+ tensor([[[1, 1, 1, 1],
241
+ [1, 1, 1, 1]],
242
+ [[1, 1, 1, 0],
243
+ [1, 1, 1, 0]],
244
+ [[1, 1, 0, 0],
245
+ [1, 1, 0, 0]]], dtype=torch.uint8)
246
+ >>> xs = torch.zeros((3, 2, 6))
247
+ >>> make_non_pad_mask(lengths, xs)
248
+ tensor([[[1, 1, 1, 1, 1, 0],
249
+ [1, 1, 1, 1, 1, 0]],
250
+ [[1, 1, 1, 0, 0, 0],
251
+ [1, 1, 1, 0, 0, 0]],
252
+ [[1, 1, 0, 0, 0, 0],
253
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
254
+ With the reference tensor and dimension indicator.
255
+ >>> xs = torch.zeros((3, 6, 6))
256
+ >>> make_non_pad_mask(lengths, xs, 1)
257
+ tensor([[[1, 1, 1, 1, 1, 1],
258
+ [1, 1, 1, 1, 1, 1],
259
+ [1, 1, 1, 1, 1, 1],
260
+ [1, 1, 1, 1, 1, 1],
261
+ [1, 1, 1, 1, 1, 1],
262
+ [0, 0, 0, 0, 0, 0]],
263
+ [[1, 1, 1, 1, 1, 1],
264
+ [1, 1, 1, 1, 1, 1],
265
+ [1, 1, 1, 1, 1, 1],
266
+ [0, 0, 0, 0, 0, 0],
267
+ [0, 0, 0, 0, 0, 0],
268
+ [0, 0, 0, 0, 0, 0]],
269
+ [[1, 1, 1, 1, 1, 1],
270
+ [1, 1, 1, 1, 1, 1],
271
+ [0, 0, 0, 0, 0, 0],
272
+ [0, 0, 0, 0, 0, 0],
273
+ [0, 0, 0, 0, 0, 0],
274
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
275
+ >>> make_non_pad_mask(lengths, xs, 2)
276
+ tensor([[[1, 1, 1, 1, 1, 0],
277
+ [1, 1, 1, 1, 1, 0],
278
+ [1, 1, 1, 1, 1, 0],
279
+ [1, 1, 1, 1, 1, 0],
280
+ [1, 1, 1, 1, 1, 0],
281
+ [1, 1, 1, 1, 1, 0]],
282
+ [[1, 1, 1, 0, 0, 0],
283
+ [1, 1, 1, 0, 0, 0],
284
+ [1, 1, 1, 0, 0, 0],
285
+ [1, 1, 1, 0, 0, 0],
286
+ [1, 1, 1, 0, 0, 0],
287
+ [1, 1, 1, 0, 0, 0]],
288
+ [[1, 1, 0, 0, 0, 0],
289
+ [1, 1, 0, 0, 0, 0],
290
+ [1, 1, 0, 0, 0, 0],
291
+ [1, 1, 0, 0, 0, 0],
292
+ [1, 1, 0, 0, 0, 0],
293
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
294
+ """
295
+ return ~make_pad_mask(lengths, xs, length_dim)
296
+
297
+
298
+ def get_mask_from_lengths(lengths):
299
+ max_len = torch.max(lengths).item()
300
+ ids = torch.arange(0, max_len).to(lengths.device)
301
+ mask = (ids < lengths.unsqueeze(1)).bool()
302
+ return mask
303
+
304
+
305
+ def group_hidden_by_segs(h, seg_ids, max_len):
306
+ """
307
+
308
+ :param h: [B, T, H]
309
+ :param seg_ids: [B, T]
310
+ :return: h_ph: [B, T_ph, H]
311
+ """
312
+ B, T, H = h.shape
313
+ h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h)
314
+ all_ones = h.new_ones(h.shape[:2])
315
+ cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous()
316
+ h_gby_segs = h_gby_segs[:, 1:]
317
+ cnt_gby_segs = cnt_gby_segs[:, 1:]
318
+ h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1)
319
+ return h_gby_segs, cnt_gby_segs
320
+
321
+ def expand_by_repeat_times(source_encoding, lengths):
322
+ """
323
+ source_encoding: [T, C]
324
+ lengths, list of int, [T,], how many times each token should repeat
325
+ return:
326
+ expanded_encoding: [T_expand, C]
327
+ """
328
+ hid_dim = source_encoding.shape[1]
329
+ out2source = []
330
+ for i, length in enumerate(lengths):
331
+ out2source += [i for _ in range(length)]
332
+ out2source = torch.LongTensor(out2source).to(source_encoding.device)
333
+ out2source_ = out2source[:, None].repeat([1, hid_dim])
334
+ expanded_encoding = torch.gather(source_encoding, 0, out2source_) # [B, T, H]
335
+ return expanded_encoding
336
+
337
+
338
+ def expand_word2ph(word_encoding, ph2word):
339
+ word_encoding = F.pad(word_encoding,[0,0,1,0])
340
+ ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]])
341
+ out = torch.gather(word_encoding, 1, ph2word_) # [B, T, H]
342
+ return out