bizyengine 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. bizyengine/__init__.py +35 -0
  2. bizyengine/bizy_server/__init__.py +7 -0
  3. bizyengine/bizy_server/api_client.py +763 -0
  4. bizyengine/bizy_server/errno.py +122 -0
  5. bizyengine/bizy_server/error_handler.py +3 -0
  6. bizyengine/bizy_server/execution.py +55 -0
  7. bizyengine/bizy_server/resp.py +24 -0
  8. bizyengine/bizy_server/server.py +898 -0
  9. bizyengine/bizy_server/utils.py +93 -0
  10. bizyengine/bizyair_extras/__init__.py +24 -0
  11. bizyengine/bizyair_extras/nodes_advanced_refluxcontrol.py +62 -0
  12. bizyengine/bizyair_extras/nodes_cogview4.py +31 -0
  13. bizyengine/bizyair_extras/nodes_comfyui_detail_daemon.py +180 -0
  14. bizyengine/bizyair_extras/nodes_comfyui_instantid.py +164 -0
  15. bizyengine/bizyair_extras/nodes_comfyui_layerstyle_advance.py +141 -0
  16. bizyengine/bizyair_extras/nodes_comfyui_pulid_flux.py +88 -0
  17. bizyengine/bizyair_extras/nodes_controlnet.py +50 -0
  18. bizyengine/bizyair_extras/nodes_custom_sampler.py +130 -0
  19. bizyengine/bizyair_extras/nodes_dataset.py +99 -0
  20. bizyengine/bizyair_extras/nodes_differential_diffusion.py +16 -0
  21. bizyengine/bizyair_extras/nodes_flux.py +69 -0
  22. bizyengine/bizyair_extras/nodes_image_utils.py +93 -0
  23. bizyengine/bizyair_extras/nodes_ip2p.py +20 -0
  24. bizyengine/bizyair_extras/nodes_ipadapter_plus/__init__.py +1 -0
  25. bizyengine/bizyair_extras/nodes_ipadapter_plus/nodes_ipadapter_plus.py +1598 -0
  26. bizyengine/bizyair_extras/nodes_janus_pro.py +81 -0
  27. bizyengine/bizyair_extras/nodes_kolors_mz/__init__.py +86 -0
  28. bizyengine/bizyair_extras/nodes_model_advanced.py +62 -0
  29. bizyengine/bizyair_extras/nodes_sd3.py +52 -0
  30. bizyengine/bizyair_extras/nodes_segment_anything.py +256 -0
  31. bizyengine/bizyair_extras/nodes_segment_anything_utils.py +134 -0
  32. bizyengine/bizyair_extras/nodes_testing_utils.py +139 -0
  33. bizyengine/bizyair_extras/nodes_trellis.py +199 -0
  34. bizyengine/bizyair_extras/nodes_ultimatesdupscale.py +137 -0
  35. bizyengine/bizyair_extras/nodes_upscale_model.py +32 -0
  36. bizyengine/bizyair_extras/nodes_wan_video.py +49 -0
  37. bizyengine/bizyair_extras/oauth_callback/main.py +118 -0
  38. bizyengine/core/__init__.py +8 -0
  39. bizyengine/core/commands/__init__.py +1 -0
  40. bizyengine/core/commands/base.py +27 -0
  41. bizyengine/core/commands/invoker.py +4 -0
  42. bizyengine/core/commands/processors/model_hosting_processor.py +0 -0
  43. bizyengine/core/commands/processors/prompt_processor.py +123 -0
  44. bizyengine/core/commands/servers/model_server.py +0 -0
  45. bizyengine/core/commands/servers/prompt_server.py +234 -0
  46. bizyengine/core/common/__init__.py +8 -0
  47. bizyengine/core/common/caching.py +198 -0
  48. bizyengine/core/common/client.py +262 -0
  49. bizyengine/core/common/env_var.py +101 -0
  50. bizyengine/core/common/utils.py +93 -0
  51. bizyengine/core/configs/conf.py +112 -0
  52. bizyengine/core/configs/models.json +101 -0
  53. bizyengine/core/configs/models.yaml +329 -0
  54. bizyengine/core/data_types.py +20 -0
  55. bizyengine/core/image_utils.py +288 -0
  56. bizyengine/core/nodes_base.py +159 -0
  57. bizyengine/core/nodes_io.py +97 -0
  58. bizyengine/core/path_utils/__init__.py +9 -0
  59. bizyengine/core/path_utils/path_manager.py +276 -0
  60. bizyengine/core/path_utils/utils.py +34 -0
  61. bizyengine/misc/__init__.py +0 -0
  62. bizyengine/misc/auth.py +83 -0
  63. bizyengine/misc/llm.py +431 -0
  64. bizyengine/misc/mzkolors.py +93 -0
  65. bizyengine/misc/nodes.py +1208 -0
  66. bizyengine/misc/nodes_controlnet_aux.py +491 -0
  67. bizyengine/misc/nodes_controlnet_union_sdxl.py +171 -0
  68. bizyengine/misc/route_sam.py +60 -0
  69. bizyengine/misc/segment_anything.py +276 -0
  70. bizyengine/misc/supernode.py +182 -0
  71. bizyengine/misc/utils.py +218 -0
  72. bizyengine/version.txt +1 -0
  73. bizyengine-0.4.2.dist-info/METADATA +12 -0
  74. bizyengine-0.4.2.dist-info/RECORD +76 -0
  75. bizyengine-0.4.2.dist-info/WHEEL +5 -0
  76. bizyengine-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,329 @@
1
+ # Common configuration
2
+ model_version_config:
3
+ model_version_id_prefix: "BIZYAIR_MODEL_VERSION_ID:"
4
+
5
+ cache_config:
6
+ max_size: 100 # 100 items
7
+ expiration: 604800 # 7 days
8
+ cache_dir: ".bizyair_cache"
9
+ file_prefix: "bizyair_task_"
10
+ file_suffix: ".json"
11
+ use_cache: true
12
+
13
+
14
+ model_hub:
15
+ find_model:
16
+ route: /models/files
17
+
18
+ model_types:
19
+ loras: bizyair/lora
20
+ controlnet: bizyair/controlnet
21
+ # folder_name, server_folder_name
22
+ # checkpoints: bizyair/checkpoint
23
+ # vae: bizyair/vae
24
+
25
+ task_api:
26
+ # Base URL for task-related API calls
27
+ task_result_endpoint: /bizy_task
28
+
29
+
30
+ model_rules:
31
+ - mode_type: unet
32
+ base_model: FLUX
33
+ describe: flux1-dev
34
+ score: 3
35
+ route: /supernode/flux-dev-bizyair-comfy-ksampler-speedup
36
+ nodes:
37
+ - class_type: UNETLoader
38
+ inputs:
39
+ unet_name:
40
+ - ^flux/flux1-dev.sft$
41
+
42
+ - mode_type: unet
43
+ base_model: FLUX
44
+ describe: flux1-schnell
45
+ score: 3
46
+ route: /supernode/flux-bizyair-sdxl-comfy-ksampler
47
+ nodes:
48
+ - class_type: UNETLoader
49
+ inputs:
50
+ unet_name:
51
+ - ^flux/flux1-schnell.sft$
52
+
53
+ - mode_type: unet
54
+ base_model: Shuttle
55
+ describe: shuttle-3.1-aesthetic
56
+ score: 3
57
+ route: /supernode/bizyair-shuttle-3-1-aesthetic
58
+ nodes:
59
+ - class_type: UNETLoader
60
+ inputs:
61
+ unet_name:
62
+ - ^shuttle-3.1-aesthetic.safetensors$
63
+
64
+ - mode_type: vae
65
+ base_model: FLUX
66
+ describe: flux-vae
67
+ score: 1
68
+ route: /supernode/flux-vae-bizyair-comfy-ksampler
69
+ nodes:
70
+ - class_type: VAELoader
71
+ inputs:
72
+ vae_name:
73
+ - ^flux/ae.sft$
74
+
75
+ - mode_type: unet
76
+ base_model: FLUX_PIXELWAVE
77
+ describe: PixelWave Flux.1-dev 03 fine tuned!
78
+ score: 3
79
+ route: /supernode/bizyair-flux1-dev-fp8-pixelwave
80
+ nodes:
81
+ - class_type: UNETLoader
82
+ inputs:
83
+ unet_name:
84
+ - ^flux/pixelwave-flux1-dev.safetensors$
85
+
86
+ - mode_type: checkpoint
87
+ base_model: SD3
88
+ describe: SD3.5 Large
89
+ score: 3
90
+ route: /supernode/bizyair-comfybridge-sd3-5-large
91
+ nodes:
92
+ - class_type: CheckpointLoaderSimple
93
+ inputs:
94
+ ckpt_name:
95
+ - ^sd3.5_large.safetensors$
96
+
97
+ - mode_type: checkpoint
98
+ base_model: SD3
99
+ describe: SD3.5 Large Turbo
100
+ score: 3
101
+ route: /supernode/bizyair-comfybridge-sd3-5-turbo
102
+ nodes:
103
+ - class_type: CheckpointLoaderSimple
104
+ inputs:
105
+ ckpt_name:
106
+ - ^sd3.5_large_turbo.safetensors$
107
+
108
+ - mode_type: checkpoint
109
+ base_model: BaseModel
110
+ describe: SD1.5
111
+ score: 3
112
+ route: /supernode/bizyair-ultimate-sd-upscale-ksampler
113
+ nodes:
114
+ - class_type: CheckpointLoaderSimple
115
+ inputs:
116
+ ckpt_name:
117
+ - ^sd15/dreamshaper_8.safetensors$
118
+
119
+
120
+ - mode_type: checkpoint
121
+ base_model: SDXL
122
+ describe: SDXL
123
+ score: 3
124
+ route: /supernode/bizyair-sdxl-comfy-ksampler-v2
125
+ nodes:
126
+ - class_type: CheckpointLoaderSimple
127
+ inputs:
128
+ ckpt_name:
129
+ - ^sdxl.*
130
+
131
+ - mode_type: checkpoint
132
+ base_model: SDXL
133
+ describe: SDXL
134
+ score: 3
135
+ route: /supernode/bizyair-sdxl-comfy-ksampler-v2
136
+ nodes:
137
+ - class_type: CheckpointLoaderSimple
138
+ inputs:
139
+ ckpt_name:
140
+ - ^sdxl.*
141
+
142
+ - mode_type: unet
143
+ base_model: Kolors
144
+ describe: Kolors
145
+ score: 3
146
+ route: /supernode/kolors-bizyair-sdxl-comfy-ksampler
147
+ nodes:
148
+ - class_type: MZ_KolorsUNETLoaderV2
149
+ inputs:
150
+ unet_name:
151
+ - ^kolors.*
152
+ - class_type: VAELoader
153
+ inputs:
154
+ vae_name:
155
+ - ^sdxl/sdxl_vae.safetensors*
156
+
157
+ - mode_type: pulid
158
+ base_model: FLUX
159
+ describe: Flux Pulid
160
+ score: 5
161
+ route: /supernode/bizyair-flux-dev-comfy-pulid
162
+ nodes:
163
+ - class_type: PulidFluxModelLoader
164
+ inputs:
165
+ pulid_file:
166
+ - '.*'
167
+
168
+ - mode_type: controlnet
169
+ base_model: FLUX
170
+ describe: Flux ControlNet
171
+ score: 5
172
+ route: /supernode/flux-dev-bizyair-comfy-ksampler-fp8-v2
173
+ nodes:
174
+ - class_type: ControlNetLoader
175
+ inputs:
176
+ control_net_name:
177
+ - '.*'
178
+
179
+ - mode_type: lora
180
+ base_model: FLUX
181
+ describe: Flux Lora
182
+ score: 4
183
+ route: /supernode/flux-dev-bizyair-comfy-ksampler-fp8-v2
184
+ nodes:
185
+ - class_type: LoraLoader
186
+ inputs:
187
+ lora_name:
188
+ - '.*'
189
+
190
+ - mode_type: style_model
191
+ base_model: FLUX
192
+ describe: Flux Style Model
193
+ score: 4
194
+ route: /supernode/flux-dev-bizyair-comfy-ksampler-fp8-v2
195
+ nodes:
196
+ - class_type: StyleModelLoader
197
+ inputs:
198
+ style_model_name:
199
+ - ^flux1-redux-dev.safetensors$
200
+
201
+ - mode_type: unet
202
+ base_model: FLUX1-Fill
203
+ describe: flux1-fill
204
+ score: 3
205
+ route: /supernode/bizyair-flux1-tools-fill
206
+ nodes:
207
+ - class_type: UNETLoader
208
+ inputs:
209
+ unet_name:
210
+ - ^flux/flux1-fill-dev.safetensors$
211
+
212
+ - mode_type: vae
213
+ base_model: FLUX1-Fill
214
+ describe: flux1-fill-vae
215
+ score: 1
216
+ route: /supernode/bizyair-flux1-tools-fill
217
+ nodes:
218
+ - class_type: VAELoader
219
+ inputs:
220
+ vae_name:
221
+ - ^flux.1-fill-vae.safetensors$
222
+ - mode_type: unet
223
+ base_model: FLUX1-Depth
224
+ describe: flux1-depth
225
+ score: 3
226
+ route: /supernode/bizyair-flux1-tools-depth
227
+ nodes:
228
+ - class_type: UNETLoader
229
+ inputs:
230
+ unet_name:
231
+ - ^flux/flux1-depth-dev.safetensors$
232
+
233
+ - mode_type: vae
234
+ base_model: FLUX1-Depth
235
+ describe: flux1-depth-vae
236
+ score: 1
237
+ route: /supernode/bizyair-flux1-tools-depth
238
+ nodes:
239
+ - class_type: VAELoader
240
+ inputs:
241
+ vae_name:
242
+ - ^flux.1-depth-vae.safetensors$
243
+
244
+ - mode_type: unet
245
+ base_model: FLUX1-Canny
246
+ describe: flux1-canny
247
+ score: 3
248
+ route: /supernode/bizyair-flux1-tools-canny
249
+ nodes:
250
+ - class_type: UNETLoader
251
+ inputs:
252
+ unet_name:
253
+ - ^flux/flux1-canny-dev.safetensors$
254
+
255
+ - mode_type: vae
256
+ base_model: FLUX1-Canny
257
+ describe: flux1-canny-vae
258
+ score: 1
259
+ route: /supernode/bizyair-flux1-tools-canny
260
+ nodes:
261
+ - class_type: VAELoader
262
+ inputs:
263
+ vae_name:
264
+ - ^flux.1-canny-vae.safetensors$
265
+
266
+ - mode_type: upscale_models
267
+ base_model: UPSCALE_MODEL
268
+ describe: Upscale Model
269
+ score: 1
270
+ route: /bizy_task/bizyair-flux1-dev-fp8-async
271
+ nodes:
272
+ - class_type: UpscaleModelLoader
273
+
274
+ - mode_type: upscale_model
275
+ base_model: FLUX
276
+ describe: Flux Upscale Model
277
+ score: 6
278
+ route: /bizy_task/bizyair-flux1-dev-fp8-async
279
+ nodes:
280
+ - class_type: UltimateSDUpscale
281
+ - mode_type: sams
282
+ base_model: SAM
283
+ describe: SAM
284
+ score: 1
285
+ route: /supernode/bizyair-sam
286
+ nodes:
287
+ - class_type: 'LayerMask: SegmentAnythingUltra V2'
288
+ - class_type: SAMModelLoader
289
+ - class_type: TrimapGenerate
290
+ - class_type: VITMatteModelLoader
291
+ - class_type: DetailMethodPredict
292
+ - class_type: VitMattePredict
293
+ - mode_type: trellis
294
+ base_model: trellis
295
+ describe: trellis
296
+ score: 1
297
+ route: /bizy_task/bizyair-3d-trellis
298
+ nodes:
299
+ - class_type: 'IF_TrellisCheckpointLoader'
300
+ - class_type: IF_TrellisImageTo3D
301
+ - class_type: Trans3D2GlbFile
302
+
303
+ - mode_type: Janus
304
+ base_model: Janus
305
+ describe: Janus Model
306
+ score: 1
307
+ route: /supernode/bizyair-janus-pro-7b
308
+ nodes:
309
+ - class_type: 'JanusModelLoader'
310
+
311
+ - mode_type: CogView4_6B_Pipe
312
+ base_model: CogView4_6B_Pipe
313
+ describe: CogView4_6B_Pipe
314
+ score: 1
315
+ route: /supernode/bizyair-cogview4-6b-pipe
316
+ nodes:
317
+ - class_type: CogView4_6B_Pipe
318
+
319
+
320
+ - mode_type: Wan2.1-T2V
321
+ base_model: Wan
322
+ describe: Wan
323
+ score: 1
324
+ route: /bizy_task/bizyair-dev-wan-video
325
+ nodes:
326
+ - class_type: 'Wan_Model_Loader'
327
+ inputs:
328
+ ckpt_name:
329
+ - ^Wan2.1-T2V-1.3B$
@@ -0,0 +1,20 @@
1
+ # https://docs.comfy.org/essentials/custom_node_datatypes#model-datatypes
2
+ # Model datatypes
3
+ MODEL = "BIZYAIR_MODEL"
4
+ CLIP = "BIZYAIR_CLIP"
5
+ VAE = "BIZYAIR_VAE"
6
+ CONDITIONING = "BIZYAIR_CONDITIONING"
7
+ CONTROL_NET = "BIZYAIR_CONTROL_NET"
8
+ UPSCALE_MODEL = "BIZYAIR_UPSCALE_MODEL"
9
+ INSTANTID = "BIZYAIR_INSTANTID"
10
+ FACEANALYSIS = "BIZYAIR_FACEANALYSIS"
11
+ STYLE_MODEL = "BIZYAIR_STYLE_MODEL"
12
+
13
+
14
+ def is_model_datatype(datatype):
15
+ return datatype in [MODEL, CLIP, VAE, CONDITIONING, CONTROL_NET, STYLE_MODEL]
16
+
17
+
18
+ # https://docs.comfy.org/essentials/custom_node_images_and_masks
19
+ def is_send_request_datatype(datatype: str) -> bool:
20
+ return datatype in {"IMAGE", "LATENT", "MASK", "STRING", "FLOAT", "INT"}
@@ -0,0 +1,288 @@
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+ import pickle
6
+ import zlib
7
+ from enum import Enum
8
+ from functools import singledispatch
9
+ from typing import Any, List, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from .common.env_var import BIZYAIR_DEBUG
16
+
17
+ # Marker to identify base64-encoded tensors
18
+ TENSOR_MARKER = "TENSOR:"
19
+ IMAGE_MARKER = "IMAGE:"
20
+
21
+
22
+ class TaskStatus(Enum):
23
+ PENDING = "pending"
24
+ PROCESSING = "processing"
25
+ COMPLETED = "completed"
26
+
27
+
28
+ def convert_image_to_rgb(image: Image.Image) -> Image.Image:
29
+ if image.mode != "RGB":
30
+ return image.convert("RGB")
31
+ return image
32
+
33
+
34
+ def encode_image_to_base64(
35
+ image: Image.Image, format: str = "png", quality: int = 100, lossless=False
36
+ ) -> str:
37
+ image = convert_image_to_rgb(image)
38
+ with io.BytesIO() as output:
39
+ image.save(output, format=format, quality=quality, lossless=lossless)
40
+ output.seek(0)
41
+ img_bytes = output.getvalue()
42
+ if BIZYAIR_DEBUG:
43
+ print(f"encode_image_to_base64: {format_bytes(len(img_bytes))}")
44
+ return base64.b64encode(img_bytes).decode("utf-8")
45
+
46
+
47
+ def decode_base64_to_np(img_data: str, format: str = "png") -> np.ndarray:
48
+ img_bytes = base64.b64decode(img_data)
49
+ if BIZYAIR_DEBUG:
50
+ print(f"decode_base64_to_np: {format_bytes(len(img_bytes))}")
51
+ with io.BytesIO(img_bytes) as input_buffer:
52
+ img = Image.open(input_buffer)
53
+ # https://github.com/comfyanonymous/ComfyUI/blob/a178e25912b01abf436eba1cfaab316ba02d272d/nodes.py#L1511
54
+ img = img.convert("RGB")
55
+ return np.array(img)
56
+
57
+
58
+ def decode_base64_to_image(img_data: str) -> Image.Image:
59
+ img_bytes = base64.b64decode(img_data)
60
+ with io.BytesIO(img_bytes) as input_buffer:
61
+ img = Image.open(input_buffer)
62
+ if BIZYAIR_DEBUG:
63
+ format_info = img.format.upper() if img.format else "Unknown"
64
+ print(f"decode image format: {format_info}")
65
+ return img
66
+
67
+
68
+ def format_bytes(num_bytes: int) -> str:
69
+ """
70
+ Converts a number of bytes to a human-readable string with units (B, KB, or MB).
71
+
72
+ :param num_bytes: The number of bytes to convert.
73
+ :return: A string representing the number of bytes in a human-readable format.
74
+ """
75
+ if num_bytes < 1024:
76
+ return f"{num_bytes} B"
77
+ elif num_bytes < 1024 * 1024:
78
+ return f"{num_bytes / 1024:.2f} KB"
79
+ else:
80
+ return f"{num_bytes / (1024 * 1024):.2f} MB"
81
+
82
+
83
+ def _legacy_encode_comfy_image(image: torch.Tensor, image_format="png") -> str:
84
+ input_image = image.cpu().detach().numpy()
85
+ i = 255.0 * input_image[0]
86
+ input_image = np.clip(i, 0, 255).astype(np.uint8)
87
+ base64ed_image = encode_image_to_base64(
88
+ Image.fromarray(input_image), format=image_format
89
+ )
90
+ return base64ed_image
91
+
92
+
93
+ def _legacy_decode_comfy_image(
94
+ img_data: Union[List, str], image_format="png"
95
+ ) -> torch.tensor:
96
+ if isinstance(img_data, List):
97
+ decoded_imgs = [decode_comfy_image(x, old_version=True) for x in img_data]
98
+
99
+ combined_imgs = torch.cat(decoded_imgs, dim=0)
100
+ return combined_imgs
101
+
102
+ out = decode_base64_to_np(img_data, format=image_format)
103
+ out = np.array(out).astype(np.float32) / 255.0
104
+ output = torch.from_numpy(out)[None,]
105
+ return output
106
+
107
+
108
+ def _new_encode_comfy_image(images: torch.Tensor, image_format="WEBP", **kwargs) -> str:
109
+ """https://docs.comfy.org/essentials/custom_node_snippets#save-an-image-batch
110
+ Encode a batch of images to base64 strings.
111
+
112
+ Args:
113
+ images (torch.Tensor): A batch of images.
114
+ image_format (str, optional): The format of the images. Defaults to "WEBP".
115
+
116
+ Returns:
117
+ str: A JSON string containing the base64-encoded images.
118
+ """
119
+ results = {}
120
+ for batch_number, image in enumerate(images):
121
+ i = 255.0 * image.cpu().numpy()
122
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
123
+ base64ed_image = encode_image_to_base64(img, format=image_format, **kwargs)
124
+ results[batch_number] = base64ed_image
125
+
126
+ return json.dumps(results)
127
+
128
+
129
+ def _new_decode_comfy_image(img_datas: str, image_format="WEBP") -> torch.tensor:
130
+ """
131
+ Decode a batch of base64-encoded images.
132
+
133
+ Args:
134
+ img_datas (str): A JSON string containing the base64-encoded images.
135
+ image_format (str, optional): The format of the images. Defaults to "WEBP".
136
+
137
+ Returns:
138
+ torch.Tensor: A tensor containing the decoded images.
139
+ """
140
+ img_datas = json.loads(img_datas)
141
+
142
+ decoded_imgs = []
143
+ for img_data in img_datas.values():
144
+ decoded_image = decode_base64_to_np(img_data, format=image_format)
145
+ decoded_image = np.array(decoded_image).astype(np.float32) / 255.0
146
+ decoded_imgs.append(torch.from_numpy(decoded_image)[None,])
147
+
148
+ return torch.cat(decoded_imgs, dim=0)
149
+
150
+
151
+ def encode_comfy_image(
152
+ image: torch.Tensor, image_format="WEBP", old_version=False, lossless=False
153
+ ) -> str:
154
+ if old_version:
155
+ return _legacy_encode_comfy_image(image, image_format)
156
+ return _new_encode_comfy_image(image, image_format, lossless=lossless)
157
+
158
+
159
+ def decode_comfy_image(
160
+ img_data: Union[List, str], image_format="WEBP", old_version=False
161
+ ) -> torch.tensor:
162
+ if old_version:
163
+ return _legacy_decode_comfy_image(img_data, image_format)
164
+ return _new_decode_comfy_image(img_data, image_format)
165
+
166
+
167
+ def tensor_to_base64(tensor: torch.Tensor, compress=True) -> str:
168
+ tensor_np = tensor.cpu().detach().numpy()
169
+
170
+ tensor_bytes = pickle.dumps(tensor_np)
171
+ if compress:
172
+ tensor_bytes = zlib.compress(tensor_bytes)
173
+
174
+ tensor_b64 = base64.b64encode(tensor_bytes).decode("utf-8")
175
+ return tensor_b64
176
+
177
+
178
+ def base64_to_tensor(tensor_b64: str, compress=True) -> torch.Tensor:
179
+ tensor_bytes = base64.b64decode(tensor_b64)
180
+
181
+ if compress:
182
+ tensor_bytes = zlib.decompress(tensor_bytes)
183
+
184
+ tensor_np = pickle.loads(tensor_bytes)
185
+
186
+ tensor = torch.from_numpy(tensor_np)
187
+ return tensor
188
+
189
+
190
+ @singledispatch
191
+ def decode_data(input, old_version=False):
192
+ raise NotImplementedError(f"Unsupported type: {type(input)}")
193
+
194
+
195
+ @decode_data.register(int)
196
+ @decode_data.register(float)
197
+ @decode_data.register(bool)
198
+ @decode_data.register(type(None))
199
+ def _(input, **kwargs):
200
+ return input
201
+
202
+
203
+ @decode_data.register(dict)
204
+ def _(input, **kwargs):
205
+ return {k: decode_data(v, **kwargs) for k, v in input.items()}
206
+
207
+
208
+ @decode_data.register(list)
209
+ def _(input, **kwargs):
210
+ return [decode_data(x, **kwargs) for x in input]
211
+
212
+
213
+ @decode_data.register(str)
214
+ def _(input: str, **kwargs):
215
+ if input.startswith(TENSOR_MARKER):
216
+ tensor_b64 = input[len(TENSOR_MARKER) :]
217
+ return base64_to_tensor(tensor_b64)
218
+ elif input.startswith(IMAGE_MARKER):
219
+ tensor_b64 = input[len(IMAGE_MARKER) :]
220
+ old_version = kwargs.get("old_version", False)
221
+ return decode_comfy_image(tensor_b64, old_version=old_version)
222
+ return input
223
+
224
+
225
+ @singledispatch
226
+ def encode_data(output, disable_image_marker=False, old_version=False):
227
+ raise NotImplementedError(f"Unsupported type: {type(output)}")
228
+
229
+
230
+ @encode_data.register(dict)
231
+ def _(output, **kwargs):
232
+ return {k: encode_data(v, **kwargs) for k, v in output.items()}
233
+
234
+
235
+ @encode_data.register(list)
236
+ def _(output, **kwargs):
237
+ return [encode_data(x, **kwargs) for x in output]
238
+
239
+
240
+ def is_image_tensor(tensor) -> bool:
241
+ """https://docs.comfy.org/essentials/custom_node_datatypes#image
242
+
243
+ Check if the given tensor is in the format of an IMAGE (shape [B, H, W, C] where C=3).
244
+
245
+ `Args`:
246
+ tensor (torch.Tensor): The tensor to check.
247
+
248
+ `Returns`:
249
+ bool: True if the tensor is in the IMAGE format, False otherwise.
250
+ """
251
+ try:
252
+ if not isinstance(tensor, torch.Tensor):
253
+ return False
254
+
255
+ if len(tensor.shape) != 4:
256
+ return False
257
+
258
+ B, H, W, C = tensor.shape
259
+ if C != 3:
260
+ return False
261
+
262
+ return True
263
+ except:
264
+ return False
265
+
266
+
267
+ @encode_data.register(torch.Tensor)
268
+ def _(output, **kwargs):
269
+ if is_image_tensor(output) and not kwargs.get("disable_image_marker", False):
270
+ old_version = kwargs.get("old_version", False)
271
+ lossless = kwargs.get("lossless", True)
272
+ return IMAGE_MARKER + encode_comfy_image(
273
+ output, image_format="WEBP", old_version=old_version, lossless=lossless
274
+ )
275
+ return TENSOR_MARKER + tensor_to_base64(output)
276
+
277
+
278
+ @encode_data.register(int)
279
+ @encode_data.register(float)
280
+ @encode_data.register(bool)
281
+ @encode_data.register(type(None))
282
+ def _(output, **kwargs):
283
+ return output
284
+
285
+
286
+ @encode_data.register(str)
287
+ def _(output, **kwargs):
288
+ return output