flaxdiff 0.2.6.1__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/PKG-INFO +1 -1
  2. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/dataloaders.py +36 -24
  3. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/dataset_map.py +2 -2
  4. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/base.py +12 -0
  5. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/images.py +71 -12
  6. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/videos.py +5 -0
  7. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/inference/pipeline.py +9 -4
  8. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/inference/utils.py +2 -2
  9. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/common.py +1 -70
  10. flaxdiff-0.2.8/flaxdiff/models/hilbert.py +617 -0
  11. flaxdiff-0.2.8/flaxdiff/models/simple_dit.py +476 -0
  12. flaxdiff-0.2.8/flaxdiff/models/simple_mmdit.py +861 -0
  13. flaxdiff-0.2.8/flaxdiff/models/simple_vit.py +347 -0
  14. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/trainer/general_diffusion_trainer.py +29 -10
  15. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/trainer/simple_trainer.py +113 -19
  16. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff.egg-info/PKG-INFO +1 -1
  17. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff.egg-info/SOURCES.txt +3 -1
  18. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/pyproject.toml +1 -1
  19. flaxdiff-0.2.6.1/flaxdiff/models/better_uvit.py +0 -380
  20. flaxdiff-0.2.6.1/flaxdiff/models/simple_vit.py +0 -186
  21. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/README.md +0 -0
  22. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/__init__.py +0 -0
  23. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/__init__.py +0 -0
  24. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/benchmark_decord.py +0 -0
  25. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/online_loader.py +0 -0
  26. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/audio_utils.py +0 -0
  27. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/av_example.py +0 -0
  28. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/av_utils.py +0 -0
  29. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/utils.py +0 -0
  30. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/data/sources/voxceleb2.py +0 -0
  31. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/inference/__init__.py +0 -0
  32. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/inputs/__init__.py +0 -0
  33. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/inputs/encoders.py +0 -0
  34. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/metrics/__init__.py +0 -0
  35. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/metrics/common.py +0 -0
  36. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/metrics/images.py +0 -0
  37. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/metrics/inception.py +0 -0
  38. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/metrics/psnr.py +0 -0
  39. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/metrics/ssim.py +0 -0
  40. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/metrics/utils.py +0 -0
  41. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/__init__.py +0 -0
  42. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/attention.py +0 -0
  43. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/__init__.py +0 -0
  44. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  45. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  46. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  47. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/favor_fastattn.py +0 -0
  48. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/general.py +0 -0
  49. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/simple_unet.py +0 -0
  50. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/unet_3d.py +0 -0
  51. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/models/unet_3d_blocks.py +0 -0
  52. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/predictors/__init__.py +0 -0
  53. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/__init__.py +0 -0
  54. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/common.py +0 -0
  55. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/ddim.py +0 -0
  56. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/ddpm.py +0 -0
  57. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/euler.py +0 -0
  58. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/heun_sampler.py +0 -0
  59. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/multistep_dpm.py +0 -0
  60. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/samplers/rk4_sampler.py +0 -0
  61. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/__init__.py +0 -0
  62. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/common.py +0 -0
  63. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/continuous.py +0 -0
  64. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/cosine.py +0 -0
  65. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/discrete.py +0 -0
  66. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/exp.py +0 -0
  67. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/karras.py +0 -0
  68. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/linear.py +0 -0
  69. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/schedulers/sqrt.py +0 -0
  70. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/trainer/__init__.py +0 -0
  71. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  72. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  73. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff/utils.py +0 -0
  74. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff.egg-info/dependency_links.txt +0 -0
  75. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff.egg-info/requires.txt +0 -0
  76. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/flaxdiff.egg-info/top_level.txt +0 -0
  77. {flaxdiff-0.2.6.1 → flaxdiff-0.2.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.6.1
3
+ Version: 0.2.8
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -251,6 +251,12 @@ def generate_collate_fn(media_type="image"):
251
251
  else: # Default to image
252
252
  return image_collate
253
253
 
254
+ class CaptionDeletionTransform(pygrain.MapTransform):
255
+ def map(self, element):
256
+ """Delete the caption from the element."""
257
+ if "caption" in element:
258
+ del element["caption"]
259
+ return element
254
260
 
255
261
  def get_dataset_grain(
256
262
  data_name="cc12m",
@@ -286,13 +292,14 @@ def get_dataset_grain(
286
292
  Dictionary with train dataset function and metadata.
287
293
  """
288
294
  dataset = datasetMap[data_name]
289
- data_source = dataset["source"](dataset_source)
295
+ train_source = dataset["source"](dataset_source, split="train")
296
+ # val_source = dataset["source"](dataset_source, split="val")
290
297
  augmenter = dataset["augmenter"](image_scale, method)
291
298
 
292
299
  local_batch_size = batch_size // jax.process_count()
293
300
 
294
301
  train_sampler = pygrain.IndexSampler(
295
- num_records=len(data_source) if count is None else count,
302
+ num_records=len(train_source) if count is None else count,
296
303
  shuffle=True,
297
304
  seed=seed,
298
305
  num_epochs=num_epochs,
@@ -300,7 +307,7 @@ def get_dataset_grain(
300
307
  )
301
308
 
302
309
  # val_sampler = pygrain.IndexSampler(
303
- # num_records=len(data_source) if count is None else count,
310
+ # num_records=len(val_source) if count is None else count,
304
311
  # shuffle=False,
305
312
  # seed=seed,
306
313
  # num_epochs=num_epochs,
@@ -310,11 +317,17 @@ def get_dataset_grain(
310
317
  def get_trainset():
311
318
  transformations = [
312
319
  augmenter(),
313
- pygrain.Batch(local_batch_size, drop_remainder=True),
314
320
  ]
321
+
322
+ # if filters:
323
+ # print("Adding filters to transformations")
324
+ # transformations.append(filters())
325
+
326
+ # transformations.append(CaptionDeletionTransform())
327
+ transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
315
328
 
316
329
  loader = pygrain.DataLoader(
317
- data_source=data_source,
330
+ data_source=train_source,
318
331
  sampler=train_sampler,
319
332
  operations=transformations,
320
333
  worker_count=worker_count,
@@ -325,30 +338,29 @@ def get_dataset_grain(
325
338
  )
326
339
  return loader
327
340
 
328
- # def get_valset():
329
- # transformations = [
330
- # augmenter(),
331
- # pygrain.Batch(local_batch_size, drop_remainder=True),
332
- # ]
341
+ def get_valset():
342
+ transformations = [
343
+ augmenter(),
344
+ pygrain.Batch(local_batch_size, drop_remainder=True),
345
+ ]
333
346
 
334
- # loader = pygrain.DataLoader(
335
- # data_source=data_source,
336
- # sampler=val_sampler,
337
- # operations=transformations,
338
- # worker_count=worker_count,
339
- # read_options=pygrain.ReadOptions(
340
- # read_thread_count, read_buffer_size
341
- # ),
342
- # worker_buffer_size=worker_buffer_size,
343
- # )
344
- # return loader
345
- get_valset = get_trainset # For now, use the same function for validation
347
+ loader = pygrain.DataLoader(
348
+ data_source=train_source,
349
+ sampler=train_sampler,
350
+ operations=transformations,
351
+ worker_count=2,
352
+ read_options=pygrain.ReadOptions(
353
+ read_thread_count, read_buffer_size
354
+ ),
355
+ worker_buffer_size=2,
356
+ )
357
+ return loader
346
358
 
347
359
  return {
348
360
  "train": get_trainset,
349
- "train_len": len(data_source),
361
+ "train_len": len(train_source),
350
362
  "val": get_valset,
351
- "val_len": len(data_source),
363
+ "val_len": len(train_source),
352
364
  "local_batch_size": local_batch_size,
353
365
  "global_batch_size": batch_size,
354
366
  }
@@ -8,7 +8,7 @@ from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugment
8
8
  # ---------------------------------------------------------------------------------
9
9
 
10
10
  from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
11
- from .sources.images import data_source_combined_gcs, gcs_augmenters
11
+ from .sources.images import data_source_combined_gcs, gcs_augmenters, gcs_filters
12
12
 
13
13
  # Configure the following for your datasets
14
14
  datasetMap = {
@@ -21,7 +21,7 @@ datasetMap = {
21
21
  "augmenter": gcs_augmenters,
22
22
  },
23
23
  "laiona_coco": {
24
- "source": data_source_gcs('datasets/laion12m+mscoco'),
24
+ "source": data_source_gcs('datasets/laion12m+mscoco_filtered-new'),
25
25
  "augmenter": gcs_augmenters,
26
26
  },
27
27
  "aesthetic_coyo": {
@@ -62,6 +62,18 @@ class DataAugmenter(ABC):
62
62
  """
63
63
  pass
64
64
 
65
+ @abstractmethod
66
+ def create_filter(self, **kwargs) -> Callable[[], pygrain.FilterTransform]:
67
+ """Create a filter function for the data.
68
+
69
+ Args:
70
+ **kwargs: Additional arguments for the filter.
71
+
72
+ Returns:
73
+ A callable that returns a pygrain.FilterTransform instance.
74
+ """
75
+ pass
76
+
65
77
  @staticmethod
66
78
  def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
67
79
  """Factory method to create a data augmenter of the specified type.
@@ -82,7 +82,7 @@ def labelizer_oxford_flowers102(path):
82
82
  class ImageTFDSSource(DataSource):
83
83
  """Data source for TensorFlow Datasets (TFDS) image datasets."""
84
84
 
85
- def __init__(self, name: str, use_tf: bool = True, split: str = "all"):
85
+ def __init__(self, name: str, use_tf: bool = True):
86
86
  """Initialize a TFDS image data source.
87
87
 
88
88
  Args:
@@ -92,9 +92,8 @@ class ImageTFDSSource(DataSource):
92
92
  """
93
93
  self.name = name
94
94
  self.use_tf = use_tf
95
- self.split = split
96
95
 
97
- def get_source(self, path_override: str) -> Any:
96
+ def get_source(self, path_override: str, split: str = "all") -> Any:
98
97
  """Get the TFDS data source.
99
98
 
100
99
  Args:
@@ -105,20 +104,22 @@ class ImageTFDSSource(DataSource):
105
104
  """
106
105
  import tensorflow_datasets as tfds
107
106
  if self.use_tf:
108
- return tfds.load(self.name, split=self.split, shuffle_files=True)
107
+ return tfds.load(self.name, split=split, shuffle_files=True)
109
108
  else:
110
- return tfds.data_source(self.name, split=self.split, try_gcs=False)
109
+ return tfds.data_source(self.name, split=split, try_gcs=False)
111
110
 
112
111
 
113
112
  class ImageTFDSAugmenter(DataAugmenter):
114
113
  """Augmenter for TFDS image datasets."""
115
114
 
116
- def __init__(self, label_path: str = "/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"):
115
+ def __init__(self, label_path: str = None):
117
116
  """Initialize a TFDS image augmenter.
118
117
 
119
118
  Args:
120
119
  label_path: Path to the labels file for datasets like Oxford Flowers.
121
120
  """
121
+ if label_path is None:
122
+ label_path = os.path.join(os.path.expanduser("~"), "tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
122
123
  self.label_path = label_path
123
124
 
124
125
  def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
@@ -166,7 +167,11 @@ class ImageTFDSAugmenter(DataAugmenter):
166
167
  }
167
168
 
168
169
  return TFDSTransform
169
-
170
+
171
+ def create_filter(self, image_scale: int = 256):
172
+ class FilterTransform(pygrain.FilterTransform):
173
+ def map(self, element) -> bool:
174
+ return True
170
175
  """
171
176
  Batch structure:
172
177
  {
@@ -193,7 +198,7 @@ class ImageGCSSource(DataSource):
193
198
  """
194
199
  self.source = source
195
200
 
196
- def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
201
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
197
202
  """Get the GCS data source.
198
203
 
199
204
  Args:
@@ -205,6 +210,8 @@ class ImageGCSSource(DataSource):
205
210
  records_path = os.path.join(path_override, self.source)
206
211
  records = [os.path.join(records_path, i) for i in os.listdir(
207
212
  records_path) if 'array_record' in i]
213
+ if split == "val":
214
+ records = records[:1]
208
215
  return pygrain.ArrayRecordDataSource(records)
209
216
 
210
217
 
@@ -219,7 +226,7 @@ class CombinedImageGCSSource(DataSource):
219
226
  """
220
227
  self.sources = sources
221
228
 
222
- def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
229
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
223
230
  """Get the combined GCS data source.
224
231
 
225
232
  Args:
@@ -233,9 +240,10 @@ class CombinedImageGCSSource(DataSource):
233
240
  for records_path in records_paths:
234
241
  records += [os.path.join(records_path, i) for i in os.listdir(
235
242
  records_path) if 'array_record' in i]
243
+ if split == "val":
244
+ records = records[:1]
236
245
  return pygrain.ArrayRecordDataSource(records)
237
246
 
238
-
239
247
  class ImageGCSAugmenter(DataAugmenter):
240
248
  """Augmenter for GCS image datasets."""
241
249
 
@@ -295,6 +303,52 @@ class ImageGCSAugmenter(DataAugmenter):
295
303
  }
296
304
 
297
305
  return GCSTransform
306
+
307
+ def create_filter(self, image_scale: int = 256):
308
+ import torch.nn.functional as F
309
+ class FilterTransform(pygrain.FilterTransform):
310
+ """
311
+ Filter transform for GCS data source.
312
+ """
313
+ def __init__(self, model=None, processor=None, method=cv2.INTER_AREA):
314
+ super().__init__()
315
+ self.image_scale = image_scale
316
+ if model is None:
317
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection, FlaxCLIPModel, CLIPModel
318
+ model_name = "openai/clip-vit-base-patch32"
319
+ model = CLIPModel.from_pretrained(model_name)
320
+ processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
321
+ self.method = method
322
+ self.model = model
323
+ self.processor = processor
324
+
325
+ # def _filter_(pixel_values, input_ids):
326
+ # image_embeds = self.model.get_image_features(pixel_values=pixel_values)
327
+ # text_embeds = self.model.get_text_features(input_ids=input_ids)
328
+ # image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
329
+ # text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
330
+ # similarity = jnp.sum(image_embeds * text_embeds, axis=-1)
331
+ # return jnp.all(similarity >= 0.25)
332
+
333
+ # self._filter_ = _filter_
334
+
335
+ def filter(self, data: Dict[str, Any]) -> bool:
336
+ images = [data['image']]
337
+ texts = [data['caption']]
338
+ inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
339
+ # result = self._filter_(
340
+ # pixel_values=inputs['pixel_values'],
341
+ # input_ids=inputs['input_ids']
342
+ # )
343
+ # return result
344
+
345
+ image_embeds = self.model.get_image_features(pixel_values=inputs['pixel_values'])
346
+ text_embeds = self.model.get_text_features(input_ids=inputs['input_ids'])
347
+ similarity = F.cosine_similarity(image_embeds, text_embeds)
348
+ # Filter out images with similarity less than 0.25
349
+ return similarity[0] >= 0.25
350
+
351
+ return FilterTransform
298
352
 
299
353
 
300
354
  # ----------------------------------------------------------------------------------
@@ -303,9 +357,9 @@ class ImageGCSAugmenter(DataAugmenter):
303
357
 
304
358
  # These functions maintain backward compatibility with existing code
305
359
 
306
- def data_source_tfds(name, use_tf=True, split="all"):
360
+ def data_source_tfds(name, use_tf=True):
307
361
  """Legacy function for TFDS data sources."""
308
- source = ImageTFDSSource(name=name, use_tf=use_tf, split=split)
362
+ source = ImageTFDSSource(name=name, use_tf=use_tf)
309
363
  return source.get_source
310
364
 
311
365
 
@@ -331,3 +385,8 @@ def gcs_augmenters(image_scale, method):
331
385
  """Legacy function for GCS augmenters."""
332
386
  augmenter = ImageGCSAugmenter()
333
387
  return augmenter.create_transform(image_scale=image_scale, method=method)
388
+
389
+ def gcs_filters(image_scale):
390
+ """Legacy function for GCS Filters."""
391
+ augmenter = ImageGCSAugmenter()
392
+ return augmenter.create_filter(image_scale=image_scale)
@@ -216,6 +216,11 @@ class AudioVideoAugmenter(DataAugmenter):
216
216
 
217
217
  return AudioVideoTransform
218
218
 
219
+
220
+ def create_filter(self, image_scale: int = 256):
221
+ class FilterTransform(pygrain.FilterTransform):
222
+ def map(self, element) -> bool:
223
+ return True
219
224
 
220
225
  # ----------------------------------------------------------------------------------
221
226
  # Helper functions for video datasets
@@ -25,6 +25,7 @@ from flaxdiff.inference.utils import parse_config, load_from_wandb_run, load_fro
25
25
  @dataclass
26
26
  class InferencePipeline:
27
27
  """Inference pipeline for a general model."""
28
+ name: str = None
28
29
  model: nn.Module = None
29
30
  state: SimpleTrainState = None
30
31
  best_state: SimpleTrainState = None
@@ -44,6 +45,7 @@ class DiffusionInferencePipeline(InferencePipeline):
44
45
  This pipeline handles loading models from wandb and generating samples using the
45
46
  DiffusionSampler from FlaxDiff.
46
47
  """
48
+ artifact: Any = None
47
49
  state: TrainState = None
48
50
  best_state: TrainState = None
49
51
  rngstate: Optional[RandomMarkovState] = None
@@ -51,7 +53,6 @@ class DiffusionInferencePipeline(InferencePipeline):
51
53
  model_output_transform: DiffusionPredictionTransform = None
52
54
  autoencoder: AutoEncoder = None
53
55
  input_config: DiffusionInputConfig = None
54
- wandb_run = None
55
56
  samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
56
57
  config: Dict[str, Any] = field(default_factory=dict)
57
58
 
@@ -76,7 +77,7 @@ class DiffusionInferencePipeline(InferencePipeline):
76
77
  Returns:
77
78
  DiffusionInferencePipeline instance
78
79
  """
79
- states, config, run = load_from_wandb_run(
80
+ states, config, run, artifact = load_from_wandb_run(
80
81
  wandb_run,
81
82
  project=project,
82
83
  entity=entity,
@@ -95,6 +96,7 @@ class DiffusionInferencePipeline(InferencePipeline):
95
96
  best_state=best_state,
96
97
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
97
98
  run=run,
99
+ artifact=artifact,
98
100
  )
99
101
  return pipeline
100
102
 
@@ -119,7 +121,7 @@ class DiffusionInferencePipeline(InferencePipeline):
119
121
  Returns:
120
122
  DiffusionInferencePipeline instance
121
123
  """
122
- states, config, run = load_from_wandb_registry(
124
+ states, config, run, artifact = load_from_wandb_registry(
123
125
  modelname=modelname,
124
126
  project=project,
125
127
  entity=entity,
@@ -140,6 +142,7 @@ class DiffusionInferencePipeline(InferencePipeline):
140
142
  best_state=best_state,
141
143
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
142
144
  run=run,
145
+ artifact=artifact,
143
146
  )
144
147
  return pipeline
145
148
 
@@ -151,11 +154,14 @@ class DiffusionInferencePipeline(InferencePipeline):
151
154
  best_state: Optional[Dict[str, Any]] = None,
152
155
  rngstate: Optional[RandomMarkovState] = None,
153
156
  run=None,
157
+ artifact=None,
154
158
  ):
155
159
  if rngstate is None:
156
160
  rngstate = RandomMarkovState(jax.random.PRNGKey(42))
157
161
  # Build and return pipeline
158
162
  return cls(
163
+ name=run.name if run else None,
164
+ artifact=artifact,
159
165
  model=config['model'],
160
166
  state=state,
161
167
  best_state=best_state,
@@ -165,7 +171,6 @@ class DiffusionInferencePipeline(InferencePipeline):
165
171
  autoencoder=config['autoencoder'],
166
172
  input_config=config['input_config'],
167
173
  config=config,
168
- wandb_run=run,
169
174
  )
170
175
 
171
176
  def get_sampler(
@@ -292,7 +292,7 @@ def load_from_wandb_run(
292
292
  config = run.config
293
293
  except Exception as e:
294
294
  print(f"Warning: Failed to load model from wandb: {e}")
295
- return states, config, run
295
+ return states, config, run, artifact
296
296
 
297
297
  def load_from_wandb_registry(
298
298
  modelname: str,
@@ -318,4 +318,4 @@ def load_from_wandb_registry(
318
318
  config = run.config
319
319
  except Exception as e:
320
320
  print(f"Warning: Failed to load model from wandb: {e}")
321
- return states, config, run
321
+ return states, config, run, artifact
@@ -335,73 +335,4 @@ class ResidualBlock(nn.Module):
335
335
 
336
336
  out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
337
337
 
338
- return out
339
-
340
- # Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
341
- def _d2xy(n, d):
342
- x = 0
343
- y = 0
344
- t = d
345
- s = 1
346
- while s < n:
347
- rx = (t // 2) & 1
348
- ry = (t ^ rx) & 1
349
- if ry == 0:
350
- if rx == 1:
351
- x = n - 1 - x
352
- y = n - 1 - y
353
- x, y = y, x
354
- x += s * rx
355
- y += s * ry
356
- t //= 4
357
- s *= 2
358
- return x, y
359
-
360
- # Hilbert index mapping for a rectangular grid of patches H_P x W_P
361
-
362
- def hilbert_indices(H_P, W_P):
363
- size = max(H_P, W_P)
364
- order = math.ceil(math.log2(size))
365
- n = 1 << order
366
- coords = []
367
- for d in range(n * n):
368
- x, y = _d2xy(n, d)
369
- # x is column index, y is row index
370
- if x < W_P and y < H_P:
371
- coords.append((y, x)) # (row, col)
372
- if len(coords) == H_P * W_P:
373
- break
374
- # Convert (row, col) to linear indices row-major
375
- indices = [r * W_P + c for r, c in coords]
376
- return jnp.array(indices, dtype=jnp.int32)
377
-
378
- # Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
379
-
380
- def inverse_permutation(idx):
381
- inv = jnp.zeros_like(idx)
382
- inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
383
- return inv
384
-
385
- # Patchify using Hilbert ordering: extract patches and reorder sequence
386
-
387
- def hilbert_patchify(x, patch_size):
388
- B, H, W, C = x.shape
389
- H_P = H // patch_size
390
- W_P = W // patch_size
391
- # Extract patches in row-major
392
- patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
393
- idx = hilbert_indices(H_P, W_P)
394
- return patches[:, idx, :]
395
-
396
- # Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
397
-
398
- def hilbert_unpatchify(patches, patch_size, H, W, C):
399
- B, N, D = patches.shape
400
- H_P = H // patch_size
401
- W_P = W // patch_size
402
- inv = inverse_permutation(hilbert_indices(H_P, W_P))
403
- # Reorder back to row-major
404
- linear = patches[:, inv, :]
405
- # Reconstruct image
406
- x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
407
- return x
338
+ return out