bithuman 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bithuman might be problematic. Click here for more details.

Files changed (44) hide show
  1. bithuman/__init__.py +13 -0
  2. bithuman/_version.py +1 -0
  3. bithuman/api.py +164 -0
  4. bithuman/audio/__init__.py +19 -0
  5. bithuman/audio/audio.py +396 -0
  6. bithuman/audio/hparams.py +108 -0
  7. bithuman/audio/utils.py +255 -0
  8. bithuman/config.py +88 -0
  9. bithuman/engine/__init__.py +15 -0
  10. bithuman/engine/auth.py +335 -0
  11. bithuman/engine/compression.py +257 -0
  12. bithuman/engine/enums.py +16 -0
  13. bithuman/engine/image_ops.py +192 -0
  14. bithuman/engine/inference.py +108 -0
  15. bithuman/engine/knn.py +58 -0
  16. bithuman/engine/video_data.py +391 -0
  17. bithuman/engine/video_reader.py +168 -0
  18. bithuman/lib/__init__.py +1 -0
  19. bithuman/lib/audio_encoder.onnx +45631 -28
  20. bithuman/lib/generator.py +763 -0
  21. bithuman/lib/pth2h5.py +106 -0
  22. bithuman/plugins/__init__.py +0 -0
  23. bithuman/plugins/stt.py +185 -0
  24. bithuman/runtime.py +1004 -0
  25. bithuman/runtime_async.py +469 -0
  26. bithuman/service/__init__.py +9 -0
  27. bithuman/service/client.py +788 -0
  28. bithuman/service/messages.py +210 -0
  29. bithuman/service/server.py +759 -0
  30. bithuman/utils/__init__.py +43 -0
  31. bithuman/utils/agent.py +359 -0
  32. bithuman/utils/fps_controller.py +90 -0
  33. bithuman/utils/image.py +41 -0
  34. bithuman/utils/unzip.py +38 -0
  35. bithuman/video_graph/__init__.py +16 -0
  36. bithuman/video_graph/action_trigger.py +83 -0
  37. bithuman/video_graph/driver_video.py +482 -0
  38. bithuman/video_graph/navigator.py +736 -0
  39. bithuman/video_graph/trigger.py +90 -0
  40. bithuman/video_graph/video_script.py +344 -0
  41. bithuman-1.0.2.dist-info/METADATA +37 -0
  42. bithuman-1.0.2.dist-info/RECORD +44 -0
  43. bithuman-1.0.2.dist-info/WHEEL +5 -0
  44. bithuman-1.0.2.dist-info/top_level.txt +1 -0
bithuman/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ from ._version import __version__
2
+ from .api import AudioChunk, VideoControl, VideoFrame
3
+ from .runtime import Bithuman
4
+ from .runtime_async import AsyncBithuman
5
+
6
+ __all__ = [
7
+ "__version__",
8
+ "Bithuman",
9
+ "AsyncBithuman",
10
+ "AudioChunk",
11
+ "VideoControl",
12
+ "VideoFrame",
13
+ ]
bithuman/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "1.0.2"
bithuman/api.py ADDED
@@ -0,0 +1,164 @@
1
+ """API for Bithuman Runtime."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from dataclasses import dataclass, field
7
+ from enum import Enum
8
+ from typing import Hashable, List, Optional
9
+
10
+ import numpy as np
11
+ from dataclasses_json import dataclass_json
12
+
13
+ INT16_MAX = 2**15 - 1
14
+
15
+
16
+ @dataclass
17
+ class AudioChunk:
18
+ """Data class to store audio data.
19
+
20
+ Attributes:
21
+ data (np.ndarray): Audio data in int16 format with 1 channel, shape (n_samples,).
22
+ sample_rate (int): Sample rate of the audio data.
23
+ last_chunk (bool): Whether this is the last chunk of the speech.
24
+ """
25
+
26
+ data: np.ndarray
27
+ sample_rate: int
28
+ last_chunk: bool = True
29
+
30
+ def __post_init__(self) -> None:
31
+ if isinstance(self.data, bytes):
32
+ self.data = np.frombuffer(self.data, dtype=np.int16)
33
+
34
+ if self.data.dtype in [np.float32, np.float64]:
35
+ self.data = (self.data * INT16_MAX).astype(np.int16)
36
+
37
+ if not isinstance(self.data, np.ndarray) or self.data.dtype != np.int16:
38
+ raise ValueError("data must be in int16 numpy array format")
39
+
40
+ @property
41
+ def bytes(self) -> bytes:
42
+ return self.data.tobytes()
43
+
44
+ @property
45
+ def array(self) -> np.ndarray:
46
+ return self.data
47
+
48
+ @property
49
+ def duration(self) -> float:
50
+ return len(self.data) / self.sample_rate
51
+
52
+ @classmethod
53
+ def from_bytes(
54
+ cls,
55
+ audio_bytes: bytes,
56
+ sample_rate: int,
57
+ last_chunk: bool = True,
58
+ ) -> "AudioChunk":
59
+ return cls(
60
+ data=np.frombuffer(audio_bytes, dtype=np.int16),
61
+ sample_rate=sample_rate,
62
+ last_chunk=last_chunk,
63
+ )
64
+
65
+
66
+ @dataclass
67
+ class VideoControl:
68
+ """Dataclass for video control information."""
69
+
70
+ audio: Optional[AudioChunk] = None
71
+ text: Optional[str] = None
72
+ target_video: Optional[str] = None
73
+ action: Optional[str | List[str]] = None
74
+ emotion_preds: Optional[List["EmotionPrediction"]] = None
75
+ message_id: str = field(default_factory=lambda: str(uuid.uuid4()))
76
+ end_of_speech: bool = False # mark the end of the speech
77
+ force_action: bool = False # force the action to be played
78
+ stop_on_user_speech: Optional[bool] = None # Override stop_on_user_speech from video config
79
+ stop_on_agent_speech: Optional[bool] = None # Override stop_on_agent_speech from video config
80
+
81
+ @property
82
+ def is_idle(self) -> bool:
83
+ """Check if the audio or control is idle."""
84
+ return (
85
+ self.audio is None
86
+ and self.target_video is None
87
+ and not self.action
88
+ and not self.emotion_preds
89
+ )
90
+
91
+ @property
92
+ def is_speaking(self) -> bool:
93
+ """Check if the audio or control is speaking."""
94
+ return self.audio is not None
95
+
96
+ @classmethod
97
+ def from_audio(
98
+ cls, audio: bytes | np.ndarray, sample_rate: int, last_chunk: bool = True
99
+ ) -> "VideoControl":
100
+ if isinstance(audio, bytes):
101
+ audio_chunk = AudioChunk.from_bytes(audio, sample_rate, last_chunk)
102
+ elif isinstance(audio, np.ndarray):
103
+ audio_chunk = AudioChunk(
104
+ data=audio, sample_rate=sample_rate, last_chunk=last_chunk
105
+ )
106
+ else:
107
+ raise ValueError("audio must be bytes or numpy array")
108
+ return cls(audio=audio_chunk)
109
+
110
+
111
+ @dataclass
112
+ class VideoFrame:
113
+ """
114
+ Dataclass for frame information.
115
+
116
+ Attributes:
117
+ bgr_image: image data in uint8 BGR format, shape (H, W, 3).
118
+ audio_chunk: audio chunk if the frame is talking,
119
+ chunked to the duration of this video frame based on FPS.
120
+ frame_index: frame index generated for this video control.
121
+ source_message_id: message id of the video control.
122
+ """
123
+
124
+ bgr_image: Optional[np.ndarray] = None
125
+ audio_chunk: Optional[AudioChunk] = None
126
+ frame_index: Optional[int] = None
127
+ source_message_id: Optional[Hashable] = None
128
+ end_of_speech: bool = False # mark the end of the speech
129
+
130
+ @property
131
+ def has_image(self) -> bool:
132
+ return self.bgr_image is not None
133
+
134
+ @property
135
+ def rgb_image(self) -> Optional[np.ndarray]:
136
+ if self.bgr_image is None:
137
+ return None
138
+ return self.bgr_image[:, :, ::-1]
139
+
140
+
141
+ class Emotion(str, Enum):
142
+ """Enumeration representing different emotions."""
143
+
144
+ ANGER = "anger"
145
+ DISGUST = "disgust"
146
+ FEAR = "fear"
147
+ JOY = "joy"
148
+ NEUTRAL = "neutral"
149
+ SADNESS = "sadness"
150
+ SURPRISE = "surprise"
151
+
152
+
153
+ @dataclass_json
154
+ @dataclass
155
+ class EmotionPrediction:
156
+ """Dataclass for emotion prediction.
157
+
158
+ Attributes:
159
+ emotion: The emotion.
160
+ score: The score for the emotion.
161
+ """
162
+
163
+ emotion: Emotion
164
+ score: float
@@ -0,0 +1,19 @@
1
+ from .audio import get_mel_chunks
2
+ from .utils import (
3
+ AudioStreamBatcher,
4
+ float32_to_int16,
5
+ int16_to_float32,
6
+ load_audio,
7
+ resample,
8
+ write_video_with_audio,
9
+ )
10
+
11
+ __all__ = [
12
+ "get_mel_chunks",
13
+ "load_audio",
14
+ "resample",
15
+ "write_video_with_audio",
16
+ "AudioStreamBatcher",
17
+ "float32_to_int16",
18
+ "int16_to_float32",
19
+ ]
@@ -0,0 +1,396 @@
1
+ from __future__ import annotations
2
+
3
+ import librosa
4
+ import librosa.filters
5
+ import numba
6
+ import numpy as np
7
+ from scipy import signal
8
+ from scipy.io import wavfile
9
+
10
+ from .hparams import hparams as hp
11
+
12
+
13
+ def get_mel_chunks(
14
+ audio_np: np.ndarray, fps: int, mel_step_size: int = 16
15
+ ) -> list[np.ndarray]:
16
+ """
17
+ Get mel chunks from audio.
18
+
19
+ Args:
20
+ audio_np: The audio data in numpy array, `np.float32` in 16kHz with 1 channel.
21
+ fps: The frame per second of the video.
22
+ mel_step_size: The step size for mel spectrogram chunks
23
+
24
+ Returns:
25
+ The mel chunks.
26
+ """
27
+ # Convert the audio waveform to a mel spectrogram
28
+
29
+ mel_chunks = []
30
+ if audio_np is None or len(audio_np) == 0:
31
+ return mel_chunks
32
+
33
+ mel = melspectrogram(audio_np)
34
+
35
+ # Check for NaN values in the mel spectrogram, which could indicate issues with audio processing
36
+ if np.isnan(mel.reshape(-1)).sum() > 0:
37
+ raise ValueError(
38
+ "Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again"
39
+ )
40
+
41
+ # Split the mel spectrogram into smaller chunks suitable for processing
42
+ mel_idx_multiplier = 80.0 / fps # Calculate index multiplier based on frame rate
43
+ i = 0
44
+ while True:
45
+ start_idx = int(i * mel_idx_multiplier)
46
+ if start_idx + mel_step_size > len(mel[0]): # Handle last chunk
47
+ mel_chunks.append(mel[:, len(mel[0]) - mel_step_size :])
48
+ break
49
+ mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size])
50
+ i += 1
51
+
52
+ return mel_chunks
53
+
54
+
55
+ def save_wav(wav, path, sr):
56
+ """
57
+ Save a wav file.
58
+
59
+ Args:
60
+ wav (np.ndarray): The wav file to save, as a numpy array.
61
+ path (str): The path to save the wav file to.
62
+ sr (int): The sample rate to use when saving the wav file.
63
+ """
64
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
65
+ wavfile.write(path, sr, wav.astype(np.int16))
66
+
67
+
68
+ def preemphasis(wav, k, preemphasize=True):
69
+ """
70
+ Apply preemphasis to a wav file.
71
+
72
+ Args:
73
+ wav (np.ndarray): The wav file to apply preemphasis to, as a numpy array.
74
+ k (float): The preemphasis coefficient.
75
+ preemphasize (bool, optional): Whether to apply preemphasis. Defaults to True.
76
+
77
+ Returns:
78
+ np.ndarray: The wav file with preemphasis applied.
79
+ """
80
+ if preemphasize:
81
+ return signal.lfilter([1, -k], [1], wav)
82
+ return wav
83
+
84
+
85
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
86
+ """
87
+ Apply inverse preemphasis to a wav file.
88
+
89
+ Args:
90
+ wav (np.ndarray): The wav file to apply inverse preemphasis to, as a numpy array.
91
+ k (float): The preemphasis coefficient.
92
+ inv_preemphasize (bool, optional): Whether to apply inverse preemphasis. Defaults to True.
93
+
94
+ Returns:
95
+ np.ndarray: The wav file with inverse preemphasis applied.
96
+ """
97
+ if inv_preemphasize:
98
+ return signal.lfilter([1], [1, -k], wav)
99
+ return wav
100
+
101
+
102
+ def get_hop_size():
103
+ """
104
+ Get the hop size.
105
+
106
+ Returns:
107
+ int: The hop size.
108
+ """
109
+ hop_size = hp.hop_size
110
+ if hop_size is None:
111
+ assert hp.frame_shift_ms is not None
112
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
113
+ return hop_size
114
+
115
+
116
+ def linearspectrogram(wav):
117
+ """
118
+ Compute the linear spectrogram of a wav file.
119
+
120
+ Args:
121
+ wav (np.ndarray): The wav file to compute the spectrogram of, as a numpy array.
122
+
123
+ Returns:
124
+ np.ndarray: The linear spectrogram of the wav file.
125
+ """
126
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
127
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
128
+
129
+ if hp.signal_normalization:
130
+ return _normalize(S)
131
+ return S
132
+
133
+
134
+ def melspectrogram(wav):
135
+ """
136
+ Compute the mel spectrogram of a wav file.
137
+
138
+ Args:
139
+ wav (np.ndarray): The wav file to compute the spectrogram of, as a numpy array.
140
+
141
+ Returns:
142
+ np.ndarray: The mel spectrogram of the wav file.
143
+ """
144
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
145
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
146
+
147
+ if hp.signal_normalization:
148
+ return _normalize(S)
149
+ return S
150
+
151
+
152
+ def _lws_processor():
153
+ # Third-Party Libraries
154
+ import lws
155
+
156
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
157
+
158
+
159
+ def _stft(y):
160
+ """
161
+ Compute the short-time Fourier transform of a signal.
162
+
163
+ Args:
164
+ y (np.ndarray): The signal to compute the transform of, as a numpy array.
165
+
166
+ Returns:
167
+ np.ndarray: The short-time Fourier transform of the signal.
168
+ """
169
+ if hp.use_lws:
170
+ return _lws_processor(hp).stft(y).T
171
+ else:
172
+ return librosa.stft(
173
+ y=y,
174
+ n_fft=hp.n_fft,
175
+ hop_length=get_hop_size(),
176
+ win_length=hp.win_size,
177
+ )
178
+
179
+
180
+ ##########################################################
181
+ # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
182
+ def num_frames(length, fsize, fshift):
183
+ """
184
+ Compute number of time frames of spectrogram.
185
+
186
+ Args:
187
+ length: length of signal
188
+ fsize: frame size
189
+ fshift: frame shift
190
+
191
+ """
192
+ pad = fsize - fshift
193
+ if length % fshift == 0:
194
+ M = (length + pad * 2 - fsize) // fshift + 1
195
+ else:
196
+ M = (length + pad * 2 - fsize) // fshift + 2
197
+ return M
198
+
199
+
200
+ def pad_lr(x, fsize, fshift):
201
+ """
202
+ Compute left and right padding.
203
+
204
+ Args:
205
+ x: input signal
206
+ fsize: frame size
207
+ fshift: frame shift
208
+ """
209
+ M = num_frames(len(x), fsize, fshift)
210
+ pad = fsize - fshift
211
+ T = len(x) + 2 * pad
212
+ r = (M - 1) * fshift + fsize - T
213
+ return pad, pad + r
214
+
215
+
216
+ ##########################################################
217
+ # Librosa correct padding
218
+ def librosa_pad_lr(x, fsize, fshift):
219
+ """
220
+ Compute left and right padding for librosa.
221
+
222
+ Args:
223
+ x: input signal
224
+ fsize: frame size
225
+ fshift: frame shift.
226
+
227
+ """
228
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
229
+
230
+
231
+ # Conversions
232
+ _mel_basis = None
233
+ _mel_basis_cache = None
234
+ _output_cache = None
235
+
236
+
237
+ @numba.jit(nopython=True, cache=True)
238
+ def _dot_mel(mel_basis, spectogram):
239
+ return np.dot(mel_basis, spectogram)
240
+
241
+
242
+ def _linear_to_mel_numba(spectogram):
243
+ """
244
+ Convert a linear spectrogram to a mel spectrogram using optimized matrix multiplication.
245
+
246
+ Args:
247
+ spectogram (np.ndarray): The linear spectrogram to convert, shape around [401, 22]
248
+
249
+ Returns:
250
+ np.ndarray: The mel spectrogram, shape [80, 22]
251
+ """
252
+ global _mel_basis
253
+ if _mel_basis is None:
254
+ _mel_basis = _build_mel_basis()
255
+ _mel_basis = np.ascontiguousarray(_mel_basis, dtype=np.float32)
256
+
257
+ spectogram = np.ascontiguousarray(spectogram, dtype=np.float32)
258
+ return _dot_mel(_mel_basis, spectogram)
259
+
260
+
261
+ @numba.jit(nopython=True, cache=True, parallel=False, fastmath=True)
262
+ def _dot_mel_hybrid(mel_basis, spectogram, out):
263
+ """Hybrid approach using vectorized operations with controlled CPU usage."""
264
+ # Process in smaller chunks to reduce CPU pressure
265
+ chunk_size = 16 # Adjust this value to balance speed/CPU usage
266
+
267
+ for j in range(0, spectogram.shape[1], chunk_size):
268
+ end_j = min(j + chunk_size, spectogram.shape[1])
269
+ # Use numpy's dot for each chunk
270
+ out[:, j:end_j] = mel_basis.dot(np.ascontiguousarray(spectogram[:, j:end_j]))
271
+
272
+ return out
273
+
274
+
275
+ def _linear_to_mel_hybrid(spectogram):
276
+ """Hybrid implementation balancing speed and CPU usage."""
277
+ global _mel_basis_cache, _output_cache
278
+
279
+ if _mel_basis_cache is None:
280
+ _mel_basis_cache = _build_mel_basis()
281
+ _mel_basis_cache = np.ascontiguousarray(_mel_basis_cache, dtype=np.float32)
282
+
283
+ spectogram = np.ascontiguousarray(spectogram, dtype=np.float32)
284
+
285
+ # Reuse output array
286
+ if _output_cache is None or _output_cache.shape != (
287
+ _mel_basis_cache.shape[0],
288
+ spectogram.shape[1],
289
+ ):
290
+ _output_cache = np.empty(
291
+ (_mel_basis_cache.shape[0], spectogram.shape[1]), dtype=np.float32
292
+ )
293
+
294
+ return _dot_mel_hybrid(_mel_basis_cache, spectogram, _output_cache)
295
+
296
+
297
+ _linear_to_mel = _linear_to_mel_hybrid
298
+
299
+
300
+ def _build_mel_basis():
301
+ """
302
+ Build the mel basis for converting linear spectrograms to mel spectrograms.
303
+
304
+ Returns:
305
+ np.ndarray: The mel basis.
306
+ """
307
+ assert hp.fmax <= hp.sample_rate // 2
308
+ return librosa.filters.mel(
309
+ sr=hp.sample_rate,
310
+ n_fft=hp.n_fft,
311
+ n_mels=hp.num_mels,
312
+ fmin=hp.fmin,
313
+ fmax=hp.fmax,
314
+ )
315
+
316
+
317
+ def _amp_to_db(x):
318
+ """
319
+ Convert amplitude to decibels.
320
+
321
+ Args:
322
+ x (np.ndarray): The amplitude to convert, as a numpy array.
323
+
324
+ Returns:
325
+ np.ndarray: The decibels.
326
+ """
327
+ min_level = np.float32(np.exp(hp.min_level_db / 20 * np.log(10)))
328
+ return np.float32(20) * np.log10(np.maximum(min_level, x)).astype(np.float32)
329
+
330
+
331
+ def _db_to_amp(x):
332
+ return np.power(10.0, (x) * 0.05)
333
+
334
+
335
+ def _normalize(S):
336
+ """
337
+ Normalize a spectrogram.
338
+
339
+ Args:
340
+ S (np.ndarray): The spectrogram to normalize, as a numpy array.
341
+
342
+ Returns:
343
+ np.ndarray: The normalized spectrogram.
344
+ """
345
+ if hp.allow_clipping_in_normalization:
346
+ if hp.symmetric_mels:
347
+ return np.clip(
348
+ (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db))
349
+ - hp.max_abs_value,
350
+ -hp.max_abs_value,
351
+ hp.max_abs_value,
352
+ )
353
+ else:
354
+ return np.clip(
355
+ hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)),
356
+ 0,
357
+ hp.max_abs_value,
358
+ )
359
+
360
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
361
+ if hp.symmetric_mels:
362
+ return (2 * hp.max_abs_value) * (
363
+ (S - hp.min_level_db) / (-hp.min_level_db)
364
+ ) - hp.max_abs_value
365
+ else:
366
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
367
+
368
+
369
+ def _denormalize(D):
370
+ """
371
+ Denormalize a spectrogram.
372
+
373
+ Args:
374
+ D: spectrogram to denormalize
375
+
376
+ Returns:
377
+ denormalized spectrogram
378
+ """
379
+ if hp.allow_clipping_in_normalization:
380
+ if hp.symmetric_mels:
381
+ return (
382
+ (np.clip(D, -hp.max_abs_value, hp.max_abs_value) + hp.max_abs_value)
383
+ * -hp.min_level_db
384
+ / (2 * hp.max_abs_value)
385
+ ) + hp.min_level_db
386
+ else:
387
+ return (
388
+ np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value
389
+ ) + hp.min_level_db
390
+
391
+ if hp.symmetric_mels:
392
+ return (
393
+ (D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)
394
+ ) + hp.min_level_db
395
+ else:
396
+ return (D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db
@@ -0,0 +1,108 @@
1
+ """hparams module."""
2
+
3
+
4
+ class HParams:
5
+ """
6
+ A class for managing hyperparameters.
7
+
8
+ This class provides an interface for getting and setting hyperparameters.
9
+ Hyperparameters are stored in a dictionary.
10
+ """
11
+
12
+ def __init__(self, **kwargs):
13
+ """
14
+ Initialize a HParams instance.
15
+
16
+ Args:
17
+ **kwargs: The initial hyperparameters, as keyword arguments.
18
+ """
19
+ self.data = {}
20
+
21
+ for key, value in kwargs.items():
22
+ self.data[key] = value
23
+
24
+ def __getattr__(self, key):
25
+ """
26
+ Get a hyperparameter.
27
+
28
+ Args:
29
+ key (str): The name of the hyperparameter.
30
+
31
+ Returns:
32
+ The value of the hyperparameter.
33
+
34
+ Raises:
35
+ AttributeError: If the hyperparameter does not exist.
36
+ """
37
+ if key not in self.data:
38
+ raise AttributeError("'HParams' object has no attribute %s" % key)
39
+ return self.data[key]
40
+
41
+ def set_hparam(self, key, value):
42
+ """
43
+ Set a hyperparameter.
44
+
45
+ Args:
46
+ key (str): The name of the hyperparameter.
47
+ value: The new value of the hyperparameter.
48
+ """
49
+ self.data[key] = value
50
+
51
+
52
+ # Default hyperparameters
53
+ hparams = HParams(
54
+ num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
55
+ # network
56
+ rescale=True, # Whether to rescale audio prior to preprocessing
57
+ rescaling_max=0.9, # Rescaling value
58
+ # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
59
+ # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
60
+ # Does not work if n_ffit is not multiple of hop_size!!
61
+ use_lws=False,
62
+ n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
63
+ hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
64
+ win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
65
+ sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
66
+ frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
67
+ # Mel and Linear spectrograms normalization/scaling and clipping
68
+ signal_normalization=True,
69
+ # Whether to normalize mel spectrograms to some predefined range (following below parameters)
70
+ allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
71
+ symmetric_mels=True,
72
+ # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
73
+ # faster and cleaner convergence)
74
+ max_abs_value=4.0,
75
+ # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
76
+ # be too big to avoid gradient explosion,
77
+ # not too small for fast convergence)
78
+ # Contribution by @begeekmyfriend
79
+ # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
80
+ # levels. Also allows for better G&L phase reconstruction)
81
+ preemphasize=True, # whether to apply filter
82
+ preemphasis=0.97, # filter coefficient.
83
+ # Limits
84
+ min_level_db=-100,
85
+ ref_level_db=20,
86
+ fmin=55,
87
+ # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
88
+ # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
89
+ fmax=7600, # To be increased/reduced depending on data.
90
+ ###################### Our training parameters #################################
91
+ img_size=96,
92
+ fps=25,
93
+ batch_size=16,
94
+ initial_learning_rate=1e-4,
95
+ nepochs=200000000000000000,
96
+ ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
97
+ num_workers=16,
98
+ checkpoint_interval=3000,
99
+ eval_interval=3000,
100
+ save_optimizer_state=True,
101
+ syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
102
+ syncnet_batch_size=64,
103
+ syncnet_lr=1e-4,
104
+ syncnet_eval_interval=10000,
105
+ syncnet_checkpoint_interval=10000,
106
+ disc_wt=0.07,
107
+ disc_initial_learning_rate=1e-4,
108
+ )