backgroundremover 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- backgroundremover/__init__.py +9 -0
- backgroundremover/bg.py +297 -0
- backgroundremover/cmd/__init__.py +0 -0
- backgroundremover/cmd/cli.py +434 -0
- backgroundremover/cmd/server.py +98 -0
- backgroundremover/github.py +108 -0
- backgroundremover/u2net/__init__.py +0 -0
- backgroundremover/u2net/data_loader.py +324 -0
- backgroundremover/u2net/detect.py +174 -0
- backgroundremover/u2net/u2net.py +541 -0
- backgroundremover/utilities.py +356 -0
- backgroundremover-0.3.7.dist-info/LICENSE.txt +24 -0
- backgroundremover-0.3.7.dist-info/METADATA +650 -0
- backgroundremover-0.3.7.dist-info/RECORD +17 -0
- backgroundremover-0.3.7.dist-info/WHEEL +5 -0
- backgroundremover-0.3.7.dist-info/entry_points.txt +3 -0
- backgroundremover-0.3.7.dist-info/top_level.txt +1 -0
backgroundremover/bg.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import os
|
|
3
|
+
import typing
|
|
4
|
+
from PIL import Image, ImageOps
|
|
5
|
+
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
|
6
|
+
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
7
|
+
from pymatting.util.util import stack_images
|
|
8
|
+
from scipy.ndimage.morphology import binary_erosion
|
|
9
|
+
from moviepy import VideoFileClip
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional
|
|
13
|
+
import torch.nn.functional
|
|
14
|
+
from hsh.library.hash import Hasher
|
|
15
|
+
from .u2net import detect, u2net
|
|
16
|
+
from . import github
|
|
17
|
+
|
|
18
|
+
# Register HEIC format support
|
|
19
|
+
try:
|
|
20
|
+
from pillow_heif import register_heif_opener
|
|
21
|
+
register_heif_opener()
|
|
22
|
+
except ImportError:
|
|
23
|
+
pass # HEIC support is optional
|
|
24
|
+
|
|
25
|
+
# closes https://github.com/nadermx/backgroundremover/issues/18
|
|
26
|
+
# closes https://github.com/nadermx/backgroundremover/issues/112
|
|
27
|
+
try:
|
|
28
|
+
if torch.cuda.is_available():
|
|
29
|
+
DEVICE = torch.device('cuda:0')
|
|
30
|
+
elif torch.backends.mps.is_available():
|
|
31
|
+
DEVICE = torch.device('mps')
|
|
32
|
+
else:
|
|
33
|
+
DEVICE = torch.device('cpu')
|
|
34
|
+
except Exception as e:
|
|
35
|
+
print(f"Using CPU. Setting Cuda or MPS failed: {e}")
|
|
36
|
+
DEVICE = torch.device('cpu')
|
|
37
|
+
|
|
38
|
+
class Net(torch.nn.Module):
|
|
39
|
+
def __init__(self, model_name):
|
|
40
|
+
super(Net, self).__init__()
|
|
41
|
+
hasher = Hasher()
|
|
42
|
+
model = {
|
|
43
|
+
'u2netp': (u2net.U2NETP,
|
|
44
|
+
'e4f636406ca4e2af789941e7f139ee2e',
|
|
45
|
+
'1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy',
|
|
46
|
+
'U2NET_PATH'),
|
|
47
|
+
'u2net': (u2net.U2NET,
|
|
48
|
+
'09fb4e49b7f785c9f855baf94916840a',
|
|
49
|
+
'1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ',
|
|
50
|
+
'U2NET_PATH'),
|
|
51
|
+
'u2net_human_seg': (u2net.U2NET,
|
|
52
|
+
'347c3d51b01528e5c6c071e3cff1cb55',
|
|
53
|
+
'1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P',
|
|
54
|
+
'U2NET_PATH')
|
|
55
|
+
}[model_name]
|
|
56
|
+
|
|
57
|
+
if model_name == "u2netp":
|
|
58
|
+
net = u2net.U2NETP(3, 1)
|
|
59
|
+
path = os.environ.get(
|
|
60
|
+
"U2NETP_PATH",
|
|
61
|
+
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
62
|
+
)
|
|
63
|
+
if (
|
|
64
|
+
not os.path.exists(path)
|
|
65
|
+
):
|
|
66
|
+
github.download_files_from_github(
|
|
67
|
+
path, model_name
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
elif model_name == "u2net":
|
|
71
|
+
net = u2net.U2NET(3, 1)
|
|
72
|
+
path = os.environ.get(
|
|
73
|
+
"U2NET_PATH",
|
|
74
|
+
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
75
|
+
)
|
|
76
|
+
if (
|
|
77
|
+
not os.path.exists(path)
|
|
78
|
+
#or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
|
79
|
+
):
|
|
80
|
+
github.download_files_from_github(
|
|
81
|
+
path, model_name
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
elif model_name == "u2net_human_seg":
|
|
85
|
+
net = u2net.U2NET(3, 1)
|
|
86
|
+
path = os.environ.get(
|
|
87
|
+
"U2NET_PATH",
|
|
88
|
+
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
89
|
+
)
|
|
90
|
+
if (
|
|
91
|
+
not os.path.exists(path)
|
|
92
|
+
#or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
|
93
|
+
):
|
|
94
|
+
github.download_files_from_github(
|
|
95
|
+
path, model_name
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
net.load_state_dict(torch.load(path, map_location=torch.device(DEVICE)))
|
|
102
|
+
net.to(device=DEVICE, dtype=torch.float32, non_blocking=True)
|
|
103
|
+
net.eval()
|
|
104
|
+
self.net = net
|
|
105
|
+
except EOFError:
|
|
106
|
+
print(f"\n{'='*60}")
|
|
107
|
+
print(f"ERROR: Model file appears to be corrupted or incomplete!")
|
|
108
|
+
print(f"Path: {path}")
|
|
109
|
+
print(f"\nThis usually happens when the model download was interrupted.")
|
|
110
|
+
print(f"To fix this:")
|
|
111
|
+
print(f" 1. Delete the corrupted file: rm {path}")
|
|
112
|
+
print(f" 2. Run backgroundremover again to re-download the model")
|
|
113
|
+
print(f"{'='*60}\n")
|
|
114
|
+
raise RuntimeError(f"Corrupted model file at {path}. Please delete it and re-run to download again.")
|
|
115
|
+
except Exception as e:
|
|
116
|
+
print(f"\n{'='*60}")
|
|
117
|
+
print(f"ERROR: Failed to load model '{model_name}'")
|
|
118
|
+
print(f"Path: {path}")
|
|
119
|
+
print(f"Error: {e}")
|
|
120
|
+
print(f"\nIf the error persists:")
|
|
121
|
+
print(f" 1. Try deleting the model file: rm {path}")
|
|
122
|
+
print(f" 2. Run backgroundremover again to re-download")
|
|
123
|
+
print(f" 3. Check if you have enough disk space")
|
|
124
|
+
print(f"{'='*60}\n")
|
|
125
|
+
raise
|
|
126
|
+
|
|
127
|
+
def forward(self, block_input: torch.Tensor):
|
|
128
|
+
image_data = block_input.permute(0, 3, 1, 2)
|
|
129
|
+
original_shape = image_data.shape[2:]
|
|
130
|
+
image_data = torch.nn.functional.interpolate(image_data, (320, 320), mode='bilinear')
|
|
131
|
+
image_data = (image_data / 255 - 0.485) / 0.229
|
|
132
|
+
out = self.net(image_data)[0][:, 0:1]
|
|
133
|
+
ma = torch.max(out)
|
|
134
|
+
mi = torch.min(out)
|
|
135
|
+
out = (out - mi) / (ma - mi) * 255
|
|
136
|
+
out = torch.nn.functional.interpolate(out, original_shape, mode='bilinear')
|
|
137
|
+
out = out[:, 0]
|
|
138
|
+
out = out.to(dtype=torch.uint8, device=torch.device('cpu'), non_blocking=True).detach()
|
|
139
|
+
return out
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def alpha_matting_cutout(
|
|
143
|
+
img,
|
|
144
|
+
mask,
|
|
145
|
+
foreground_threshold,
|
|
146
|
+
background_threshold,
|
|
147
|
+
erode_structure_size,
|
|
148
|
+
base_size,
|
|
149
|
+
):
|
|
150
|
+
size = img.size
|
|
151
|
+
|
|
152
|
+
img.thumbnail((base_size, base_size), Image.LANCZOS)
|
|
153
|
+
mask = mask.resize(img.size, Image.LANCZOS)
|
|
154
|
+
|
|
155
|
+
img = np.asarray(img)
|
|
156
|
+
mask = np.asarray(mask)
|
|
157
|
+
|
|
158
|
+
# guess likely foreground/background
|
|
159
|
+
is_foreground = mask > foreground_threshold
|
|
160
|
+
is_background = mask < background_threshold
|
|
161
|
+
|
|
162
|
+
# erode foreground/background
|
|
163
|
+
structure = None
|
|
164
|
+
if erode_structure_size > 0:
|
|
165
|
+
structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int64)
|
|
166
|
+
|
|
167
|
+
is_foreground = binary_erosion(is_foreground, structure=structure)
|
|
168
|
+
is_background = binary_erosion(is_background, structure=structure, border_value=1)
|
|
169
|
+
|
|
170
|
+
# build trimap
|
|
171
|
+
# 0 = background
|
|
172
|
+
# 128 = unknown
|
|
173
|
+
# 255 = foreground
|
|
174
|
+
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
|
|
175
|
+
trimap[is_foreground] = 255
|
|
176
|
+
trimap[is_background] = 0
|
|
177
|
+
|
|
178
|
+
# build the cutout image
|
|
179
|
+
img_normalized = img / 255.0
|
|
180
|
+
trimap_normalized = trimap / 255.0
|
|
181
|
+
|
|
182
|
+
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
|
|
183
|
+
foreground = estimate_foreground_ml(img_normalized, alpha)
|
|
184
|
+
cutout = stack_images(foreground, alpha)
|
|
185
|
+
|
|
186
|
+
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
|
|
187
|
+
cutout = Image.fromarray(cutout)
|
|
188
|
+
cutout = cutout.resize(size, Image.LANCZOS)
|
|
189
|
+
|
|
190
|
+
return cutout
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def naive_cutout(img, mask):
|
|
194
|
+
empty = Image.new("RGBA", (img.size), 0)
|
|
195
|
+
cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
|
|
196
|
+
return cutout
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def get_model(model_name):
|
|
200
|
+
if model_name == "u2netp":
|
|
201
|
+
return detect.load_model(model_name="u2netp")
|
|
202
|
+
if model_name == "u2net_human_seg":
|
|
203
|
+
return detect.load_model(model_name="u2net_human_seg")
|
|
204
|
+
else:
|
|
205
|
+
return detect.load_model(model_name="u2net")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def remove(
|
|
209
|
+
data,
|
|
210
|
+
model_name="u2net",
|
|
211
|
+
alpha_matting=False,
|
|
212
|
+
alpha_matting_foreground_threshold=240,
|
|
213
|
+
alpha_matting_background_threshold=10,
|
|
214
|
+
alpha_matting_erode_structure_size=10,
|
|
215
|
+
alpha_matting_base_size=1000,
|
|
216
|
+
only_mask=False,
|
|
217
|
+
background_color=None,
|
|
218
|
+
background_image=None,
|
|
219
|
+
):
|
|
220
|
+
model = get_model(model_name)
|
|
221
|
+
|
|
222
|
+
if isinstance(data, np.ndarray):
|
|
223
|
+
img = Image.fromarray(data).convert("RGB")
|
|
224
|
+
else:
|
|
225
|
+
try:
|
|
226
|
+
img = Image.open(io.BytesIO(data))
|
|
227
|
+
# Handle EXIF orientation to prevent rotated images (fixes #144)
|
|
228
|
+
img = ImageOps.exif_transpose(img)
|
|
229
|
+
img = img.convert("RGB")
|
|
230
|
+
except Exception as e:
|
|
231
|
+
raise ValueError(f"Invalid image input to `remove()`: {e}")
|
|
232
|
+
|
|
233
|
+
mask = detect.predict(model, np.array(img)).convert("L")
|
|
234
|
+
|
|
235
|
+
# If only_mask is True, return just the mask
|
|
236
|
+
if only_mask:
|
|
237
|
+
bio = io.BytesIO()
|
|
238
|
+
mask.save(bio, "PNG")
|
|
239
|
+
return bio.getbuffer()
|
|
240
|
+
|
|
241
|
+
if alpha_matting:
|
|
242
|
+
cutout = alpha_matting_cutout(
|
|
243
|
+
img,
|
|
244
|
+
mask,
|
|
245
|
+
alpha_matting_foreground_threshold,
|
|
246
|
+
alpha_matting_background_threshold,
|
|
247
|
+
alpha_matting_erode_structure_size,
|
|
248
|
+
alpha_matting_base_size,
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
cutout = naive_cutout(img, mask)
|
|
252
|
+
|
|
253
|
+
# If background_image is specified, composite over that image
|
|
254
|
+
if background_image is not None:
|
|
255
|
+
if isinstance(background_image, np.ndarray):
|
|
256
|
+
bg = Image.fromarray(background_image).convert("RGB")
|
|
257
|
+
else:
|
|
258
|
+
try:
|
|
259
|
+
bg = Image.open(io.BytesIO(background_image))
|
|
260
|
+
# Handle EXIF orientation for background image too
|
|
261
|
+
bg = ImageOps.exif_transpose(bg)
|
|
262
|
+
bg = bg.convert("RGB")
|
|
263
|
+
except Exception as e:
|
|
264
|
+
raise ValueError(f"Invalid background image input: {e}")
|
|
265
|
+
|
|
266
|
+
# Resize background to match cutout size
|
|
267
|
+
bg = bg.resize(cutout.size, Image.LANCZOS)
|
|
268
|
+
|
|
269
|
+
if cutout.mode == 'RGBA':
|
|
270
|
+
bg.paste(cutout, mask=cutout.split()[3])
|
|
271
|
+
cutout = bg
|
|
272
|
+
else:
|
|
273
|
+
cutout = bg
|
|
274
|
+
# If background_color is specified, composite with that color
|
|
275
|
+
elif background_color is not None:
|
|
276
|
+
bg = Image.new("RGB", cutout.size, background_color)
|
|
277
|
+
if cutout.mode == 'RGBA':
|
|
278
|
+
bg.paste(cutout, mask=cutout.split()[3])
|
|
279
|
+
cutout = bg
|
|
280
|
+
else:
|
|
281
|
+
cutout = bg
|
|
282
|
+
|
|
283
|
+
bio = io.BytesIO()
|
|
284
|
+
cutout.save(bio, "PNG")
|
|
285
|
+
|
|
286
|
+
return bio.getbuffer()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def iter_frames(path):
|
|
290
|
+
return VideoFileClip(path).resized(height=320).iter_frames(dtype="uint8")
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@torch.no_grad()
|
|
294
|
+
def remove_many(image_data: typing.List[np.array], net: Net):
|
|
295
|
+
image_data = np.stack(image_data)
|
|
296
|
+
image_data = torch.as_tensor(image_data, dtype=torch.float32, device=DEVICE)
|
|
297
|
+
return net(image_data).numpy()
|
|
File without changes
|