bizyengine 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. bizyengine/__init__.py +35 -0
  2. bizyengine/bizy_server/__init__.py +7 -0
  3. bizyengine/bizy_server/api_client.py +763 -0
  4. bizyengine/bizy_server/errno.py +122 -0
  5. bizyengine/bizy_server/error_handler.py +3 -0
  6. bizyengine/bizy_server/execution.py +55 -0
  7. bizyengine/bizy_server/resp.py +24 -0
  8. bizyengine/bizy_server/server.py +898 -0
  9. bizyengine/bizy_server/utils.py +93 -0
  10. bizyengine/bizyair_extras/__init__.py +24 -0
  11. bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
  12. bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
  13. bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
  14. bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
  15. bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
  16. bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
  17. bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
  18. bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
  19. bizyengine/bizyair_extras/nodes_dataset.py +99 -0
  20. bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
  21. bizyengine/bizyair_extras/nodes_flux.py +69 -0
  22. bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
  23. bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
  24. bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
  25. bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
  26. bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
  27. bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
  28. bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
  29. bizyengine/bizyair_extras/nodes_sd3.py +52 -0
  30. bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
  31. bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
  32. bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
  33. bizyengine/bizyair_extras/nodes_trellis.py +199 -0
  34. bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
  35. bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
  36. bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
  37. bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
  38. bizyengine/core/__init__.py +8 -0
  39. bizyengine/core/commands/__init__.py +1 -0
  40. bizyengine/core/commands/base.py +27 -0
  41. bizyengine/core/commands/invoker.py +4 -0
  42. bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
  43. bizyengine/core/commands/processors/prompt_processor.py +123 -0
  44. bizyengine/core/commands/servers/model_server.py +0 -0
  45. bizyengine/core/commands/servers/prompt_server.py +234 -0
  46. bizyengine/core/common/__init__.py +8 -0
  47. bizyengine/core/common/caching.py +198 -0
  48. bizyengine/core/common/client.py +262 -0
  49. bizyengine/core/common/env_var.py +101 -0
  50. bizyengine/core/common/utils.py +93 -0
  51. bizyengine/core/configs/conf.py +112 -0
  52. bizyengine/core/configs/models.json +101 -0
  53. bizyengine/core/configs/models.yaml +329 -0
  54. bizyengine/core/data_types.py +20 -0
  55. bizyengine/core/image_utils.py +288 -0
  56. bizyengine/core/nodes_base.py +159 -0
  57. bizyengine/core/nodes_io.py +97 -0
  58. bizyengine/core/path_utils/__init__.py +9 -0
  59. bizyengine/core/path_utils/path_manager.py +276 -0
  60. bizyengine/core/path_utils/utils.py +34 -0
  61. bizyengine/misc/__init__.py +0 -0
  62. bizyengine/misc/auth.py +83 -0
  63. bizyengine/misc/llm.py +431 -0
  64. bizyengine/misc/mzkolors.py +93 -0
  65. bizyengine/misc/nodes.py +1208 -0
  66. bizyengine/misc/nodes_controlnet_aux.py +491 -0
  67. bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
  68. bizyengine/misc/route_sam.py +60 -0
  69. bizyengine/misc/segment_anything.py +276 -0
  70. bizyengine/misc/supernode.py +182 -0
  71. bizyengine/misc/utils.py +218 -0
  72. bizyengine/version.txt +1 -0
  73. bizyengine-0.4.2.dist-info/METADATA +12 -0
  74. bizyengine-0.4.2.dist-info/RECORD +76 -0
  75. bizyengine-0.4.2.dist-info/WHEEL +5 -0
  76. bizyengine-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,81 @@
1
+ from bizyengine.core import BizyAirBaseNode
2
+
3
+
4
+ class JanusModelLoader(BizyAirBaseNode):
5
+
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return {
9
+ "required": {
10
+ "model_name": (
11
+ ["deepseek-ai/Janus-Pro-7B"],
12
+ ), # "deepseek-ai/Janus-Pro-1B",
13
+ },
14
+ }
15
+
16
+ RETURN_TYPES = ("BIZYAIR_JANUS_MODEL", "BIZYAIR_JANUS_PROCESSOR")
17
+ RETURN_NAMES = ("model", "processor")
18
+ # FUNCTION = "load_model"
19
+ CATEGORY = "Janus-Pro"
20
+
21
+
22
+ class JanusImageUnderstanding(BizyAirBaseNode):
23
+ @classmethod
24
+ def INPUT_TYPES(s):
25
+ return {
26
+ "required": {
27
+ "model": ("BIZYAIR_JANUS_MODEL",),
28
+ "processor": ("BIZYAIR_JANUS_PROCESSOR",),
29
+ "image": ("IMAGE",),
30
+ "question": (
31
+ "STRING",
32
+ {"multiline": True, "default": "Describe this image in detail."},
33
+ ),
34
+ "seed": (
35
+ "INT",
36
+ {"default": 666666666666666, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
37
+ ),
38
+ "temperature": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0}),
39
+ "top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0}),
40
+ "max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048}),
41
+ },
42
+ }
43
+
44
+ RETURN_TYPES = ("STRING",)
45
+ RETURN_NAMES = ("text",)
46
+ # FUNCTION = "analyze_image"
47
+ CATEGORY = "Janus-Pro"
48
+
49
+
50
+ class JanusImageGeneration(BizyAirBaseNode):
51
+ @classmethod
52
+ def INPUT_TYPES(s):
53
+ return {
54
+ "required": {
55
+ "model": ("BIZYAIR_JANUS_MODEL",),
56
+ "processor": ("BIZYAIR_JANUS_PROCESSOR",),
57
+ "prompt": (
58
+ "STRING",
59
+ {"multiline": True, "default": "A beautiful photo of"},
60
+ ),
61
+ "seed": (
62
+ "INT",
63
+ {"default": 666666666666666, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
64
+ ),
65
+ "batch_size": ("INT", {"default": 4, "min": 4, "max": 9}),
66
+ "cfg_weight": (
67
+ "FLOAT",
68
+ {"default": 5.0, "min": 1.0, "max": 10.0, "step": 0.5},
69
+ ),
70
+ "temperature": (
71
+ "FLOAT",
72
+ {"default": 1.0, "min": 0.1, "max": 2.0, "step": 0.1},
73
+ ),
74
+ "top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0}),
75
+ },
76
+ }
77
+
78
+ RETURN_TYPES = ("IMAGE",)
79
+ RETURN_NAMES = ("images",)
80
+ # FUNCTION = "generate_images"
81
+ CATEGORY = "Janus-Pro"
@@ -0,0 +1,86 @@
1
+ import os
2
+
3
+ from bizyengine.core import BizyAirBaseNode, BizyAirNodeIO, create_node_data
4
+ from bizyengine.core import path_utils as folder_paths
5
+ from bizyengine.core.data_types import CLIP, CONDITIONING, CONTROL_NET, MODEL
6
+
7
+ AUTHOR_NAME = "MinusZone"
8
+ CATEGORY_NAME = f"Kolors"
9
+
10
+
11
+ class MZ_KolorsUNETLoaderV2(BizyAirBaseNode):
12
+ @classmethod
13
+ def INPUT_TYPES(s):
14
+ return {
15
+ "required": {
16
+ "unet_name": (folder_paths.get_filename_list("unet"),),
17
+ }
18
+ }
19
+
20
+ RETURN_TYPES = (MODEL,)
21
+ RETURN_NAMES = ("model",)
22
+
23
+ FUNCTION = "load_unet"
24
+
25
+ CATEGORY = CATEGORY_NAME
26
+ NODE_DISPLAY_NAME = f"{AUTHOR_NAME} - KolorsUNETLoaderV2"
27
+
28
+ def load_unet(self, **kwargs):
29
+
30
+ node_data = create_node_data(
31
+ class_type="MZ_KolorsUNETLoaderV2",
32
+ inputs=kwargs,
33
+ outputs={"slot_index": 0},
34
+ )
35
+ config_file = folder_paths.guess_config(unet_name=kwargs["unet_name"])
36
+ out = BizyAirNodeIO(
37
+ self.assigned_id, {self.assigned_id: node_data}, config_file=config_file
38
+ )
39
+ return (out,)
40
+
41
+
42
+ WEIGHT_TYPES = [
43
+ "linear",
44
+ "ease in",
45
+ "ease out",
46
+ "ease in-out",
47
+ "reverse in-out",
48
+ "weak input",
49
+ "weak output",
50
+ "weak middle",
51
+ "strong middle",
52
+ "style transfer",
53
+ "composition",
54
+ "strong style transfer",
55
+ "style and composition",
56
+ "style transfer precise",
57
+ "composition precise",
58
+ ]
59
+
60
+
61
+ class MZ_KolorsControlNetLoader(BizyAirBaseNode):
62
+ @classmethod
63
+ def INPUT_TYPES(s):
64
+ return {
65
+ "required": {
66
+ "control_net_name": (folder_paths.get_filename_list("controlnet"),),
67
+ # "seed": ("INT", {"default": 0, "min": 0, "max": 1000000}),
68
+ }
69
+ }
70
+
71
+ RETURN_TYPES = (CONTROL_NET,)
72
+ RETURN_NAMES = ("ControlNet",)
73
+ FUNCTION = "load_controlnet"
74
+
75
+ CATEGORY = CATEGORY_NAME
76
+ NODE_DISPLAY_NAME = f"{AUTHOR_NAME} - KolorsControlNetLoader"
77
+
78
+ def load_controlnet(self, **kwargs):
79
+ node_data = create_node_data(
80
+ class_type="MZ_KolorsControlNetLoader",
81
+ inputs=kwargs,
82
+ outputs={"slot_index": 0},
83
+ )
84
+ assigned_id = self.assigned_id
85
+ node = BizyAirNodeIO(assigned_id, {assigned_id: node_data})
86
+ return (node,)
@@ -0,0 +1,62 @@
1
+ # ComfyUI/comfy_extras/nodes_model_advanced.py
2
+ import nodes
3
+ from bizyengine.core import BizyAirBaseNode, BizyAirNodeIO, data_types
4
+
5
+
6
+ class ModelSamplingSD3(BizyAirBaseNode):
7
+ @classmethod
8
+ def INPUT_TYPES(s):
9
+ return {
10
+ "required": {
11
+ "model": (data_types.MODEL,),
12
+ "shift": (
13
+ "FLOAT",
14
+ {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01},
15
+ ),
16
+ }
17
+ }
18
+
19
+ RETURN_TYPES = (data_types.MODEL,)
20
+ # FUNCTION = "patch"
21
+
22
+ CATEGORY = "advanced/model"
23
+
24
+
25
+ class ModelSamplingFlux(BizyAirBaseNode):
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {
29
+ "required": {
30
+ "model": (data_types.MODEL,),
31
+ "max_shift": (
32
+ "FLOAT",
33
+ {"default": 1.15, "min": 0.0, "max": 100.0, "step": 0.01},
34
+ ),
35
+ "base_shift": (
36
+ "FLOAT",
37
+ {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01},
38
+ ),
39
+ "width": (
40
+ "INT",
41
+ {
42
+ "default": 1024,
43
+ "min": 16,
44
+ "max": nodes.MAX_RESOLUTION,
45
+ "step": 8,
46
+ },
47
+ ),
48
+ "height": (
49
+ "INT",
50
+ {
51
+ "default": 1024,
52
+ "min": 16,
53
+ "max": nodes.MAX_RESOLUTION,
54
+ "step": 8,
55
+ },
56
+ ),
57
+ }
58
+ }
59
+
60
+ RETURN_TYPES = (data_types.MODEL,)
61
+ # FUNCTION = "patch"
62
+ CATEGORY = "advanced/model"
@@ -0,0 +1,52 @@
1
+ # sd3.5
2
+ from bizyengine.core import BizyAirBaseNode, BizyAirNodeIO, data_types
3
+ from bizyengine.core.path_utils import path_manager as folder_paths
4
+
5
+
6
+ class TripleCLIPLoader(BizyAirBaseNode):
7
+ @classmethod
8
+ def INPUT_TYPES(s):
9
+ return {
10
+ "required": {
11
+ "clip_name1": (folder_paths.get_filename_list("clip"),),
12
+ "clip_name2": (folder_paths.get_filename_list("clip"),),
13
+ "clip_name3": (folder_paths.get_filename_list("clip"),),
14
+ }
15
+ }
16
+
17
+ RETURN_TYPES = (data_types.CLIP,)
18
+ # FUNCTION = "load_clip"
19
+
20
+ CATEGORY = "advanced/loaders"
21
+
22
+
23
+ class ControlNetApplySD3(BizyAirBaseNode):
24
+ @classmethod
25
+ def INPUT_TYPES(s):
26
+ return {
27
+ "required": {
28
+ "positive": (data_types.CONDITIONING,),
29
+ "negative": (data_types.CONDITIONING,),
30
+ "control_net": (data_types.CONTROL_NET,),
31
+ "vae": (data_types.VAE,),
32
+ "image": ("IMAGE",),
33
+ "strength": (
34
+ "FLOAT",
35
+ {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01},
36
+ ),
37
+ "start_percent": (
38
+ "FLOAT",
39
+ {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001},
40
+ ),
41
+ "end_percent": (
42
+ "FLOAT",
43
+ {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001},
44
+ ),
45
+ }
46
+ }
47
+
48
+ CATEGORY = "conditioning/controlnet"
49
+ # DEPRECATED = True
50
+ NODE_DISPLAY_NAME = "Apply Controlnet with VAE"
51
+ RETURN_TYPES = (data_types.CONDITIONING, data_types.CONDITIONING)
52
+ RETURN_NAMES = ("positive", "negative")
@@ -0,0 +1,256 @@
1
+ from bizyengine.core import BizyAirBaseNode
2
+
3
+ from .nodes_segment_anything_utils import *
4
+
5
+
6
+ class BizyAir_SAMModelLoader(BizyAirBaseNode):
7
+ @classmethod
8
+ def INPUT_TYPES(cls):
9
+ return {
10
+ "required": {
11
+ "model_name": (list_sam_model(),),
12
+ }
13
+ }
14
+
15
+ CATEGORY = "☁️BizyAir/segment-anything"
16
+ # FUNCTION = "main"
17
+ RETURN_TYPES = ("SAM_PREDICTOR",)
18
+ NODE_DISPLAY_NAME = "☁️BizyAir Load SAM Model"
19
+
20
+
21
+ class BizyAir_GroundingDinoModelLoader(BizyAirBaseNode):
22
+ @classmethod
23
+ def INPUT_TYPES(cls):
24
+ return {
25
+ "required": {
26
+ "model_name": (list_groundingdino_model(),),
27
+ }
28
+ }
29
+
30
+ CATEGORY = "☁️BizyAir/segment-anything"
31
+ # FUNCTION = "main"
32
+ RETURN_TYPES = ("GROUNDING_DINO_MODEL",)
33
+ NODE_DISPLAY_NAME = "☁️BizyAir Load GroundingDino Model"
34
+
35
+
36
+ class BizyAir_VITMatteModelLoader(BizyAirBaseNode):
37
+ @classmethod
38
+ def INPUT_TYPES(cls):
39
+ method_list = [
40
+ "VITMatte",
41
+ "VITMatte(local)",
42
+ ]
43
+ return {
44
+ "required": {
45
+ "detail_method": (method_list,),
46
+ }
47
+ }
48
+
49
+ CATEGORY = "☁️BizyAir/segment-anything"
50
+ # FUNCTION = "main"
51
+ RETURN_TYPES = (
52
+ "VitMatte_MODEL",
53
+ "VitMatte_predictor",
54
+ )
55
+ NODE_DISPLAY_NAME = "☁️BizyAir Load VITMatte Model"
56
+
57
+
58
+ class BizyAir_GroundingDinoSAMSegment(BizyAirBaseNode):
59
+ @classmethod
60
+ def INPUT_TYPES(cls):
61
+ return {
62
+ "required": {
63
+ "grounding_dino_model": ("GROUNDING_DINO_MODEL", {}),
64
+ "sam_predictor": ("SAM_PREDICTOR", {}),
65
+ "image": ("IMAGE", {}),
66
+ "prompt": ("STRING", {}),
67
+ "box_threshold": (
68
+ "FLOAT",
69
+ {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
70
+ ),
71
+ "text_threshold": (
72
+ "FLOAT",
73
+ {"default": 0.3, "min": 0, "max": 1.0, "step": 0.01},
74
+ ),
75
+ }
76
+ }
77
+
78
+ CATEGORY = "☁️BizyAir/segment-anything"
79
+ # FUNCTION = "main"
80
+ RETURN_TYPES = ("IMAGE", "MASK")
81
+ NODE_DISPLAY_NAME = "☁️BizyAir GroundingDinoSAMSegment"
82
+
83
+
84
+ class BizyAir_TrimapGenerate(BizyAirBaseNode):
85
+ @classmethod
86
+ def INPUT_TYPES(cls):
87
+ return {
88
+ "required": {
89
+ "mask": ("MASK",),
90
+ "detail_erode": (
91
+ "INT",
92
+ {"default": 6, "min": 1, "max": 255, "step": 1},
93
+ ),
94
+ "detail_dilate": (
95
+ "INT",
96
+ {"default": 6, "min": 1, "max": 255, "step": 1},
97
+ ),
98
+ }
99
+ }
100
+
101
+ CATEGORY = "☁️BizyAir/segment-anything"
102
+ # FUNCTION = "main"
103
+ RETURN_TYPES = ("MASK",)
104
+ RETURN_NAMES = ("trimap",)
105
+ NODE_DISPLAY_NAME = "☁️BizyAir Trimap Generate"
106
+
107
+
108
+ class BizyAir_VITMattePredict(BizyAirBaseNode):
109
+ @classmethod
110
+ def INPUT_TYPES(cls):
111
+ return {
112
+ "required": {
113
+ "image": ("IMAGE", {}),
114
+ "trimap": ("MASK",),
115
+ "vitmatte_model": ("VitMatte_MODEL", {}),
116
+ "vitmatte_predictor": ("VitMatte_predictor", {}),
117
+ "black_point": (
118
+ "FLOAT",
119
+ {
120
+ "default": 0.15,
121
+ "min": 0.01,
122
+ "max": 0.98,
123
+ "step": 0.01,
124
+ "display": "slider",
125
+ },
126
+ ),
127
+ "white_point": (
128
+ "FLOAT",
129
+ {
130
+ "default": 0.99,
131
+ "min": 0.02,
132
+ "max": 0.99,
133
+ "step": 0.01,
134
+ "display": "slider",
135
+ },
136
+ ),
137
+ "max_megapixels": (
138
+ "FLOAT",
139
+ {"default": 2.0, "min": 1, "max": 999, "step": 0.1},
140
+ ),
141
+ }
142
+ }
143
+
144
+ CATEGORY = "☁️BizyAir/segment-anything"
145
+ # FUNCTION = "main"
146
+ RETURN_TYPES = (
147
+ "IMAGE",
148
+ "MASK",
149
+ )
150
+ RETURN_NAMES = (
151
+ "image",
152
+ "mask",
153
+ )
154
+ NODE_DISPLAY_NAME = "☁️BizyAir VITMatte Predict"
155
+
156
+
157
+ class BizyAirDetailMethodPredict(BizyAirBaseNode):
158
+ NODE_DISPLAY_NAME = "☁️BizyAir DetailMethod Predict"
159
+
160
+ @classmethod
161
+ def INPUT_TYPES(cls):
162
+
163
+ method_list = [
164
+ "PyMatting",
165
+ ]
166
+ return {
167
+ "required": {
168
+ "image": ("IMAGE", {}),
169
+ "mask": ("MASK",),
170
+ "detail_method": (method_list,),
171
+ "detail_erode": (
172
+ "INT",
173
+ {"default": 6, "min": 1, "max": 255, "step": 1},
174
+ ),
175
+ "detail_dilate": (
176
+ "INT",
177
+ {"default": 6, "min": 1, "max": 255, "step": 1},
178
+ ),
179
+ "black_point": (
180
+ "FLOAT",
181
+ {
182
+ "default": 0.15,
183
+ "min": 0.01,
184
+ "max": 0.98,
185
+ "step": 0.01,
186
+ "display": "slider",
187
+ },
188
+ ),
189
+ "white_point": (
190
+ "FLOAT",
191
+ {
192
+ "default": 0.99,
193
+ "min": 0.02,
194
+ "max": 0.99,
195
+ "step": 0.01,
196
+ "display": "slider",
197
+ },
198
+ ),
199
+ }
200
+ }
201
+
202
+ CATEGORY = "☁️BizyAir/segment-anything"
203
+ FUNCTION = "main"
204
+ RETURN_TYPES = (
205
+ "IMAGE",
206
+ "MASK",
207
+ )
208
+ RETURN_NAMES = (
209
+ "image",
210
+ "mask",
211
+ )
212
+
213
+ def main(
214
+ self,
215
+ image,
216
+ mask,
217
+ detail_method,
218
+ detail_erode,
219
+ detail_dilate,
220
+ black_point,
221
+ white_point,
222
+ ):
223
+
224
+ ret_images = []
225
+ ret_masks = []
226
+ # device = comfy.model_management.get_torch_device()
227
+
228
+ for i in range(image.shape[0]):
229
+ img = torch.unsqueeze(image[i], 0)
230
+ img = pil2tensor(tensor2pil(img).convert("RGB"))
231
+ _image = tensor2pil(img).convert("RGBA")
232
+
233
+ detail_range = detail_erode + detail_dilate
234
+
235
+ if detail_method == "PyMatting":
236
+ _mask = tensor2pil(
237
+ mask_edge_detail(
238
+ img, mask[i], detail_range // 8 + 1, black_point, white_point
239
+ )
240
+ )
241
+
242
+ _image = RGB2RGBA(tensor2pil(img).convert("RGB"), _mask.convert("L"))
243
+
244
+ ret_images.append(pil2tensor(_image))
245
+ ret_masks.append(image2mask(_mask))
246
+ if len(ret_masks) == 0:
247
+ _, height, width, _ = image.size()
248
+ empty_mask = torch.zeros(
249
+ (1, height, width), dtype=torch.uint8, device="cpu"
250
+ )
251
+ return (empty_mask, empty_mask)
252
+
253
+ return (
254
+ torch.cat(ret_images, dim=0),
255
+ torch.cat(ret_masks, dim=0),
256
+ )
@@ -0,0 +1,134 @@
1
+ import copy
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from scipy.ndimage import gaussian_filter
8
+
9
+ sam_model_dir_name = "sams"
10
+ sam_model_list = {
11
+ "sam_vit_h (2.56GB)": {
12
+ "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
13
+ },
14
+ # "sam_vit_l (1.25GB)": {
15
+ # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"
16
+ # },
17
+ # "sam_vit_b (375MB)": {
18
+ # "model_url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
19
+ # },
20
+ # "sam_hq_vit_h (2.57GB)": {
21
+ # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth"
22
+ # },
23
+ # "sam_hq_vit_l (1.25GB)": {
24
+ # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth"
25
+ # },
26
+ # "sam_hq_vit_b (379MB)": {
27
+ # "model_url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth"
28
+ # },
29
+ # "mobile_sam(39MB)": {
30
+ # "model_url": "https://github.com/ChaoningZhang/MobileSAM/blob/master/weights/mobile_sam.pt"
31
+ # },
32
+ }
33
+
34
+ groundingdino_model_dir_name = "grounding-dino"
35
+ groundingdino_model_list = {
36
+ "GroundingDINO_SwinT_OGC (694MB)": {
37
+ "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinT_OGC.cfg.py",
38
+ "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
39
+ },
40
+ # "GroundingDINO_SwinB (938MB)": {
41
+ # "config_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/GroundingDINO_SwinB.cfg.py",
42
+ # "model_url": "https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth",
43
+ # },
44
+ }
45
+
46
+
47
+ def list_sam_model():
48
+ return list(sam_model_list.keys())
49
+
50
+
51
+ def list_groundingdino_model():
52
+ return list(groundingdino_model_list.keys())
53
+
54
+
55
+ def histogram_remap(
56
+ image: torch.Tensor, blackpoint: float, whitepoint: float
57
+ ) -> torch.Tensor:
58
+ bp = min(blackpoint, whitepoint - 0.001)
59
+ scale = 1 / (whitepoint - bp)
60
+ i_dup = copy.deepcopy(image.cpu().numpy())
61
+ i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0)
62
+ return torch.from_numpy(i_dup)
63
+
64
+
65
+ def mask_edge_detail(
66
+ image: torch.Tensor,
67
+ mask: torch.Tensor,
68
+ detail_range: int = 8,
69
+ black_point: float = 0.01,
70
+ white_point: float = 0.99,
71
+ ) -> torch.Tensor:
72
+ from pymatting import estimate_alpha_cf, fix_trimap
73
+
74
+ d = detail_range * 5 + 1
75
+ mask = pil2tensor(tensor2pil(mask).convert("RGB"))
76
+ if not bool(d % 2):
77
+ d += 1
78
+ i_dup = copy.deepcopy(image.cpu().numpy().astype(np.float64))
79
+ a_dup = copy.deepcopy(mask.cpu().numpy().astype(np.float64))
80
+ for index, img in enumerate(i_dup):
81
+ trimap = a_dup[index][:, :, 0] # convert to single channel
82
+ if detail_range > 0:
83
+ # trimap = cv2.GaussianBlur(trimap, (d, d), 0)
84
+ trimap = gaussian_filter(trimap, sigma=d / 2)
85
+ trimap = fix_trimap(trimap, black_point, white_point)
86
+ alpha = estimate_alpha_cf(
87
+ img, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter": 500}
88
+ )
89
+ a_dup[index] = np.stack([alpha, alpha, alpha], axis=-1) # convert back to rgb
90
+ return torch.from_numpy(a_dup.astype(np.float32))
91
+
92
+
93
+ def pil2tensor(image: Image) -> torch.Tensor:
94
+ return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
95
+
96
+
97
+ def tensor2pil(t_image: torch.Tensor) -> Image:
98
+ return Image.fromarray(
99
+ np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
100
+ )
101
+
102
+
103
+ def tensor2np(tensor: torch.Tensor) -> List[np.ndarray]:
104
+ if len(tensor.shape) == 3: # Single image
105
+ return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8)
106
+ else: # Batch of images
107
+ return [
108
+ np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor
109
+ ]
110
+
111
+
112
+ def mask2image(mask: torch.Tensor) -> Image:
113
+ masks = tensor2np(mask)
114
+ for m in masks:
115
+ _mask = Image.fromarray(m).convert("L")
116
+ _image = Image.new("RGBA", _mask.size, color="white")
117
+ _image = Image.composite(
118
+ _image, Image.new("RGBA", _mask.size, color="black"), _mask
119
+ )
120
+ return _image
121
+
122
+
123
+ def image2mask(image: Image) -> torch.Tensor:
124
+ _image = image.convert("RGBA")
125
+ alpha = _image.split()[0]
126
+ bg = Image.new("L", _image.size)
127
+ _image = Image.merge("RGBA", (bg, bg, bg, alpha))
128
+ ret_mask = torch.tensor([pil2tensor(_image)[0, :, :, 3].tolist()])
129
+ return ret_mask
130
+
131
+
132
+ def RGB2RGBA(image: Image, mask: Image) -> Image:
133
+ (R, G, B) = image.convert("RGB").split()
134
+ return Image.merge("RGBA", (R, G, B, mask.convert("L")))