pymss 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. pymss/__init__.py +2 -0
  2. pymss/logger.py +93 -0
  3. pymss/modules/__init__.py +0 -0
  4. pymss/modules/bandit/__init__.py +0 -0
  5. pymss/modules/bandit/core/__init__.py +744 -0
  6. pymss/modules/bandit/core/data/__init__.py +2 -0
  7. pymss/modules/bandit/core/data/_types.py +18 -0
  8. pymss/modules/bandit/core/data/augmentation.py +107 -0
  9. pymss/modules/bandit/core/data/augmented.py +35 -0
  10. pymss/modules/bandit/core/data/base.py +69 -0
  11. pymss/modules/bandit/core/data/dnr/__init__.py +0 -0
  12. pymss/modules/bandit/core/data/dnr/datamodule.py +74 -0
  13. pymss/modules/bandit/core/data/dnr/dataset.py +392 -0
  14. pymss/modules/bandit/core/data/dnr/preprocess.py +54 -0
  15. pymss/modules/bandit/core/data/musdb/__init__.py +0 -0
  16. pymss/modules/bandit/core/data/musdb/datamodule.py +77 -0
  17. pymss/modules/bandit/core/data/musdb/dataset.py +280 -0
  18. pymss/modules/bandit/core/data/musdb/preprocess.py +238 -0
  19. pymss/modules/bandit/core/loss/__init__.py +2 -0
  20. pymss/modules/bandit/core/loss/_complex.py +34 -0
  21. pymss/modules/bandit/core/loss/_multistem.py +45 -0
  22. pymss/modules/bandit/core/loss/_timefreq.py +113 -0
  23. pymss/modules/bandit/core/loss/snr.py +146 -0
  24. pymss/modules/bandit/core/metrics/__init__.py +9 -0
  25. pymss/modules/bandit/core/metrics/_squim.py +383 -0
  26. pymss/modules/bandit/core/metrics/snr.py +150 -0
  27. pymss/modules/bandit/core/model/__init__.py +3 -0
  28. pymss/modules/bandit/core/model/_spectral.py +58 -0
  29. pymss/modules/bandit/core/model/bsrnn/__init__.py +23 -0
  30. pymss/modules/bandit/core/model/bsrnn/bandsplit.py +139 -0
  31. pymss/modules/bandit/core/model/bsrnn/core.py +661 -0
  32. pymss/modules/bandit/core/model/bsrnn/maskestim.py +347 -0
  33. pymss/modules/bandit/core/model/bsrnn/tfmodel.py +317 -0
  34. pymss/modules/bandit/core/model/bsrnn/utils.py +583 -0
  35. pymss/modules/bandit/core/model/bsrnn/wrapper.py +882 -0
  36. pymss/modules/bandit/core/utils/__init__.py +0 -0
  37. pymss/modules/bandit/core/utils/audio.py +463 -0
  38. pymss/modules/bandit/model_from_config.py +31 -0
  39. pymss/modules/bandit_v2/__init__.py +0 -0
  40. pymss/modules/bandit_v2/bandit.py +367 -0
  41. pymss/modules/bandit_v2/bandsplit.py +130 -0
  42. pymss/modules/bandit_v2/film.py +25 -0
  43. pymss/modules/bandit_v2/maskestim.py +281 -0
  44. pymss/modules/bandit_v2/tfmodel.py +145 -0
  45. pymss/modules/bandit_v2/utils.py +523 -0
  46. pymss/modules/bs_roformer/__init__.py +2 -0
  47. pymss/modules/bs_roformer/attend.py +126 -0
  48. pymss/modules/bs_roformer/bs_roformer.py +621 -0
  49. pymss/modules/bs_roformer/mel_band_roformer.py +668 -0
  50. pymss/modules/demucs4ht.py +713 -0
  51. pymss/modules/ex_bi_mamba2.py +303 -0
  52. pymss/modules/look2hear/__init__.py +49 -0
  53. pymss/modules/look2hear/apollo.py +324 -0
  54. pymss/modules/look2hear/base_model.py +100 -0
  55. pymss/modules/mdx23c_tfc_tdf_v3.py +242 -0
  56. pymss/modules/scnet/__init__.py +1 -0
  57. pymss/modules/scnet/scnet.py +373 -0
  58. pymss/modules/scnet/separation.py +113 -0
  59. pymss/modules/scnet_unofficial/__init__.py +1 -0
  60. pymss/modules/scnet_unofficial/modules/__init__.py +3 -0
  61. pymss/modules/scnet_unofficial/modules/dualpath_rnn.py +228 -0
  62. pymss/modules/scnet_unofficial/modules/sd_encoder.py +285 -0
  63. pymss/modules/scnet_unofficial/modules/su_decoder.py +241 -0
  64. pymss/modules/scnet_unofficial/scnet.py +249 -0
  65. pymss/modules/scnet_unofficial/utils.py +135 -0
  66. pymss/modules/segm_models.py +255 -0
  67. pymss/modules/torchseg_models.py +255 -0
  68. pymss/modules/ts_bs_mamba2.py +319 -0
  69. pymss/modules/upernet_swin_transformers.py +228 -0
  70. pymss/separator.py +316 -0
  71. pymss/utils.py +230 -0
  72. pymss-1.0.dist-info/LICENSE +21 -0
  73. pymss-1.0.dist-info/METADATA +128 -0
  74. pymss-1.0.dist-info/RECORD +76 -0
  75. pymss-1.0.dist-info/WHEEL +5 -0
  76. pymss-1.0.dist-info/top_level.txt +1 -0
pymss/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .separator import MSSeparator
2
+ from .logger import get_separation_logger
pymss/logger.py ADDED
@@ -0,0 +1,93 @@
1
+ import logging
2
+ import os
3
+ from colorama import Fore, Style, init
4
+ from datetime import datetime
5
+
6
+ init(autoreset=True)
7
+
8
+ MAX_LOG = 100
9
+ LOG_DIR = "logs"
10
+
11
+ class ColorFormatter(logging.Formatter):
12
+ def format(self, record):
13
+ record.pathname = os.path.relpath(record.pathname)
14
+ log_msg = super().format(record)
15
+
16
+ if record.levelname == "INFO":
17
+ log_msg = log_msg.replace("INFO", f"{Fore.GREEN}INFO{Style.RESET_ALL}")
18
+ elif record.levelname == "DEBUG":
19
+ log_msg = log_msg.replace("DEBUG", f"{Fore.BLUE}DEBUG{Style.RESET_ALL}")
20
+ elif record.levelname == "WARNING":
21
+ log_msg = log_msg.replace("WARNING", f"{Fore.YELLOW}WARNING{Style.RESET_ALL}")
22
+ elif record.levelname == "ERROR":
23
+ log_msg = log_msg.replace("ERROR", f"{Fore.RED}ERROR{Style.RESET_ALL}")
24
+ elif record.levelname == "CRITICAL":
25
+ log_msg = log_msg.replace("CRITICAL", f"{Fore.MAGENTA}CRITICAL{Style.RESET_ALL}")
26
+
27
+ return log_msg
28
+
29
+
30
+ def manage_log_files(log_dir, max_log):
31
+ log_files = [f for f in os.listdir(log_dir) if f.endswith(".log")]
32
+
33
+ def parse_date(filename):
34
+ for fmt in ("%Y-%m-%d", "%Y-%m-%d_%H-%M-%S"):
35
+ try:
36
+ return datetime.strptime(filename.split(".")[0], fmt)
37
+ except ValueError:
38
+ continue
39
+ return datetime.min
40
+
41
+ log_files = sorted(log_files, key=parse_date)
42
+
43
+ while len(log_files) > max_log:
44
+ try:
45
+ oldest_file = log_files.pop(0)
46
+ os.remove(os.path.join(log_dir, oldest_file))
47
+ except: pass
48
+
49
+
50
+ def set_log_level(logger, level):
51
+ logger.console_handler.setLevel(level)
52
+
53
+
54
+ def get_separation_logger(console_level=logging.INFO, enable_file_log=False, max_log=MAX_LOG):
55
+ logger = logging.getLogger("logger")
56
+
57
+ if logger.hasHandlers():
58
+ return logger
59
+
60
+ logger.setLevel(logging.DEBUG)
61
+
62
+ console_handler = logging.StreamHandler()
63
+ console_handler.setLevel(console_level)
64
+ formatter = ColorFormatter(fmt="%(asctime)s.%(msecs)03d [%(levelname)s] [%(pathname)s:%(lineno)d] %(message)s", datefmt="%H:%M:%S")
65
+ console_handler.setFormatter(formatter)
66
+
67
+ logger.addHandler(console_handler)
68
+
69
+ if enable_file_log:
70
+ os.makedirs(LOG_DIR, exist_ok=True)
71
+ log_filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S.log")
72
+ file_path = os.path.join(LOG_DIR, log_filename)
73
+
74
+ file_handler = logging.FileHandler(file_path, mode='a', encoding='utf-8')
75
+ file_handler.setLevel(logging.DEBUG)
76
+ file_formatter = logging.Formatter(fmt="%(asctime)s.%(msecs)03d [%(levelname)s] [%(pathname)s:%(lineno)d] %(message)s", datefmt="%H:%M:%S")
77
+ file_handler.setFormatter(file_formatter)
78
+
79
+ logger.addHandler(file_handler)
80
+ manage_log_files(LOG_DIR, max_log)
81
+
82
+ logger.console_handler = console_handler
83
+
84
+ return logger
85
+
86
+
87
+ if __name__ == "__main__":
88
+ logger = get_separation_logger(console_level=logging.INFO, enable_file_log=False)
89
+ logger.debug("This is a debug message.")
90
+ logger.info("This is an info message.")
91
+ logger.warning("This is a warning message.")
92
+ logger.error("This is an error message.")
93
+ logger.critical("This is a critical message.")
File without changes
File without changes