backgroundremover 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ """
2
+ backgroundremover
3
+
4
+ A library to remove background from videos and images
5
+ """
6
+
7
+ __version__ = "0.3.7"
8
+ __author__ = 'Johnathan Nader'
9
+ __credits__ = 'BackgroundRemoverAI.com'
@@ -0,0 +1,297 @@
1
+ import io
2
+ import os
3
+ import typing
4
+ from PIL import Image, ImageOps
5
+ from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
6
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
7
+ from pymatting.util.util import stack_images
8
+ from scipy.ndimage.morphology import binary_erosion
9
+ from moviepy import VideoFileClip
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional
13
+ import torch.nn.functional
14
+ from hsh.library.hash import Hasher
15
+ from .u2net import detect, u2net
16
+ from . import github
17
+
18
+ # Register HEIC format support
19
+ try:
20
+ from pillow_heif import register_heif_opener
21
+ register_heif_opener()
22
+ except ImportError:
23
+ pass # HEIC support is optional
24
+
25
+ # closes https://github.com/nadermx/backgroundremover/issues/18
26
+ # closes https://github.com/nadermx/backgroundremover/issues/112
27
+ try:
28
+ if torch.cuda.is_available():
29
+ DEVICE = torch.device('cuda:0')
30
+ elif torch.backends.mps.is_available():
31
+ DEVICE = torch.device('mps')
32
+ else:
33
+ DEVICE = torch.device('cpu')
34
+ except Exception as e:
35
+ print(f"Using CPU. Setting Cuda or MPS failed: {e}")
36
+ DEVICE = torch.device('cpu')
37
+
38
+ class Net(torch.nn.Module):
39
+ def __init__(self, model_name):
40
+ super(Net, self).__init__()
41
+ hasher = Hasher()
42
+ model = {
43
+ 'u2netp': (u2net.U2NETP,
44
+ 'e4f636406ca4e2af789941e7f139ee2e',
45
+ '1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy',
46
+ 'U2NET_PATH'),
47
+ 'u2net': (u2net.U2NET,
48
+ '09fb4e49b7f785c9f855baf94916840a',
49
+ '1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ',
50
+ 'U2NET_PATH'),
51
+ 'u2net_human_seg': (u2net.U2NET,
52
+ '347c3d51b01528e5c6c071e3cff1cb55',
53
+ '1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P',
54
+ 'U2NET_PATH')
55
+ }[model_name]
56
+
57
+ if model_name == "u2netp":
58
+ net = u2net.U2NETP(3, 1)
59
+ path = os.environ.get(
60
+ "U2NETP_PATH",
61
+ os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
62
+ )
63
+ if (
64
+ not os.path.exists(path)
65
+ ):
66
+ github.download_files_from_github(
67
+ path, model_name
68
+ )
69
+
70
+ elif model_name == "u2net":
71
+ net = u2net.U2NET(3, 1)
72
+ path = os.environ.get(
73
+ "U2NET_PATH",
74
+ os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
75
+ )
76
+ if (
77
+ not os.path.exists(path)
78
+ #or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
79
+ ):
80
+ github.download_files_from_github(
81
+ path, model_name
82
+ )
83
+
84
+ elif model_name == "u2net_human_seg":
85
+ net = u2net.U2NET(3, 1)
86
+ path = os.environ.get(
87
+ "U2NET_PATH",
88
+ os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
89
+ )
90
+ if (
91
+ not os.path.exists(path)
92
+ #or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
93
+ ):
94
+ github.download_files_from_github(
95
+ path, model_name
96
+ )
97
+ else:
98
+ print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
99
+
100
+ try:
101
+ net.load_state_dict(torch.load(path, map_location=torch.device(DEVICE)))
102
+ net.to(device=DEVICE, dtype=torch.float32, non_blocking=True)
103
+ net.eval()
104
+ self.net = net
105
+ except EOFError:
106
+ print(f"\n{'='*60}")
107
+ print(f"ERROR: Model file appears to be corrupted or incomplete!")
108
+ print(f"Path: {path}")
109
+ print(f"\nThis usually happens when the model download was interrupted.")
110
+ print(f"To fix this:")
111
+ print(f" 1. Delete the corrupted file: rm {path}")
112
+ print(f" 2. Run backgroundremover again to re-download the model")
113
+ print(f"{'='*60}\n")
114
+ raise RuntimeError(f"Corrupted model file at {path}. Please delete it and re-run to download again.")
115
+ except Exception as e:
116
+ print(f"\n{'='*60}")
117
+ print(f"ERROR: Failed to load model '{model_name}'")
118
+ print(f"Path: {path}")
119
+ print(f"Error: {e}")
120
+ print(f"\nIf the error persists:")
121
+ print(f" 1. Try deleting the model file: rm {path}")
122
+ print(f" 2. Run backgroundremover again to re-download")
123
+ print(f" 3. Check if you have enough disk space")
124
+ print(f"{'='*60}\n")
125
+ raise
126
+
127
+ def forward(self, block_input: torch.Tensor):
128
+ image_data = block_input.permute(0, 3, 1, 2)
129
+ original_shape = image_data.shape[2:]
130
+ image_data = torch.nn.functional.interpolate(image_data, (320, 320), mode='bilinear')
131
+ image_data = (image_data / 255 - 0.485) / 0.229
132
+ out = self.net(image_data)[0][:, 0:1]
133
+ ma = torch.max(out)
134
+ mi = torch.min(out)
135
+ out = (out - mi) / (ma - mi) * 255
136
+ out = torch.nn.functional.interpolate(out, original_shape, mode='bilinear')
137
+ out = out[:, 0]
138
+ out = out.to(dtype=torch.uint8, device=torch.device('cpu'), non_blocking=True).detach()
139
+ return out
140
+
141
+
142
+ def alpha_matting_cutout(
143
+ img,
144
+ mask,
145
+ foreground_threshold,
146
+ background_threshold,
147
+ erode_structure_size,
148
+ base_size,
149
+ ):
150
+ size = img.size
151
+
152
+ img.thumbnail((base_size, base_size), Image.LANCZOS)
153
+ mask = mask.resize(img.size, Image.LANCZOS)
154
+
155
+ img = np.asarray(img)
156
+ mask = np.asarray(mask)
157
+
158
+ # guess likely foreground/background
159
+ is_foreground = mask > foreground_threshold
160
+ is_background = mask < background_threshold
161
+
162
+ # erode foreground/background
163
+ structure = None
164
+ if erode_structure_size > 0:
165
+ structure = np.ones((erode_structure_size, erode_structure_size), dtype=np.int64)
166
+
167
+ is_foreground = binary_erosion(is_foreground, structure=structure)
168
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
169
+
170
+ # build trimap
171
+ # 0 = background
172
+ # 128 = unknown
173
+ # 255 = foreground
174
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
175
+ trimap[is_foreground] = 255
176
+ trimap[is_background] = 0
177
+
178
+ # build the cutout image
179
+ img_normalized = img / 255.0
180
+ trimap_normalized = trimap / 255.0
181
+
182
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
183
+ foreground = estimate_foreground_ml(img_normalized, alpha)
184
+ cutout = stack_images(foreground, alpha)
185
+
186
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
187
+ cutout = Image.fromarray(cutout)
188
+ cutout = cutout.resize(size, Image.LANCZOS)
189
+
190
+ return cutout
191
+
192
+
193
+ def naive_cutout(img, mask):
194
+ empty = Image.new("RGBA", (img.size), 0)
195
+ cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
196
+ return cutout
197
+
198
+
199
+ def get_model(model_name):
200
+ if model_name == "u2netp":
201
+ return detect.load_model(model_name="u2netp")
202
+ if model_name == "u2net_human_seg":
203
+ return detect.load_model(model_name="u2net_human_seg")
204
+ else:
205
+ return detect.load_model(model_name="u2net")
206
+
207
+
208
+ def remove(
209
+ data,
210
+ model_name="u2net",
211
+ alpha_matting=False,
212
+ alpha_matting_foreground_threshold=240,
213
+ alpha_matting_background_threshold=10,
214
+ alpha_matting_erode_structure_size=10,
215
+ alpha_matting_base_size=1000,
216
+ only_mask=False,
217
+ background_color=None,
218
+ background_image=None,
219
+ ):
220
+ model = get_model(model_name)
221
+
222
+ if isinstance(data, np.ndarray):
223
+ img = Image.fromarray(data).convert("RGB")
224
+ else:
225
+ try:
226
+ img = Image.open(io.BytesIO(data))
227
+ # Handle EXIF orientation to prevent rotated images (fixes #144)
228
+ img = ImageOps.exif_transpose(img)
229
+ img = img.convert("RGB")
230
+ except Exception as e:
231
+ raise ValueError(f"Invalid image input to `remove()`: {e}")
232
+
233
+ mask = detect.predict(model, np.array(img)).convert("L")
234
+
235
+ # If only_mask is True, return just the mask
236
+ if only_mask:
237
+ bio = io.BytesIO()
238
+ mask.save(bio, "PNG")
239
+ return bio.getbuffer()
240
+
241
+ if alpha_matting:
242
+ cutout = alpha_matting_cutout(
243
+ img,
244
+ mask,
245
+ alpha_matting_foreground_threshold,
246
+ alpha_matting_background_threshold,
247
+ alpha_matting_erode_structure_size,
248
+ alpha_matting_base_size,
249
+ )
250
+ else:
251
+ cutout = naive_cutout(img, mask)
252
+
253
+ # If background_image is specified, composite over that image
254
+ if background_image is not None:
255
+ if isinstance(background_image, np.ndarray):
256
+ bg = Image.fromarray(background_image).convert("RGB")
257
+ else:
258
+ try:
259
+ bg = Image.open(io.BytesIO(background_image))
260
+ # Handle EXIF orientation for background image too
261
+ bg = ImageOps.exif_transpose(bg)
262
+ bg = bg.convert("RGB")
263
+ except Exception as e:
264
+ raise ValueError(f"Invalid background image input: {e}")
265
+
266
+ # Resize background to match cutout size
267
+ bg = bg.resize(cutout.size, Image.LANCZOS)
268
+
269
+ if cutout.mode == 'RGBA':
270
+ bg.paste(cutout, mask=cutout.split()[3])
271
+ cutout = bg
272
+ else:
273
+ cutout = bg
274
+ # If background_color is specified, composite with that color
275
+ elif background_color is not None:
276
+ bg = Image.new("RGB", cutout.size, background_color)
277
+ if cutout.mode == 'RGBA':
278
+ bg.paste(cutout, mask=cutout.split()[3])
279
+ cutout = bg
280
+ else:
281
+ cutout = bg
282
+
283
+ bio = io.BytesIO()
284
+ cutout.save(bio, "PNG")
285
+
286
+ return bio.getbuffer()
287
+
288
+
289
+ def iter_frames(path):
290
+ return VideoFileClip(path).resized(height=320).iter_frames(dtype="uint8")
291
+
292
+
293
+ @torch.no_grad()
294
+ def remove_many(image_data: typing.List[np.array], net: Net):
295
+ image_data = np.stack(image_data)
296
+ image_data = torch.as_tensor(image_data, dtype=torch.float32, device=DEVICE)
297
+ return net(image_data).numpy()
File without changes