xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/core/chat_interface.py +1 -1
- xinference/core/image_interface.py +9 -0
- xinference/core/model.py +4 -1
- xinference/core/worker.py +60 -44
- xinference/model/audio/chattts.py +25 -9
- xinference/model/audio/core.py +8 -2
- xinference/model/audio/cosyvoice.py +4 -3
- xinference/model/audio/custom.py +4 -5
- xinference/model/audio/fish_speech.py +228 -0
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/embedding/core.py +25 -1
- xinference/model/embedding/custom.py +4 -5
- xinference/model/flexible/core.py +5 -1
- xinference/model/image/custom.py +4 -5
- xinference/model/image/model_spec.json +2 -1
- xinference/model/image/model_spec_modelscope.json +2 -1
- xinference/model/image/stable_diffusion/core.py +66 -3
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +54 -9
- xinference/model/llm/llm_family.py +7 -6
- xinference/model/llm/llm_family_modelscope.json +56 -10
- xinference/model/llm/lmdeploy/__init__.py +0 -0
- xinference/model/llm/lmdeploy/core.py +557 -0
- xinference/model/llm/sglang/core.py +7 -1
- xinference/model/llm/transformers/cogvlm2.py +4 -45
- xinference/model/llm/transformers/cogvlm2_video.py +524 -0
- xinference/model/llm/transformers/core.py +3 -0
- xinference/model/llm/transformers/glm4v.py +2 -23
- xinference/model/llm/transformers/intern_vl.py +94 -11
- xinference/model/llm/transformers/minicpmv25.py +2 -23
- xinference/model/llm/transformers/minicpmv26.py +2 -22
- xinference/model/llm/transformers/yi_vl.py +2 -24
- xinference/model/llm/utils.py +13 -1
- xinference/model/llm/vllm/core.py +1 -34
- xinference/model/rerank/custom.py +4 -5
- xinference/model/utils.py +41 -1
- xinference/model/video/core.py +3 -1
- xinference/model/video/diffusers.py +41 -38
- xinference/model/video/model_spec.json +24 -1
- xinference/model/video/model_spec_modelscope.json +25 -1
- xinference/thirdparty/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
- xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
- xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
- xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
- xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
- xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/api.py +495 -0
- xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
- xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
- xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
- xinference/thirdparty/fish_speech/tools/file.py +108 -0
- xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
- xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
- xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
- xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
- xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
- xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
- xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
- xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
- xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
- xinference/thirdparty/fish_speech/tools/webui.py +619 -0
- xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
- xinference/thirdparty/matcha/__init__.py +0 -0
- xinference/thirdparty/matcha/app.py +357 -0
- xinference/thirdparty/matcha/cli.py +419 -0
- xinference/thirdparty/matcha/data/__init__.py +0 -0
- xinference/thirdparty/matcha/data/components/__init__.py +0 -0
- xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
- xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
- xinference/thirdparty/matcha/hifigan/config.py +28 -0
- xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
- xinference/thirdparty/matcha/hifigan/env.py +17 -0
- xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
- xinference/thirdparty/matcha/hifigan/models.py +368 -0
- xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
- xinference/thirdparty/matcha/models/__init__.py +0 -0
- xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
- xinference/thirdparty/matcha/models/components/__init__.py +0 -0
- xinference/thirdparty/matcha/models/components/decoder.py +443 -0
- xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
- xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
- xinference/thirdparty/matcha/models/components/transformer.py +316 -0
- xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
- xinference/thirdparty/matcha/onnx/__init__.py +0 -0
- xinference/thirdparty/matcha/onnx/export.py +181 -0
- xinference/thirdparty/matcha/onnx/infer.py +168 -0
- xinference/thirdparty/matcha/text/__init__.py +53 -0
- xinference/thirdparty/matcha/text/cleaners.py +121 -0
- xinference/thirdparty/matcha/text/numbers.py +71 -0
- xinference/thirdparty/matcha/text/symbols.py +17 -0
- xinference/thirdparty/matcha/train.py +122 -0
- xinference/thirdparty/matcha/utils/__init__.py +5 -0
- xinference/thirdparty/matcha/utils/audio.py +82 -0
- xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
- xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
- xinference/thirdparty/matcha/utils/instantiators.py +56 -0
- xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
- xinference/thirdparty/matcha/utils/model.py +90 -0
- xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
- xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
- xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
- xinference/thirdparty/matcha/utils/pylogger.py +21 -0
- xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
- xinference/thirdparty/matcha/utils/utils.py +259 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
- xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
- xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
- /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from .braceexpand import braceexpand
|
|
2
|
+
from .context import autocast_exclude_mps
|
|
3
|
+
from .file import get_latest_checkpoint
|
|
4
|
+
from .instantiators import instantiate_callbacks, instantiate_loggers
|
|
5
|
+
from .logger import RankedLogger
|
|
6
|
+
from .logging_utils import log_hyperparameters
|
|
7
|
+
from .rich_utils import enforce_tags, print_config_tree
|
|
8
|
+
from .utils import extras, get_metric_value, task_wrapper
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"enforce_tags",
|
|
12
|
+
"extras",
|
|
13
|
+
"get_metric_value",
|
|
14
|
+
"RankedLogger",
|
|
15
|
+
"instantiate_callbacks",
|
|
16
|
+
"instantiate_loggers",
|
|
17
|
+
"log_hyperparameters",
|
|
18
|
+
"print_config_tree",
|
|
19
|
+
"task_wrapper",
|
|
20
|
+
"braceexpand",
|
|
21
|
+
"get_latest_checkpoint",
|
|
22
|
+
"autocast_exclude_mps",
|
|
23
|
+
]
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Bash-style brace expansion
|
|
3
|
+
Copied from: https://github.com/trendels/braceexpand/blob/main/src/braceexpand/__init__.py
|
|
4
|
+
License: MIT
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import re
|
|
8
|
+
import string
|
|
9
|
+
from itertools import chain, product
|
|
10
|
+
from typing import Iterable, Iterator, Optional
|
|
11
|
+
|
|
12
|
+
__all__ = ["braceexpand", "alphabet", "UnbalancedBracesError"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class UnbalancedBracesError(ValueError):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
alphabet = string.ascii_uppercase + string.ascii_lowercase
|
|
20
|
+
|
|
21
|
+
int_range_re = re.compile(r"^(-?\d+)\.\.(-?\d+)(?:\.\.-?(\d+))?$")
|
|
22
|
+
char_range_re = re.compile(r"^([A-Za-z])\.\.([A-Za-z])(?:\.\.-?(\d+))?$")
|
|
23
|
+
escape_re = re.compile(r"\\(.)")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def braceexpand(pattern: str, escape: bool = True) -> Iterator[str]:
|
|
27
|
+
"""braceexpand(pattern) -> iterator over generated strings
|
|
28
|
+
|
|
29
|
+
Returns an iterator over the strings resulting from brace expansion
|
|
30
|
+
of pattern. This function implements Brace Expansion as described in
|
|
31
|
+
bash(1), with the following limitations:
|
|
32
|
+
|
|
33
|
+
* A pattern containing unbalanced braces will raise an
|
|
34
|
+
UnbalancedBracesError exception. In bash, unbalanced braces will either
|
|
35
|
+
be partly expanded or ignored.
|
|
36
|
+
|
|
37
|
+
* A mixed-case character range like '{Z..a}' or '{a..Z}' will not
|
|
38
|
+
include the characters '[]^_`' between 'Z' and 'a'.
|
|
39
|
+
|
|
40
|
+
When escape is True (the default), characters in pattern can be
|
|
41
|
+
prefixed with a backslash to cause them not to be interpreted as
|
|
42
|
+
special characters for brace expansion (such as '{', '}', ',').
|
|
43
|
+
To pass through a a literal backslash, double it ('\\\\').
|
|
44
|
+
|
|
45
|
+
When escape is False, backslashes in pattern have no special
|
|
46
|
+
meaning and will be preserved in the output.
|
|
47
|
+
|
|
48
|
+
Examples:
|
|
49
|
+
|
|
50
|
+
>>> from braceexpand import braceexpand
|
|
51
|
+
|
|
52
|
+
# Integer range
|
|
53
|
+
>>> list(braceexpand('item{1..3}'))
|
|
54
|
+
['item1', 'item2', 'item3']
|
|
55
|
+
|
|
56
|
+
# Character range
|
|
57
|
+
>>> list(braceexpand('{a..c}'))
|
|
58
|
+
['a', 'b', 'c']
|
|
59
|
+
|
|
60
|
+
# Sequence
|
|
61
|
+
>>> list(braceexpand('index.html{,.backup}'))
|
|
62
|
+
['index.html', 'index.html.backup']
|
|
63
|
+
|
|
64
|
+
# Nested patterns
|
|
65
|
+
>>> list(braceexpand('python{2.{5..7},3.{2,3}}'))
|
|
66
|
+
['python2.5', 'python2.6', 'python2.7', 'python3.2', 'python3.3']
|
|
67
|
+
|
|
68
|
+
# Prefixing an integer with zero causes all numbers to be padded to
|
|
69
|
+
# the same width.
|
|
70
|
+
>>> list(braceexpand('{07..10}'))
|
|
71
|
+
['07', '08', '09', '10']
|
|
72
|
+
|
|
73
|
+
# An optional increment can be specified for ranges.
|
|
74
|
+
>>> list(braceexpand('{a..g..2}'))
|
|
75
|
+
['a', 'c', 'e', 'g']
|
|
76
|
+
|
|
77
|
+
# Ranges can go in both directions.
|
|
78
|
+
>>> list(braceexpand('{4..1}'))
|
|
79
|
+
['4', '3', '2', '1']
|
|
80
|
+
|
|
81
|
+
# Numbers can be negative
|
|
82
|
+
>>> list(braceexpand('{2..-1}'))
|
|
83
|
+
['2', '1', '0', '-1']
|
|
84
|
+
|
|
85
|
+
# Unbalanced braces raise an exception.
|
|
86
|
+
>>> list(braceexpand('{1{2,3}'))
|
|
87
|
+
Traceback (most recent call last):
|
|
88
|
+
...
|
|
89
|
+
UnbalancedBracesError: Unbalanced braces: '{1{2,3}'
|
|
90
|
+
|
|
91
|
+
# By default, the backslash is the escape character.
|
|
92
|
+
>>> list(braceexpand(r'{1\\{2,3}'))
|
|
93
|
+
['1{2', '3']
|
|
94
|
+
|
|
95
|
+
# Setting 'escape' to False disables backslash escaping.
|
|
96
|
+
>>> list(braceexpand(r'\\{1,2}', escape=False))
|
|
97
|
+
['\\\\1', '\\\\2']
|
|
98
|
+
|
|
99
|
+
"""
|
|
100
|
+
return (
|
|
101
|
+
escape_re.sub(r"\1", s) if escape else s for s in parse_pattern(pattern, escape)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def parse_pattern(pattern: str, escape: bool) -> Iterator[str]:
|
|
106
|
+
start = 0
|
|
107
|
+
pos = 0
|
|
108
|
+
bracketdepth = 0
|
|
109
|
+
items: list[Iterable[str]] = []
|
|
110
|
+
|
|
111
|
+
# print 'pattern:', pattern
|
|
112
|
+
while pos < len(pattern):
|
|
113
|
+
if escape and pattern[pos] == "\\":
|
|
114
|
+
pos += 2
|
|
115
|
+
continue
|
|
116
|
+
elif pattern[pos] == "{":
|
|
117
|
+
if bracketdepth == 0 and pos > start:
|
|
118
|
+
# print 'literal:', pattern[start:pos]
|
|
119
|
+
items.append([pattern[start:pos]])
|
|
120
|
+
start = pos
|
|
121
|
+
bracketdepth += 1
|
|
122
|
+
elif pattern[pos] == "}":
|
|
123
|
+
bracketdepth -= 1
|
|
124
|
+
if bracketdepth == 0:
|
|
125
|
+
# print 'expression:', pattern[start+1:pos]
|
|
126
|
+
expr = pattern[start + 1 : pos]
|
|
127
|
+
item = parse_expression(expr, escape)
|
|
128
|
+
if item is None: # not a range or sequence
|
|
129
|
+
items.extend([["{"], parse_pattern(expr, escape), ["}"]])
|
|
130
|
+
else:
|
|
131
|
+
items.append(item)
|
|
132
|
+
start = pos + 1 # skip the closing brace
|
|
133
|
+
pos += 1
|
|
134
|
+
|
|
135
|
+
if bracketdepth != 0: # unbalanced braces
|
|
136
|
+
raise UnbalancedBracesError("Unbalanced braces: '%s'" % pattern)
|
|
137
|
+
|
|
138
|
+
if start < pos:
|
|
139
|
+
items.append([pattern[start:]])
|
|
140
|
+
|
|
141
|
+
return ("".join(item) for item in product(*items))
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def parse_expression(expr: str, escape: bool) -> Optional[Iterable[str]]:
|
|
145
|
+
int_range_match = int_range_re.match(expr)
|
|
146
|
+
if int_range_match:
|
|
147
|
+
return make_int_range(*int_range_match.groups())
|
|
148
|
+
|
|
149
|
+
char_range_match = char_range_re.match(expr)
|
|
150
|
+
if char_range_match:
|
|
151
|
+
return make_char_range(*char_range_match.groups())
|
|
152
|
+
|
|
153
|
+
return parse_sequence(expr, escape)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def parse_sequence(seq: str, escape: bool) -> Optional[Iterator[str]]:
|
|
157
|
+
# sequence -> chain(*sequence_items)
|
|
158
|
+
start = 0
|
|
159
|
+
pos = 0
|
|
160
|
+
bracketdepth = 0
|
|
161
|
+
items: list[Iterable[str]] = []
|
|
162
|
+
|
|
163
|
+
# print 'sequence:', seq
|
|
164
|
+
while pos < len(seq):
|
|
165
|
+
if escape and seq[pos] == "\\":
|
|
166
|
+
pos += 2
|
|
167
|
+
continue
|
|
168
|
+
elif seq[pos] == "{":
|
|
169
|
+
bracketdepth += 1
|
|
170
|
+
elif seq[pos] == "}":
|
|
171
|
+
bracketdepth -= 1
|
|
172
|
+
elif seq[pos] == "," and bracketdepth == 0:
|
|
173
|
+
items.append(parse_pattern(seq[start:pos], escape))
|
|
174
|
+
start = pos + 1 # skip the comma
|
|
175
|
+
pos += 1
|
|
176
|
+
|
|
177
|
+
if bracketdepth != 0:
|
|
178
|
+
raise UnbalancedBracesError
|
|
179
|
+
if not items:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
# part after the last comma (may be the empty string)
|
|
183
|
+
items.append(parse_pattern(seq[start:], escape))
|
|
184
|
+
return chain(*items)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def make_int_range(left: str, right: str, incr: Optional[str] = None) -> Iterator[str]:
|
|
188
|
+
if any([s.startswith(("0", "-0")) for s in (left, right) if s not in ("0", "-0")]):
|
|
189
|
+
padding = max(len(left), len(right))
|
|
190
|
+
else:
|
|
191
|
+
padding = 0
|
|
192
|
+
step = (int(incr) or 1) if incr else 1
|
|
193
|
+
start = int(left)
|
|
194
|
+
end = int(right)
|
|
195
|
+
r = range(start, end + 1, step) if start < end else range(start, end - 1, -step)
|
|
196
|
+
fmt = "%0{}d".format(padding)
|
|
197
|
+
return (fmt % i for i in r)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def make_char_range(left: str, right: str, incr: Optional[str] = None) -> str:
|
|
201
|
+
step = (int(incr) or 1) if incr else 1
|
|
202
|
+
start = alphabet.index(left)
|
|
203
|
+
end = alphabet.index(right)
|
|
204
|
+
if start < end:
|
|
205
|
+
return alphabet[start : end + 1 : step]
|
|
206
|
+
else:
|
|
207
|
+
end = end or -len(alphabet)
|
|
208
|
+
return alphabet[start : end - 1 : -step]
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
if __name__ == "__main__":
|
|
212
|
+
import doctest
|
|
213
|
+
import sys
|
|
214
|
+
|
|
215
|
+
failed, _ = doctest.testmod(optionflags=doctest.IGNORE_EXCEPTION_DETAIL)
|
|
216
|
+
if failed:
|
|
217
|
+
sys.exit(1)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def autocast_exclude_mps(
|
|
7
|
+
device_type: str, dtype: torch.dtype
|
|
8
|
+
) -> nullcontext | torch.autocast:
|
|
9
|
+
return (
|
|
10
|
+
nullcontext()
|
|
11
|
+
if torch.backends.mps.is_available()
|
|
12
|
+
else torch.autocast(device_type, dtype)
|
|
13
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_latest_checkpoint(path: Path | str) -> Path | None:
|
|
6
|
+
# Find the latest checkpoint
|
|
7
|
+
ckpt_dir = Path(path)
|
|
8
|
+
|
|
9
|
+
if ckpt_dir.exists() is False:
|
|
10
|
+
return None
|
|
11
|
+
|
|
12
|
+
ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
|
|
13
|
+
if len(ckpts) == 0:
|
|
14
|
+
return None
|
|
15
|
+
|
|
16
|
+
return ckpts[-1]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
import hydra
|
|
4
|
+
from omegaconf import DictConfig
|
|
5
|
+
from pytorch_lightning import Callback
|
|
6
|
+
from pytorch_lightning.loggers import Logger
|
|
7
|
+
|
|
8
|
+
from .logger import RankedLogger
|
|
9
|
+
|
|
10
|
+
log = RankedLogger(__name__, rank_zero_only=True)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
|
14
|
+
"""Instantiates callbacks from config."""
|
|
15
|
+
|
|
16
|
+
callbacks: List[Callback] = []
|
|
17
|
+
|
|
18
|
+
if not callbacks_cfg:
|
|
19
|
+
log.warning("No callback configs found! Skipping..")
|
|
20
|
+
return callbacks
|
|
21
|
+
|
|
22
|
+
if not isinstance(callbacks_cfg, DictConfig):
|
|
23
|
+
raise TypeError("Callbacks config must be a DictConfig!")
|
|
24
|
+
|
|
25
|
+
for _, cb_conf in callbacks_cfg.items():
|
|
26
|
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
|
27
|
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
|
28
|
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
|
29
|
+
|
|
30
|
+
return callbacks
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
|
34
|
+
"""Instantiates loggers from config."""
|
|
35
|
+
|
|
36
|
+
logger: List[Logger] = []
|
|
37
|
+
|
|
38
|
+
if not logger_cfg:
|
|
39
|
+
log.warning("No logger configs found! Skipping...")
|
|
40
|
+
return logger
|
|
41
|
+
|
|
42
|
+
if not isinstance(logger_cfg, DictConfig):
|
|
43
|
+
raise TypeError("Logger config must be a DictConfig!")
|
|
44
|
+
|
|
45
|
+
for _, lg_conf in logger_cfg.items():
|
|
46
|
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
|
47
|
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
|
48
|
+
logger.append(hydra.utils.instantiate(lg_conf))
|
|
49
|
+
|
|
50
|
+
return logger
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Mapping, Optional
|
|
3
|
+
|
|
4
|
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class RankedLogger(logging.LoggerAdapter):
|
|
8
|
+
"""A multi-GPU-friendly python command line logger."""
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
name: str = __name__,
|
|
13
|
+
rank_zero_only: bool = True,
|
|
14
|
+
extra: Optional[Mapping[str, object]] = None,
|
|
15
|
+
) -> None:
|
|
16
|
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
|
17
|
+
with their rank prefixed in the log message.
|
|
18
|
+
|
|
19
|
+
:param name: The name of the logger. Default is ``__name__``.
|
|
20
|
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
|
21
|
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
|
22
|
+
"""
|
|
23
|
+
logger = logging.getLogger(name)
|
|
24
|
+
super().__init__(logger=logger, extra=extra)
|
|
25
|
+
self.rank_zero_only = rank_zero_only
|
|
26
|
+
|
|
27
|
+
def log(
|
|
28
|
+
self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
|
31
|
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
|
32
|
+
occur on that rank/process.
|
|
33
|
+
|
|
34
|
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
|
35
|
+
:param msg: The message to log.
|
|
36
|
+
:param rank: The rank to log at.
|
|
37
|
+
:param args: Additional args to pass to the underlying logging function.
|
|
38
|
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
|
39
|
+
"""
|
|
40
|
+
if self.isEnabledFor(level):
|
|
41
|
+
msg, kwargs = self.process(msg, kwargs)
|
|
42
|
+
current_rank = getattr(rank_zero_only, "rank", None)
|
|
43
|
+
if current_rank is None:
|
|
44
|
+
raise RuntimeError(
|
|
45
|
+
"The `rank_zero_only.rank` needs to be set before use"
|
|
46
|
+
)
|
|
47
|
+
msg = rank_prefixed_message(msg, current_rank)
|
|
48
|
+
if self.rank_zero_only:
|
|
49
|
+
if current_rank == 0:
|
|
50
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
51
|
+
else:
|
|
52
|
+
if rank is None:
|
|
53
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
54
|
+
elif current_rank == rank:
|
|
55
|
+
self.logger.log(level, msg, *args, **kwargs)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
2
|
+
|
|
3
|
+
from fish_speech.utils import logger as log
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@rank_zero_only
|
|
7
|
+
def log_hyperparameters(object_dict: dict) -> None:
|
|
8
|
+
"""Controls which config parts are saved by lightning loggers.
|
|
9
|
+
|
|
10
|
+
Additionally saves:
|
|
11
|
+
- Number of model parameters
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
hparams = {}
|
|
15
|
+
|
|
16
|
+
cfg = object_dict["cfg"]
|
|
17
|
+
model = object_dict["model"]
|
|
18
|
+
trainer = object_dict["trainer"]
|
|
19
|
+
|
|
20
|
+
if not trainer.logger:
|
|
21
|
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
hparams["model"] = cfg["model"]
|
|
25
|
+
|
|
26
|
+
# save number of model parameters
|
|
27
|
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
|
28
|
+
hparams["model/params/trainable"] = sum(
|
|
29
|
+
p.numel() for p in model.parameters() if p.requires_grad
|
|
30
|
+
)
|
|
31
|
+
hparams["model/params/non_trainable"] = sum(
|
|
32
|
+
p.numel() for p in model.parameters() if not p.requires_grad
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
hparams["data"] = cfg["data"]
|
|
36
|
+
hparams["trainer"] = cfg["trainer"]
|
|
37
|
+
|
|
38
|
+
hparams["callbacks"] = cfg.get("callbacks")
|
|
39
|
+
hparams["extras"] = cfg.get("extras")
|
|
40
|
+
|
|
41
|
+
hparams["task_name"] = cfg.get("task_name")
|
|
42
|
+
hparams["tags"] = cfg.get("tags")
|
|
43
|
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
|
44
|
+
hparams["seed"] = cfg.get("seed")
|
|
45
|
+
|
|
46
|
+
# send hparams to all loggers
|
|
47
|
+
for logger in trainer.loggers:
|
|
48
|
+
logger.log_hyperparams(hparams)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Sequence
|
|
3
|
+
|
|
4
|
+
import rich
|
|
5
|
+
import rich.syntax
|
|
6
|
+
import rich.tree
|
|
7
|
+
from hydra.core.hydra_config import HydraConfig
|
|
8
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
|
10
|
+
from rich.prompt import Prompt
|
|
11
|
+
|
|
12
|
+
from fish_speech.utils import logger as log
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@rank_zero_only
|
|
16
|
+
def print_config_tree(
|
|
17
|
+
cfg: DictConfig,
|
|
18
|
+
print_order: Sequence[str] = (
|
|
19
|
+
"data",
|
|
20
|
+
"model",
|
|
21
|
+
"callbacks",
|
|
22
|
+
"logger",
|
|
23
|
+
"trainer",
|
|
24
|
+
"paths",
|
|
25
|
+
"extras",
|
|
26
|
+
),
|
|
27
|
+
resolve: bool = False,
|
|
28
|
+
save_to_file: bool = False,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Prints content of DictConfig using Rich library and its tree structure.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
cfg (DictConfig): Configuration composed by Hydra.
|
|
34
|
+
print_order (Sequence[str], optional): Determines in what order config components are printed.
|
|
35
|
+
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
|
36
|
+
save_to_file (bool, optional): Whether to export config to the hydra output folder.
|
|
37
|
+
""" # noqa: E501
|
|
38
|
+
|
|
39
|
+
style = "dim"
|
|
40
|
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
|
41
|
+
|
|
42
|
+
queue = []
|
|
43
|
+
|
|
44
|
+
# add fields from `print_order` to queue
|
|
45
|
+
for field in print_order:
|
|
46
|
+
(
|
|
47
|
+
queue.append(field)
|
|
48
|
+
if field in cfg
|
|
49
|
+
else log.warning(
|
|
50
|
+
f"Field '{field}' not found in config. "
|
|
51
|
+
+ f"Skipping '{field}' config printing..."
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# add all the other fields to queue (not specified in `print_order`)
|
|
56
|
+
for field in cfg:
|
|
57
|
+
if field not in queue:
|
|
58
|
+
queue.append(field)
|
|
59
|
+
|
|
60
|
+
# generate config tree from queue
|
|
61
|
+
for field in queue:
|
|
62
|
+
branch = tree.add(field, style=style, guide_style=style)
|
|
63
|
+
|
|
64
|
+
config_group = cfg[field]
|
|
65
|
+
if isinstance(config_group, DictConfig):
|
|
66
|
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
|
67
|
+
else:
|
|
68
|
+
branch_content = str(config_group)
|
|
69
|
+
|
|
70
|
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
|
71
|
+
|
|
72
|
+
# print config tree
|
|
73
|
+
rich.print(tree)
|
|
74
|
+
|
|
75
|
+
# save config tree to file
|
|
76
|
+
if save_to_file:
|
|
77
|
+
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
|
78
|
+
rich.print(tree, file=file)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@rank_zero_only
|
|
82
|
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
|
83
|
+
"""Prompts user to input tags from command line if no tags are provided in config.""" # noqa: E501
|
|
84
|
+
|
|
85
|
+
if not cfg.get("tags"):
|
|
86
|
+
if "id" in HydraConfig().cfg.hydra.job:
|
|
87
|
+
raise ValueError("Specify tags before launching a multirun!")
|
|
88
|
+
|
|
89
|
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
|
90
|
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
|
91
|
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
|
92
|
+
|
|
93
|
+
with open_dict(cfg):
|
|
94
|
+
cfg.tags = tags
|
|
95
|
+
|
|
96
|
+
log.info(f"Tags: {cfg.tags}")
|
|
97
|
+
|
|
98
|
+
if save_to_file:
|
|
99
|
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
|
100
|
+
rich.print(cfg.tags, file=file)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torchaudio.functional as F
|
|
3
|
+
from torch import Tensor, nn
|
|
4
|
+
from torchaudio.transforms import MelScale
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LinearSpectrogram(nn.Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
n_fft=2048,
|
|
11
|
+
win_length=2048,
|
|
12
|
+
hop_length=512,
|
|
13
|
+
center=False,
|
|
14
|
+
mode="pow2_sqrt",
|
|
15
|
+
):
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
self.n_fft = n_fft
|
|
19
|
+
self.win_length = win_length
|
|
20
|
+
self.hop_length = hop_length
|
|
21
|
+
self.center = center
|
|
22
|
+
self.mode = mode
|
|
23
|
+
|
|
24
|
+
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
|
|
25
|
+
|
|
26
|
+
def forward(self, y: Tensor) -> Tensor:
|
|
27
|
+
if y.ndim == 3:
|
|
28
|
+
y = y.squeeze(1)
|
|
29
|
+
|
|
30
|
+
y = torch.nn.functional.pad(
|
|
31
|
+
y.unsqueeze(1),
|
|
32
|
+
(
|
|
33
|
+
(self.win_length - self.hop_length) // 2,
|
|
34
|
+
(self.win_length - self.hop_length + 1) // 2,
|
|
35
|
+
),
|
|
36
|
+
mode="reflect",
|
|
37
|
+
).squeeze(1)
|
|
38
|
+
|
|
39
|
+
spec = torch.stft(
|
|
40
|
+
y,
|
|
41
|
+
self.n_fft,
|
|
42
|
+
hop_length=self.hop_length,
|
|
43
|
+
win_length=self.win_length,
|
|
44
|
+
window=self.window,
|
|
45
|
+
center=self.center,
|
|
46
|
+
pad_mode="reflect",
|
|
47
|
+
normalized=False,
|
|
48
|
+
onesided=True,
|
|
49
|
+
return_complex=True,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
spec = torch.view_as_real(spec)
|
|
53
|
+
|
|
54
|
+
if self.mode == "pow2_sqrt":
|
|
55
|
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
56
|
+
|
|
57
|
+
return spec
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LogMelSpectrogram(nn.Module):
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
sample_rate=44100,
|
|
64
|
+
n_fft=2048,
|
|
65
|
+
win_length=2048,
|
|
66
|
+
hop_length=512,
|
|
67
|
+
n_mels=128,
|
|
68
|
+
center=False,
|
|
69
|
+
f_min=0.0,
|
|
70
|
+
f_max=None,
|
|
71
|
+
):
|
|
72
|
+
super().__init__()
|
|
73
|
+
|
|
74
|
+
self.sample_rate = sample_rate
|
|
75
|
+
self.n_fft = n_fft
|
|
76
|
+
self.win_length = win_length
|
|
77
|
+
self.hop_length = hop_length
|
|
78
|
+
self.center = center
|
|
79
|
+
self.n_mels = n_mels
|
|
80
|
+
self.f_min = f_min
|
|
81
|
+
self.f_max = f_max or float(sample_rate // 2)
|
|
82
|
+
|
|
83
|
+
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
|
84
|
+
|
|
85
|
+
fb = F.melscale_fbanks(
|
|
86
|
+
n_freqs=self.n_fft // 2 + 1,
|
|
87
|
+
f_min=self.f_min,
|
|
88
|
+
f_max=self.f_max,
|
|
89
|
+
n_mels=self.n_mels,
|
|
90
|
+
sample_rate=self.sample_rate,
|
|
91
|
+
norm="slaney",
|
|
92
|
+
mel_scale="slaney",
|
|
93
|
+
)
|
|
94
|
+
self.register_buffer(
|
|
95
|
+
"fb",
|
|
96
|
+
fb,
|
|
97
|
+
persistent=False,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def compress(self, x: Tensor) -> Tensor:
|
|
101
|
+
return torch.log(torch.clamp(x, min=1e-5))
|
|
102
|
+
|
|
103
|
+
def decompress(self, x: Tensor) -> Tensor:
|
|
104
|
+
return torch.exp(x)
|
|
105
|
+
|
|
106
|
+
def apply_mel_scale(self, x: Tensor) -> Tensor:
|
|
107
|
+
return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
|
|
108
|
+
|
|
109
|
+
def forward(
|
|
110
|
+
self, x: Tensor, return_linear: bool = False, sample_rate: int = None
|
|
111
|
+
) -> Tensor:
|
|
112
|
+
if sample_rate is not None and sample_rate != self.sample_rate:
|
|
113
|
+
x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
|
|
114
|
+
|
|
115
|
+
linear = self.spectrogram(x)
|
|
116
|
+
x = self.apply_mel_scale(linear)
|
|
117
|
+
x = self.compress(x)
|
|
118
|
+
|
|
119
|
+
if return_linear:
|
|
120
|
+
return x, self.compress(linear)
|
|
121
|
+
|
|
122
|
+
return x
|