flaxdiff 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +1 -0
- flaxdiff/data/online_loader.py +208 -0
- flaxdiff/models/autoencoder/diffusers.py +3 -3
- {flaxdiff-0.1.12.dist-info → flaxdiff-0.1.14.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.12.dist-info → flaxdiff-0.1.14.dist-info}/RECORD +7 -5
- {flaxdiff-0.1.12.dist-info → flaxdiff-0.1.14.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.12.dist-info → flaxdiff-0.1.14.dist-info}/top_level.txt +0 -0
| @@ -0,0 +1 @@ | |
| 1 | 
            +
            from .online_loader import OnlineStreamingDataLoader
         | 
| @@ -0,0 +1,208 @@ | |
| 1 | 
            +
            import multiprocessing
         | 
| 2 | 
            +
            import threading
         | 
| 3 | 
            +
            from multiprocessing import Queue
         | 
| 4 | 
            +
            # from arrayqueues.shared_arrays import ArrayQueue
         | 
| 5 | 
            +
            # from faster_fifo import Queue
         | 
| 6 | 
            +
            import time
         | 
| 7 | 
            +
            import albumentations as A
         | 
| 8 | 
            +
            import queue
         | 
| 9 | 
            +
            import cv2
         | 
| 10 | 
            +
            from functools import partial
         | 
| 11 | 
            +
            from typing import Any, Dict, List, Tuple
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            from functools import partial
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from datasets import load_dataset, concatenate_datasets, Dataset
         | 
| 17 | 
            +
            from datasets.utils.file_utils import get_datasets_user_agent
         | 
| 18 | 
            +
            from concurrent.futures import ThreadPoolExecutor
         | 
| 19 | 
            +
            import io
         | 
| 20 | 
            +
            import urllib
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import PIL.Image
         | 
| 23 | 
            +
            import cv2
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            USER_AGENT = get_datasets_user_agent()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            data_queue = Queue(16*2000)
         | 
| 28 | 
            +
            error_queue = Queue(16*2000)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def fetch_single_image(image_url, timeout=None, retries=0):
         | 
| 32 | 
            +
                for _ in range(retries + 1):
         | 
| 33 | 
            +
                    try:
         | 
| 34 | 
            +
                        request = urllib.request.Request(
         | 
| 35 | 
            +
                            image_url,
         | 
| 36 | 
            +
                            data=None,
         | 
| 37 | 
            +
                            headers={"user-agent": USER_AGENT},
         | 
| 38 | 
            +
                        )
         | 
| 39 | 
            +
                        with urllib.request.urlopen(request, timeout=timeout) as req:
         | 
| 40 | 
            +
                            image = PIL.Image.open(io.BytesIO(req.read()))
         | 
| 41 | 
            +
                        break
         | 
| 42 | 
            +
                    except Exception:
         | 
| 43 | 
            +
                        image = None
         | 
| 44 | 
            +
                return image
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def map_sample(
         | 
| 47 | 
            +
                url, caption, 
         | 
| 48 | 
            +
                image_shape=(256, 256),
         | 
| 49 | 
            +
                upscale_interpolation=cv2.INTER_LANCZOS4,
         | 
| 50 | 
            +
                downscale_interpolation=cv2.INTER_AREA,
         | 
| 51 | 
            +
            ):
         | 
| 52 | 
            +
                try:
         | 
| 53 | 
            +
                    image = fetch_single_image(url, timeout=15, retries=3)  # Assuming fetch_single_image is defined elsewhere
         | 
| 54 | 
            +
                    if image is None:
         | 
| 55 | 
            +
                        return
         | 
| 56 | 
            +
                    
         | 
| 57 | 
            +
                    image = np.array(image)
         | 
| 58 | 
            +
                    original_height, original_width = image.shape[:2]
         | 
| 59 | 
            +
                    # check if the image is too small
         | 
| 60 | 
            +
                    if min(original_height, original_width) < min(image_shape):
         | 
| 61 | 
            +
                        return
         | 
| 62 | 
            +
                    # check if wrong aspect ratio
         | 
| 63 | 
            +
                    if max(original_height, original_width) / min(original_height, original_width) > 2:
         | 
| 64 | 
            +
                        return
         | 
| 65 | 
            +
                    # check if the variance is too low
         | 
| 66 | 
            +
                    if np.std(image) < 1e-4:
         | 
| 67 | 
            +
                        return
         | 
| 68 | 
            +
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
         | 
| 69 | 
            +
                    downscale = max(original_width, original_height) > max(image_shape)
         | 
| 70 | 
            +
                    interpolation = downscale_interpolation if downscale else upscale_interpolation
         | 
| 71 | 
            +
                    image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
         | 
| 72 | 
            +
                    image = A.pad(
         | 
| 73 | 
            +
                        image,
         | 
| 74 | 
            +
                        min_height=image_shape[0],
         | 
| 75 | 
            +
                        min_width=image_shape[1],
         | 
| 76 | 
            +
                        border_mode=cv2.BORDER_CONSTANT,
         | 
| 77 | 
            +
                        value=[255, 255, 255],
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
                    data_queue.put({
         | 
| 80 | 
            +
                        "url": url,
         | 
| 81 | 
            +
                        "caption": caption,
         | 
| 82 | 
            +
                        "image": image,
         | 
| 83 | 
            +
                        "original_height": original_height,
         | 
| 84 | 
            +
                        "original_width": original_width,
         | 
| 85 | 
            +
                    })
         | 
| 86 | 
            +
                except Exception as e:
         | 
| 87 | 
            +
                    error_queue.put({
         | 
| 88 | 
            +
                        "url": url,
         | 
| 89 | 
            +
                        "caption": caption,
         | 
| 90 | 
            +
                        "error": str(e)
         | 
| 91 | 
            +
                    })
         | 
| 92 | 
            +
                    
         | 
| 93 | 
            +
            def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
         | 
| 94 | 
            +
                with ThreadPoolExecutor(max_workers=num_threads) as executor:
         | 
| 95 | 
            +
                    executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
            def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
         | 
| 98 | 
            +
                map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
         | 
| 99 | 
            +
                shard_len = len(dataset) // num_workers
         | 
| 100 | 
            +
                print(f"Local Shard lengths: {shard_len}")
         | 
| 101 | 
            +
                with multiprocessing.Pool(num_workers) as pool:
         | 
| 102 | 
            +
                    iteration = 0
         | 
| 103 | 
            +
                    while True:
         | 
| 104 | 
            +
                        # Repeat forever
         | 
| 105 | 
            +
                        dataset = dataset.shuffle(seed=iteration)
         | 
| 106 | 
            +
                        shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
         | 
| 107 | 
            +
                        pool.map(map_batch_fn, shards)
         | 
| 108 | 
            +
                        iteration += 1
         | 
| 109 | 
            +
                        
         | 
| 110 | 
            +
            class ImageBatchIterator:
         | 
| 111 | 
            +
                def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
         | 
| 112 | 
            +
                    self.dataset = dataset
         | 
| 113 | 
            +
                    self.num_workers = num_workers
         | 
| 114 | 
            +
                    self.batch_size = batch_size
         | 
| 115 | 
            +
                    loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
         | 
| 116 | 
            +
                    self.thread = threading.Thread(target=loader, args=(dataset))
         | 
| 117 | 
            +
                    self.thread.start()
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                def __iter__(self):
         | 
| 120 | 
            +
                    return self
         | 
| 121 | 
            +
                
         | 
| 122 | 
            +
                def __next__(self):
         | 
| 123 | 
            +
                    def fetcher(_):
         | 
| 124 | 
            +
                        return data_queue.get()
         | 
| 125 | 
            +
                    with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
         | 
| 126 | 
            +
                        batch = list(executor.map(fetcher, range(self.batch_size)))
         | 
| 127 | 
            +
                    return batch
         | 
| 128 | 
            +
                
         | 
| 129 | 
            +
                def __del__(self):
         | 
| 130 | 
            +
                    self.thread.join()
         | 
| 131 | 
            +
                    
         | 
| 132 | 
            +
                def __len__(self):
         | 
| 133 | 
            +
                    return len(self.dataset) // self.batch_size
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
            def default_collate(batch):
         | 
| 136 | 
            +
                urls = [sample["url"] for sample in batch]
         | 
| 137 | 
            +
                captions = [sample["caption"] for sample in batch]
         | 
| 138 | 
            +
                images = np.stack([sample["image"] for sample in batch], axis=0)
         | 
| 139 | 
            +
                return {
         | 
| 140 | 
            +
                    "url": urls,
         | 
| 141 | 
            +
                    "caption": captions,
         | 
| 142 | 
            +
                    "image": images,
         | 
| 143 | 
            +
                }
         | 
| 144 | 
            +
                
         | 
| 145 | 
            +
            def dataMapper(map: Dict[str, Any]):
         | 
| 146 | 
            +
                def _map(sample) -> Dict[str, Any]:
         | 
| 147 | 
            +
                    return {
         | 
| 148 | 
            +
                        "url": sample[map["url"]],
         | 
| 149 | 
            +
                        "caption": sample[map["caption"]],
         | 
| 150 | 
            +
                    }
         | 
| 151 | 
            +
                return _map
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            class OnlineStreamingDataLoader():
         | 
| 154 | 
            +
                def __init__(
         | 
| 155 | 
            +
                    self, 
         | 
| 156 | 
            +
                    dataset, 
         | 
| 157 | 
            +
                    batch_size=64, 
         | 
| 158 | 
            +
                    image_shape=(256, 256),
         | 
| 159 | 
            +
                    num_workers=16, 
         | 
| 160 | 
            +
                    num_threads=512,
         | 
| 161 | 
            +
                    default_split="all",
         | 
| 162 | 
            +
                    pre_map_maker=dataMapper, 
         | 
| 163 | 
            +
                    pre_map_def={
         | 
| 164 | 
            +
                        "url": "URL",
         | 
| 165 | 
            +
                        "caption": "TEXT",
         | 
| 166 | 
            +
                    },
         | 
| 167 | 
            +
                    global_process_count=1,
         | 
| 168 | 
            +
                    global_process_index=0,
         | 
| 169 | 
            +
                    prefetch=1000,
         | 
| 170 | 
            +
                    collate_fn=default_collate,
         | 
| 171 | 
            +
                ):
         | 
| 172 | 
            +
                    if isinstance(dataset, str):
         | 
| 173 | 
            +
                        dataset_path = dataset
         | 
| 174 | 
            +
                        print("Loading dataset from path")
         | 
| 175 | 
            +
                        dataset = load_dataset(dataset_path, split=default_split)
         | 
| 176 | 
            +
                    elif isinstance(dataset, list):
         | 
| 177 | 
            +
                        if isinstance(dataset[0], str):
         | 
| 178 | 
            +
                            print("Loading multiple datasets from paths")
         | 
| 179 | 
            +
                            dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
         | 
| 180 | 
            +
                        else:
         | 
| 181 | 
            +
                            print("Concatenating multiple datasets")
         | 
| 182 | 
            +
                            dataset = concatenate_datasets(dataset)
         | 
| 183 | 
            +
                    dataset = dataset.map(pre_map_maker(pre_map_def))
         | 
| 184 | 
            +
                    self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
         | 
| 185 | 
            +
                    print(f"Dataset length: {len(dataset)}")
         | 
| 186 | 
            +
                    self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
         | 
| 187 | 
            +
                    self.collate_fn = collate_fn
         | 
| 188 | 
            +
                    
         | 
| 189 | 
            +
                    # Launch a thread to load batches in the background
         | 
| 190 | 
            +
                    self.batch_queue = queue.Queue(prefetch)
         | 
| 191 | 
            +
                    
         | 
| 192 | 
            +
                    def batch_loader():
         | 
| 193 | 
            +
                        for batch in self.iterator:
         | 
| 194 | 
            +
                            self.batch_queue.put(batch)
         | 
| 195 | 
            +
                    
         | 
| 196 | 
            +
                    self.loader_thread = threading.Thread(target=batch_loader)
         | 
| 197 | 
            +
                    self.loader_thread.start()
         | 
| 198 | 
            +
                    
         | 
| 199 | 
            +
                def __iter__(self):
         | 
| 200 | 
            +
                    return self
         | 
| 201 | 
            +
                
         | 
| 202 | 
            +
                def __next__(self):
         | 
| 203 | 
            +
                    return self.collate_fn(self.batch_queue.get())
         | 
| 204 | 
            +
                    # return self.collate_fn(next(self.iterator))
         | 
| 205 | 
            +
                    
         | 
| 206 | 
            +
                def __len__(self):
         | 
| 207 | 
            +
                    return len(self.dataset) // self.batch_size
         | 
| 208 | 
            +
                
         | 
| @@ -11,15 +11,15 @@ All credits for the model go to the developers of Stable Diffusion VAE and all c | |
| 11 11 | 
             
            """
         | 
| 12 12 |  | 
| 13 13 | 
             
            class StableDiffusionVAE(AutoEncoder):
         | 
| 14 | 
            -
                def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
         | 
| 14 | 
            +
                def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
         | 
| 15 15 |  | 
| 16 16 | 
             
                    from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
         | 
| 17 17 | 
             
                    from diffusers import FlaxStableDiffusionPipeline
         | 
| 18 18 |  | 
| 19 19 | 
             
                    pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
         | 
| 20 20 | 
             
                        modelname,
         | 
| 21 | 
            -
                        revision= | 
| 22 | 
            -
                        dtype= | 
| 21 | 
            +
                        revision=revision,
         | 
| 22 | 
            +
                        dtype=dtype,
         | 
| 23 23 | 
             
                    )
         | 
| 24 24 |  | 
| 25 25 | 
             
                    vae = pipeline.vae
         | 
| @@ -1,5 +1,7 @@ | |
| 1 1 | 
             
            flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 2 2 | 
             
            flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
         | 
| 3 | 
            +
            flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
         | 
| 4 | 
            +
            flaxdiff/data/online_loader.py,sha256=IjCVdeq18lF71eLX8RplJLezjKASO1cSFVe2GSYkLQ8,7283
         | 
| 3 5 | 
             
            flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
         | 
| 4 6 | 
             
            flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
         | 
| 5 7 | 
             
            flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
         | 
| @@ -8,7 +10,7 @@ flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZ | |
| 8 10 | 
             
            flaxdiff/models/simple_vit.py,sha256=xD23i1b7WEvoH4tUMsLyCe9ebDcv-PpaV0Nso38Jlb8,3887
         | 
| 9 11 | 
             
            flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
         | 
| 10 12 | 
             
            flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
         | 
| 11 | 
            -
            flaxdiff/models/autoencoder/diffusers.py,sha256= | 
| 13 | 
            +
            flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
         | 
| 12 14 | 
             
            flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
         | 
| 13 15 | 
             
            flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
         | 
| 14 16 | 
             
            flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
         | 
| @@ -32,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk, | |
| 32 34 | 
             
            flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
         | 
| 33 35 | 
             
            flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
         | 
| 34 36 | 
             
            flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
         | 
| 35 | 
            -
            flaxdiff-0.1. | 
| 36 | 
            -
            flaxdiff-0.1. | 
| 37 | 
            -
            flaxdiff-0.1. | 
| 38 | 
            -
            flaxdiff-0.1. | 
| 37 | 
            +
            flaxdiff-0.1.14.dist-info/METADATA,sha256=rcwF7cCFfgPLHn5gD7GZ_KvtG6EmiAIiel7xt8HylAo,22083
         | 
| 38 | 
            +
            flaxdiff-0.1.14.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
         | 
| 39 | 
            +
            flaxdiff-0.1.14.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
         | 
| 40 | 
            +
            flaxdiff-0.1.14.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |