flaxdiff 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +92 -46
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/top_level.txt +0 -0
    
        flaxdiff/data/online_loader.py
    CHANGED
    
    | @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Tuple | |
| 13 13 | 
             
            import numpy as np
         | 
| 14 14 | 
             
            from functools import partial
         | 
| 15 15 |  | 
| 16 | 
            -
            from datasets import load_dataset, concatenate_datasets, Dataset
         | 
| 16 | 
            +
            from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
         | 
| 17 17 | 
             
            from datasets.utils.file_utils import get_datasets_user_agent
         | 
| 18 18 | 
             
            from concurrent.futures import ThreadPoolExecutor
         | 
| 19 19 | 
             
            import io
         | 
| @@ -43,17 +43,35 @@ def fetch_single_image(image_url, timeout=None, retries=0): | |
| 43 43 | 
             
                        image = None
         | 
| 44 44 | 
             
                return image
         | 
| 45 45 |  | 
| 46 | 
            +
             | 
| 47 | 
            +
            def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
         | 
| 48 | 
            +
                image = A.longest_max_size(image, max(
         | 
| 49 | 
            +
                    image_shape), interpolation=interpolation)
         | 
| 50 | 
            +
                image = A.pad(
         | 
| 51 | 
            +
                    image,
         | 
| 52 | 
            +
                    min_height=image_shape[0],
         | 
| 53 | 
            +
                    min_width=image_shape[1],
         | 
| 54 | 
            +
                    border_mode=cv2.BORDER_CONSTANT,
         | 
| 55 | 
            +
                    value=[255, 255, 255],
         | 
| 56 | 
            +
                )
         | 
| 57 | 
            +
                return image
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 46 60 | 
             
            def map_sample(
         | 
| 47 | 
            -
                url, caption, | 
| 61 | 
            +
                url, caption,
         | 
| 48 62 | 
             
                image_shape=(256, 256),
         | 
| 63 | 
            +
                timeout=15,
         | 
| 64 | 
            +
                retries=3,
         | 
| 49 65 | 
             
                upscale_interpolation=cv2.INTER_LANCZOS4,
         | 
| 50 66 | 
             
                downscale_interpolation=cv2.INTER_AREA,
         | 
| 67 | 
            +
                image_processor=default_image_processor,
         | 
| 51 68 | 
             
            ):
         | 
| 52 69 | 
             
                try:
         | 
| 53 | 
            -
                     | 
| 70 | 
            +
                    # Assuming fetch_single_image is defined elsewhere
         | 
| 71 | 
            +
                    image = fetch_single_image(url, timeout=timeout, retries=retries)
         | 
| 54 72 | 
             
                    if image is None:
         | 
| 55 73 | 
             
                        return
         | 
| 56 | 
            -
             | 
| 74 | 
            +
             | 
| 57 75 | 
             
                    image = np.array(image)
         | 
| 58 76 | 
             
                    original_height, original_width = image.shape[:2]
         | 
| 59 77 | 
             
                    # check if the image is too small
         | 
| @@ -68,14 +86,10 @@ def map_sample( | |
| 68 86 | 
             
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         | 
| 69 87 | 
             
                    downscale = max(original_width, original_height) > max(image_shape)
         | 
| 70 88 | 
             
                    interpolation = downscale_interpolation if downscale else upscale_interpolation
         | 
| 71 | 
            -
             | 
| 72 | 
            -
                    image =  | 
| 73 | 
            -
                        image,
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                        min_width=image_shape[1],
         | 
| 76 | 
            -
                        border_mode=cv2.BORDER_CONSTANT,
         | 
| 77 | 
            -
                        value=[255, 255, 255],
         | 
| 78 | 
            -
                    )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    image = image_processor(
         | 
| 91 | 
            +
                        image, image_shape, interpolation=interpolation)
         | 
| 92 | 
            +
             | 
| 79 93 | 
             
                    data_queue.put({
         | 
| 80 94 | 
             
                        "url": url,
         | 
| 81 95 | 
             
                        "caption": caption,
         | 
| @@ -89,49 +103,69 @@ def map_sample( | |
| 89 103 | 
             
                        "caption": caption,
         | 
| 90 104 | 
             
                        "error": str(e)
         | 
| 91 105 | 
             
                    })
         | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
         | 
| 109 | 
            +
                try:
         | 
| 110 | 
            +
                    map_sample_fn = partial(map_sample, image_shape=image_shape,
         | 
| 111 | 
            +
                                            timeout=timeout, retries=retries, image_processor=image_processor)
         | 
| 112 | 
            +
                    with ThreadPoolExecutor(max_workers=num_threads) as executor:
         | 
| 113 | 
            +
                        executor.map(map_sample_fn, batch["url"], batch['caption'])
         | 
| 114 | 
            +
                except Exception as e:
         | 
| 115 | 
            +
                    error_queue.put({
         | 
| 116 | 
            +
                        "batch": batch,
         | 
| 117 | 
            +
                        "error": str(e)
         | 
| 118 | 
            +
                    })
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), 
         | 
| 122 | 
            +
                                      num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
         | 
| 123 | 
            +
                map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
         | 
| 124 | 
            +
                                       timeout=timeout, retries=retries, image_processor=image_processor)
         | 
| 99 125 | 
             
                shard_len = len(dataset) // num_workers
         | 
| 100 126 | 
             
                print(f"Local Shard lengths: {shard_len}")
         | 
| 101 127 | 
             
                with multiprocessing.Pool(num_workers) as pool:
         | 
| 102 128 | 
             
                    iteration = 0
         | 
| 103 129 | 
             
                    while True:
         | 
| 104 130 | 
             
                        # Repeat forever
         | 
| 105 | 
            -
                         | 
| 106 | 
            -
             | 
| 131 | 
            +
                        shards = [dataset[i*shard_len:(i+1)*shard_len]
         | 
| 132 | 
            +
                                  for i in range(num_workers)]
         | 
| 133 | 
            +
                        print(f"mapping {len(shards)} shards")
         | 
| 107 134 | 
             
                        pool.map(map_batch_fn, shards)
         | 
| 108 135 | 
             
                        iteration += 1
         | 
| 109 | 
            -
                        
         | 
| 136 | 
            +
                        print(f"Shuffling dataset with seed {iteration}")
         | 
| 137 | 
            +
                        dataset = dataset.shuffle(seed=iteration)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 110 140 | 
             
            class ImageBatchIterator:
         | 
| 111 | 
            -
                def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),  | 
| 141 | 
            +
                def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), 
         | 
| 142 | 
            +
                             num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
         | 
| 112 143 | 
             
                    self.dataset = dataset
         | 
| 113 144 | 
             
                    self.num_workers = num_workers
         | 
| 114 145 | 
             
                    self.batch_size = batch_size
         | 
| 115 | 
            -
                    loader = partial(parallel_image_loader, num_threads=num_threads, | 
| 146 | 
            +
                    loader = partial(parallel_image_loader, num_threads=num_threads,
         | 
| 147 | 
            +
                                     image_shape=image_shape, num_workers=num_workers, 
         | 
| 148 | 
            +
                                     timeout=timeout, retries=retries, image_processor=image_processor)
         | 
| 116 149 | 
             
                    self.thread = threading.Thread(target=loader, args=(dataset,))
         | 
| 117 150 | 
             
                    self.thread.start()
         | 
| 118 | 
            -
             | 
| 151 | 
            +
             | 
| 119 152 | 
             
                def __iter__(self):
         | 
| 120 153 | 
             
                    return self
         | 
| 121 | 
            -
             | 
| 154 | 
            +
             | 
| 122 155 | 
             
                def __next__(self):
         | 
| 123 156 | 
             
                    def fetcher(_):
         | 
| 124 157 | 
             
                        return data_queue.get()
         | 
| 125 158 | 
             
                    with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
         | 
| 126 159 | 
             
                        batch = list(executor.map(fetcher, range(self.batch_size)))
         | 
| 127 160 | 
             
                    return batch
         | 
| 128 | 
            -
             | 
| 161 | 
            +
             | 
| 129 162 | 
             
                def __del__(self):
         | 
| 130 163 | 
             
                    self.thread.join()
         | 
| 131 | 
            -
             | 
| 164 | 
            +
             | 
| 132 165 | 
             
                def __len__(self):
         | 
| 133 166 | 
             
                    return len(self.dataset) // self.batch_size
         | 
| 134 167 |  | 
| 168 | 
            +
             | 
| 135 169 | 
             
            def default_collate(batch):
         | 
| 136 170 | 
             
                urls = [sample["url"] for sample in batch]
         | 
| 137 171 | 
             
                captions = [sample["caption"] for sample in batch]
         | 
| @@ -141,7 +175,8 @@ def default_collate(batch): | |
| 141 175 | 
             
                    "caption": captions,
         | 
| 142 176 | 
             
                    "image": images,
         | 
| 143 177 | 
             
                }
         | 
| 144 | 
            -
             | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 145 180 | 
             
            def dataMapper(map: Dict[str, Any]):
         | 
| 146 181 | 
             
                def _map(sample) -> Dict[str, Any]:
         | 
| 147 182 | 
             
                    return {
         | 
| @@ -150,16 +185,17 @@ def dataMapper(map: Dict[str, Any]): | |
| 150 185 | 
             
                    }
         | 
| 151 186 | 
             
                return _map
         | 
| 152 187 |  | 
| 188 | 
            +
             | 
| 153 189 | 
             
            class OnlineStreamingDataLoader():
         | 
| 154 190 | 
             
                def __init__(
         | 
| 155 | 
            -
                    self, | 
| 156 | 
            -
                    dataset, | 
| 157 | 
            -
                    batch_size=64, | 
| 191 | 
            +
                    self,
         | 
| 192 | 
            +
                    dataset,
         | 
| 193 | 
            +
                    batch_size=64,
         | 
| 158 194 | 
             
                    image_shape=(256, 256),
         | 
| 159 | 
            -
                    num_workers=16, | 
| 195 | 
            +
                    num_workers=16,
         | 
| 160 196 | 
             
                    num_threads=512,
         | 
| 161 197 | 
             
                    default_split="all",
         | 
| 162 | 
            -
                    pre_map_maker=dataMapper, | 
| 198 | 
            +
                    pre_map_maker=dataMapper,
         | 
| 163 199 | 
             
                    pre_map_def={
         | 
| 164 200 | 
             
                        "url": "URL",
         | 
| 165 201 | 
             
                        "caption": "TEXT",
         | 
| @@ -168,41 +204,51 @@ class OnlineStreamingDataLoader(): | |
| 168 204 | 
             
                    global_process_index=0,
         | 
| 169 205 | 
             
                    prefetch=1000,
         | 
| 170 206 | 
             
                    collate_fn=default_collate,
         | 
| 207 | 
            +
                    timeout=15,
         | 
| 208 | 
            +
                    retries=3,
         | 
| 209 | 
            +
                    image_processor=default_image_processor,
         | 
| 171 210 | 
             
                ):
         | 
| 172 211 | 
             
                    if isinstance(dataset, str):
         | 
| 173 212 | 
             
                        dataset_path = dataset
         | 
| 174 213 | 
             
                        print("Loading dataset from path")
         | 
| 175 | 
            -
                         | 
| 214 | 
            +
                        if "gs://" in dataset:
         | 
| 215 | 
            +
                            dataset = load_from_disk(dataset_path)
         | 
| 216 | 
            +
                        else:
         | 
| 217 | 
            +
                            dataset = load_dataset(dataset_path, split=default_split)
         | 
| 176 218 | 
             
                    elif isinstance(dataset, list):
         | 
| 177 219 | 
             
                        if isinstance(dataset[0], str):
         | 
| 178 220 | 
             
                            print("Loading multiple datasets from paths")
         | 
| 179 | 
            -
                            dataset = [ | 
| 221 | 
            +
                            dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
         | 
| 222 | 
            +
                                dataset_path, split=default_split) for dataset_path in dataset]
         | 
| 180 223 | 
             
                        print("Concatenating multiple datasets")
         | 
| 181 224 | 
             
                        dataset = concatenate_datasets(dataset)
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                     | 
| 225 | 
            +
                        dataset = dataset.shuffle(seed=0)
         | 
| 226 | 
            +
                    # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
         | 
| 227 | 
            +
                    self.dataset = dataset.shard(
         | 
| 228 | 
            +
                        num_shards=global_process_count, index=global_process_index)
         | 
| 184 229 | 
             
                    print(f"Dataset length: {len(dataset)}")
         | 
| 185 | 
            -
                    self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, | 
| 230 | 
            +
                    self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
         | 
| 231 | 
            +
                                                       num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
         | 
| 232 | 
            +
                                                         timeout=timeout, retries=retries, image_processor=image_processor)
         | 
| 186 233 | 
             
                    self.collate_fn = collate_fn
         | 
| 187 234 | 
             
                    self.batch_size = batch_size
         | 
| 188 | 
            -
             | 
| 235 | 
            +
             | 
| 189 236 | 
             
                    # Launch a thread to load batches in the background
         | 
| 190 237 | 
             
                    self.batch_queue = queue.Queue(prefetch)
         | 
| 191 | 
            -
             | 
| 238 | 
            +
             | 
| 192 239 | 
             
                    def batch_loader():
         | 
| 193 240 | 
             
                        for batch in self.iterator:
         | 
| 194 241 | 
             
                            self.batch_queue.put(batch)
         | 
| 195 | 
            -
             | 
| 242 | 
            +
             | 
| 196 243 | 
             
                    self.loader_thread = threading.Thread(target=batch_loader)
         | 
| 197 244 | 
             
                    self.loader_thread.start()
         | 
| 198 | 
            -
             | 
| 245 | 
            +
             | 
| 199 246 | 
             
                def __iter__(self):
         | 
| 200 247 | 
             
                    return self
         | 
| 201 | 
            -
             | 
| 248 | 
            +
             | 
| 202 249 | 
             
                def __next__(self):
         | 
| 203 250 | 
             
                    return self.collate_fn(self.batch_queue.get())
         | 
| 204 251 | 
             
                    # return self.collate_fn(next(self.iterator))
         | 
| 205 | 
            -
             | 
| 252 | 
            +
             | 
| 206 253 | 
             
                def __len__(self):
         | 
| 207 254 | 
             
                    return len(self.dataset)
         | 
| 208 | 
            -
                
         | 
| @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 2 2 | 
             
            flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
         | 
| 3 3 | 
             
            flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
         | 
| 4 | 
            -
            flaxdiff/data/online_loader.py,sha256= | 
| 4 | 
            +
            flaxdiff/data/online_loader.py,sha256=BM4Le-4BUo8MJpRzGIA2nMHKm4-WynQ2BOdiQz0JCDs,8791
         | 
| 5 5 | 
             
            flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
         | 
| 6 6 | 
             
            flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
         | 
| 7 7 | 
             
            flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
         | 
| @@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk, | |
| 34 34 | 
             
            flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
         | 
| 35 35 | 
             
            flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
         | 
| 36 36 | 
             
            flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
         | 
| 37 | 
            -
            flaxdiff-0.1. | 
| 38 | 
            -
            flaxdiff-0.1. | 
| 39 | 
            -
            flaxdiff-0.1. | 
| 40 | 
            -
            flaxdiff-0.1. | 
| 37 | 
            +
            flaxdiff-0.1.17.dist-info/METADATA,sha256=2Nr_T2yg3XHFt2jBuUXo8FxLYM8si-DBLdW_PBKxzc4,22083
         | 
| 38 | 
            +
            flaxdiff-0.1.17.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
         | 
| 39 | 
            +
            flaxdiff-0.1.17.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
         | 
| 40 | 
            +
            flaxdiff-0.1.17.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |