flaxdiff 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +66 -32
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/top_level.txt +0 -0
    
        flaxdiff/data/online_loader.py
    CHANGED
    
    | @@ -25,7 +25,6 @@ import cv2 | |
| 25 25 | 
             
            USER_AGENT = get_datasets_user_agent()
         | 
| 26 26 |  | 
| 27 27 | 
             
            data_queue = Queue(16*2000)
         | 
| 28 | 
            -
            error_queue = Queue()
         | 
| 29 28 |  | 
| 30 29 |  | 
| 31 30 | 
             
            def fetch_single_image(image_url, timeout=None, retries=0):
         | 
| @@ -44,7 +43,7 @@ def fetch_single_image(image_url, timeout=None, retries=0): | |
| 44 43 | 
             
                return image
         | 
| 45 44 |  | 
| 46 45 |  | 
| 47 | 
            -
            def default_image_processor(image, image_shape, interpolation=cv2. | 
| 46 | 
            +
            def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
         | 
| 48 47 | 
             
                image = A.longest_max_size(image, max(
         | 
| 49 48 | 
             
                    image_shape), interpolation=interpolation)
         | 
| 50 49 | 
             
                image = A.pad(
         | 
| @@ -60,9 +59,10 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4 | |
| 60 59 | 
             
            def map_sample(
         | 
| 61 60 | 
             
                url, caption,
         | 
| 62 61 | 
             
                image_shape=(256, 256),
         | 
| 62 | 
            +
                min_image_shape=(128, 128),
         | 
| 63 63 | 
             
                timeout=15,
         | 
| 64 64 | 
             
                retries=3,
         | 
| 65 | 
            -
                upscale_interpolation=cv2. | 
| 65 | 
            +
                upscale_interpolation=cv2.INTER_CUBIC,
         | 
| 66 66 | 
             
                downscale_interpolation=cv2.INTER_AREA,
         | 
| 67 67 | 
             
                image_processor=default_image_processor,
         | 
| 68 68 | 
             
            ):
         | 
| @@ -75,10 +75,10 @@ def map_sample( | |
| 75 75 | 
             
                    image = np.array(image)
         | 
| 76 76 | 
             
                    original_height, original_width = image.shape[:2]
         | 
| 77 77 | 
             
                    # check if the image is too small
         | 
| 78 | 
            -
                    if min(original_height, original_width) < min( | 
| 78 | 
            +
                    if min(original_height, original_width) < min(min_image_shape):
         | 
| 79 79 | 
             
                        return
         | 
| 80 80 | 
             
                    # check if wrong aspect ratio
         | 
| 81 | 
            -
                    if max(original_height, original_width) / min(original_height, original_width) > 2:
         | 
| 81 | 
            +
                    if max(original_height, original_width) / min(original_height, original_width) > 2.4:
         | 
| 82 82 | 
             
                        return
         | 
| 83 83 | 
             
                    # check if the variance is too low
         | 
| 84 84 | 
             
                    if np.std(image) < 1e-4:
         | 
| @@ -98,30 +98,48 @@ def map_sample( | |
| 98 98 | 
             
                        "original_width": original_width,
         | 
| 99 99 | 
             
                    })
         | 
| 100 100 | 
             
                except Exception as e:
         | 
| 101 | 
            -
                    error_queue.put_nowait({
         | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
                    })
         | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 101 | 
            +
                    # error_queue.put_nowait({
         | 
| 102 | 
            +
                    #     "url": url,
         | 
| 103 | 
            +
                    #     "caption": caption,
         | 
| 104 | 
            +
                    #     "error": str(e)
         | 
| 105 | 
            +
                    # })
         | 
| 106 | 
            +
                    pass
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            def map_batch(
         | 
| 110 | 
            +
                batch, num_threads=256, image_shape=(256, 256), 
         | 
| 111 | 
            +
                min_image_shape=(128, 128),
         | 
| 112 | 
            +
                timeout=15, retries=3, image_processor=default_image_processor,
         | 
| 113 | 
            +
                upscale_interpolation=cv2.INTER_CUBIC,
         | 
| 114 | 
            +
                downscale_interpolation=cv2.INTER_AREA,
         | 
| 115 | 
            +
            ):
         | 
| 109 116 | 
             
                try:
         | 
| 110 | 
            -
                    map_sample_fn = partial(map_sample, image_shape=image_shape,
         | 
| 111 | 
            -
                                            timeout=timeout, retries=retries, image_processor=image_processor | 
| 117 | 
            +
                    map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
         | 
| 118 | 
            +
                                            timeout=timeout, retries=retries, image_processor=image_processor,
         | 
| 119 | 
            +
                                            upscale_interpolation=upscale_interpolation,
         | 
| 120 | 
            +
                                            downscale_interpolation=downscale_interpolation)
         | 
| 112 121 | 
             
                    with ThreadPoolExecutor(max_workers=num_threads) as executor:
         | 
| 113 122 | 
             
                        executor.map(map_sample_fn, batch["url"], batch['caption'])
         | 
| 114 123 | 
             
                except Exception as e:
         | 
| 115 | 
            -
                    error_queue. | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
                    })
         | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
| 122 | 
            -
             | 
| 123 | 
            -
                 | 
| 124 | 
            -
             | 
| 124 | 
            +
                    # error_queue.put_nowait({
         | 
| 125 | 
            +
                    #     "batch": batch,
         | 
| 126 | 
            +
                    #     "error": str(e)
         | 
| 127 | 
            +
                    # })
         | 
| 128 | 
            +
                    pass
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            def parallel_image_loader(
         | 
| 132 | 
            +
                dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), 
         | 
| 133 | 
            +
                min_image_shape=(128, 128),
         | 
| 134 | 
            +
                num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
         | 
| 135 | 
            +
                upscale_interpolation=cv2.INTER_CUBIC,
         | 
| 136 | 
            +
                downscale_interpolation=cv2.INTER_AREA,
         | 
| 137 | 
            +
            ):
         | 
| 138 | 
            +
                map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape, 
         | 
| 139 | 
            +
                                       min_image_shape=min_image_shape,
         | 
| 140 | 
            +
                                       timeout=timeout, retries=retries, image_processor=image_processor,
         | 
| 141 | 
            +
                                       upscale_interpolation=upscale_interpolation,
         | 
| 142 | 
            +
                                       downscale_interpolation=downscale_interpolation)
         | 
| 125 143 | 
             
                shard_len = len(dataset) // num_workers
         | 
| 126 144 | 
             
                print(f"Local Shard lengths: {shard_len}")
         | 
| 127 145 | 
             
                with multiprocessing.Pool(num_workers) as pool:
         | 
| @@ -136,19 +154,29 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2 | |
| 136 154 | 
             
                        print(f"Shuffling dataset with seed {iteration}")
         | 
| 137 155 | 
             
                        dataset = dataset.shuffle(seed=iteration)
         | 
| 138 156 | 
             
                        # Clear the error queue
         | 
| 139 | 
            -
                        while not error_queue.empty():
         | 
| 140 | 
            -
             | 
| 157 | 
            +
                        # while not error_queue.empty():
         | 
| 158 | 
            +
                        #     error_queue.get_nowait()
         | 
| 141 159 |  | 
| 142 160 |  | 
| 143 161 | 
             
            class ImageBatchIterator:
         | 
| 144 | 
            -
                def __init__( | 
| 145 | 
            -
             | 
| 162 | 
            +
                def __init__(
         | 
| 163 | 
            +
                    self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), 
         | 
| 164 | 
            +
                    min_image_shape=(128, 128),
         | 
| 165 | 
            +
                    num_workers: int = 8, num_threads=256, timeout=15, retries=3, 
         | 
| 166 | 
            +
                    image_processor=default_image_processor,
         | 
| 167 | 
            +
                    upscale_interpolation=cv2.INTER_CUBIC,
         | 
| 168 | 
            +
                    downscale_interpolation=cv2.INTER_AREA,
         | 
| 169 | 
            +
                ):
         | 
| 146 170 | 
             
                    self.dataset = dataset
         | 
| 147 171 | 
             
                    self.num_workers = num_workers
         | 
| 148 172 | 
             
                    self.batch_size = batch_size
         | 
| 149 173 | 
             
                    loader = partial(parallel_image_loader, num_threads=num_threads,
         | 
| 150 | 
            -
                                     image_shape=image_shape, | 
| 151 | 
            -
                                      | 
| 174 | 
            +
                                     image_shape=image_shape,
         | 
| 175 | 
            +
                                     min_image_shape=min_image_shape, 
         | 
| 176 | 
            +
                                     num_workers=num_workers, 
         | 
| 177 | 
            +
                                     timeout=timeout, retries=retries, image_processor=image_processor,
         | 
| 178 | 
            +
                                     upscale_interpolation=upscale_interpolation,
         | 
| 179 | 
            +
                                     downscale_interpolation=downscale_interpolation)
         | 
| 152 180 | 
             
                    self.thread = threading.Thread(target=loader, args=(dataset,))
         | 
| 153 181 | 
             
                    self.thread.start()
         | 
| 154 182 |  | 
| @@ -195,6 +223,7 @@ class OnlineStreamingDataLoader(): | |
| 195 223 | 
             
                    dataset,
         | 
| 196 224 | 
             
                    batch_size=64,
         | 
| 197 225 | 
             
                    image_shape=(256, 256),
         | 
| 226 | 
            +
                    min_image_shape=(128, 128),
         | 
| 198 227 | 
             
                    num_workers=16,
         | 
| 199 228 | 
             
                    num_threads=512,
         | 
| 200 229 | 
             
                    default_split="all",
         | 
| @@ -210,6 +239,8 @@ class OnlineStreamingDataLoader(): | |
| 210 239 | 
             
                    timeout=15,
         | 
| 211 240 | 
             
                    retries=3,
         | 
| 212 241 | 
             
                    image_processor=default_image_processor,
         | 
| 242 | 
            +
                    upscale_interpolation=cv2.INTER_CUBIC,
         | 
| 243 | 
            +
                    downscale_interpolation=cv2.INTER_AREA,
         | 
| 213 244 | 
             
                ):
         | 
| 214 245 | 
             
                    if isinstance(dataset, str):
         | 
| 215 246 | 
             
                        dataset_path = dataset
         | 
| @@ -231,8 +262,11 @@ class OnlineStreamingDataLoader(): | |
| 231 262 | 
             
                        num_shards=global_process_count, index=global_process_index)
         | 
| 232 263 | 
             
                    print(f"Dataset length: {len(dataset)}")
         | 
| 233 264 | 
             
                    self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
         | 
| 265 | 
            +
                                                       min_image_shape=min_image_shape,
         | 
| 234 266 | 
             
                                                       num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
         | 
| 235 | 
            -
             | 
| 267 | 
            +
                                                        timeout=timeout, retries=retries, image_processor=image_processor,
         | 
| 268 | 
            +
                                                         upscale_interpolation=upscale_interpolation,
         | 
| 269 | 
            +
                                                         downscale_interpolation=downscale_interpolation)
         | 
| 236 270 | 
             
                    self.batch_size = batch_size
         | 
| 237 271 |  | 
| 238 272 | 
             
                    # Launch a thread to load batches in the background
         | 
| @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 2 2 | 
             
            flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
         | 
| 3 3 | 
             
            flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
         | 
| 4 | 
            -
            flaxdiff/data/online_loader.py,sha256= | 
| 4 | 
            +
            flaxdiff/data/online_loader.py,sha256=XVT_kT7v9CQVaQgunTL48KxgPgwQ-bhIi8RN-Q1qbYc,10451
         | 
| 5 5 | 
             
            flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
         | 
| 6 6 | 
             
            flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
         | 
| 7 7 | 
             
            flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
         | 
| @@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk, | |
| 34 34 | 
             
            flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
         | 
| 35 35 | 
             
            flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
         | 
| 36 36 | 
             
            flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
         | 
| 37 | 
            -
            flaxdiff-0.1. | 
| 38 | 
            -
            flaxdiff-0.1. | 
| 39 | 
            -
            flaxdiff-0.1. | 
| 40 | 
            -
            flaxdiff-0.1. | 
| 37 | 
            +
            flaxdiff-0.1.20.dist-info/METADATA,sha256=ls0rUYnHBWdChfQ7meO2nlHSqGVEPn2JzZTOTagt2H8,22083
         | 
| 38 | 
            +
            flaxdiff-0.1.20.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
         | 
| 39 | 
            +
            flaxdiff-0.1.20.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
         | 
| 40 | 
            +
            flaxdiff-0.1.20.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |