flaxdiff 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +33 -20
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.22.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.22.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.22.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.21.dist-info → flaxdiff-0.1.22.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -21,6 +21,7 @@ import urllib
|
|
21
21
|
|
22
22
|
import PIL.Image
|
23
23
|
import cv2
|
24
|
+
import traceback
|
24
25
|
|
25
26
|
USER_AGENT = get_datasets_user_agent()
|
26
27
|
|
@@ -43,7 +44,27 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
43
44
|
return image
|
44
45
|
|
45
46
|
|
46
|
-
def default_image_processor(
|
47
|
+
def default_image_processor(
|
48
|
+
image, image_shape,
|
49
|
+
min_image_shape=(128, 128),
|
50
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
51
|
+
downscale_interpolation=cv2.INTER_AREA,
|
52
|
+
):
|
53
|
+
image = np.array(image)
|
54
|
+
original_height, original_width = image.shape[:2]
|
55
|
+
# check if the image is too small
|
56
|
+
if min(original_height, original_width) < min(min_image_shape):
|
57
|
+
return None, original_height, original_width
|
58
|
+
# check if wrong aspect ratio
|
59
|
+
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
60
|
+
return None, original_height, original_width
|
61
|
+
# check if the variance is too low
|
62
|
+
if np.std(image) < 1e-5:
|
63
|
+
return None, original_height, original_width
|
64
|
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
65
|
+
downscale = max(original_width, original_height) > max(image_shape)
|
66
|
+
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
67
|
+
|
47
68
|
image = A.longest_max_size(image, max(
|
48
69
|
image_shape), interpolation=interpolation)
|
49
70
|
image = A.pad(
|
@@ -53,7 +74,7 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
|
|
53
74
|
border_mode=cv2.BORDER_CONSTANT,
|
54
75
|
value=[255, 255, 255],
|
55
76
|
)
|
56
|
-
return image
|
77
|
+
return image, original_height, original_width
|
57
78
|
|
58
79
|
|
59
80
|
def map_sample(
|
@@ -72,23 +93,13 @@ def map_sample(
|
|
72
93
|
if image is None:
|
73
94
|
return
|
74
95
|
|
75
|
-
image =
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
82
|
-
return
|
83
|
-
# check if the variance is too low
|
84
|
-
if np.std(image) < 1e-4:
|
96
|
+
image, original_height, original_width = image_processor(
|
97
|
+
image, image_shape, min_image_shape=min_image_shape,
|
98
|
+
upscale_interpolation=upscale_interpolation,
|
99
|
+
downscale_interpolation=downscale_interpolation,)
|
100
|
+
|
101
|
+
if image is None:
|
85
102
|
return
|
86
|
-
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
87
|
-
downscale = max(original_width, original_height) > max(image_shape)
|
88
|
-
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
89
|
-
|
90
|
-
image = image_processor(
|
91
|
-
image, image_shape, interpolation=interpolation)
|
92
103
|
|
93
104
|
data_queue.put({
|
94
105
|
"url": url,
|
@@ -98,7 +109,8 @@ def map_sample(
|
|
98
109
|
"original_width": original_width,
|
99
110
|
})
|
100
111
|
except Exception as e:
|
101
|
-
print(f"Error
|
112
|
+
print(f"Error maping sample {url}", e)
|
113
|
+
traceback.print_exc()
|
102
114
|
# error_queue.put_nowait({
|
103
115
|
# "url": url,
|
104
116
|
# "caption": caption,
|
@@ -122,7 +134,8 @@ def map_batch(
|
|
122
134
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
123
135
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
124
136
|
except Exception as e:
|
125
|
-
print(f"Error
|
137
|
+
print(f"Error maping batch", e)
|
138
|
+
traceback.print_exc()
|
126
139
|
# error_queue.put_nowait({
|
127
140
|
# "batch": batch,
|
128
141
|
# "error": str(e)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.22.dist-info/METADATA,sha256=Sv8OtwO7oEcPUc7GfytcqLNC6GP8ZFA7BZ4-X5QqUj8,22083
|
38
|
+
flaxdiff-0.1.22.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.22.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|