flaxdiff 0.2.6.1__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/PKG-INFO +1 -1
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/images.py +3 -1
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/inference/pipeline.py +9 -4
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/inference/utils.py +2 -2
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/pyproject.toml +1 -1
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/README.md +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/benchmark_decord.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/dataloaders.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/audio_utils.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/av_example.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/av_utils.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/base.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/utils.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/videos.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/data/sources/voxceleb2.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/inference/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/inputs/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/inputs/encoders.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/metrics/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/metrics/common.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/metrics/images.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/metrics/psnr.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/better_uvit.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/general.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/unet_3d.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/models/unet_3d_blocks.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/trainer/general_diffusion_trainer.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.2.6.1 → flaxdiff-0.2.7}/setup.cfg +0 -0
@@ -113,12 +113,14 @@ class ImageTFDSSource(DataSource):
|
|
113
113
|
class ImageTFDSAugmenter(DataAugmenter):
|
114
114
|
"""Augmenter for TFDS image datasets."""
|
115
115
|
|
116
|
-
def __init__(self, label_path: str =
|
116
|
+
def __init__(self, label_path: str = None):
|
117
117
|
"""Initialize a TFDS image augmenter.
|
118
118
|
|
119
119
|
Args:
|
120
120
|
label_path: Path to the labels file for datasets like Oxford Flowers.
|
121
121
|
"""
|
122
|
+
if label_path is None:
|
123
|
+
label_path = os.path.join(os.path.expanduser("~"), "tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
|
122
124
|
self.label_path = label_path
|
123
125
|
|
124
126
|
def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
|
@@ -25,6 +25,7 @@ from flaxdiff.inference.utils import parse_config, load_from_wandb_run, load_fro
|
|
25
25
|
@dataclass
|
26
26
|
class InferencePipeline:
|
27
27
|
"""Inference pipeline for a general model."""
|
28
|
+
name: str = None
|
28
29
|
model: nn.Module = None
|
29
30
|
state: SimpleTrainState = None
|
30
31
|
best_state: SimpleTrainState = None
|
@@ -44,6 +45,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
44
45
|
This pipeline handles loading models from wandb and generating samples using the
|
45
46
|
DiffusionSampler from FlaxDiff.
|
46
47
|
"""
|
48
|
+
artifact: Any = None
|
47
49
|
state: TrainState = None
|
48
50
|
best_state: TrainState = None
|
49
51
|
rngstate: Optional[RandomMarkovState] = None
|
@@ -51,7 +53,6 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
51
53
|
model_output_transform: DiffusionPredictionTransform = None
|
52
54
|
autoencoder: AutoEncoder = None
|
53
55
|
input_config: DiffusionInputConfig = None
|
54
|
-
wandb_run = None
|
55
56
|
samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
|
56
57
|
config: Dict[str, Any] = field(default_factory=dict)
|
57
58
|
|
@@ -76,7 +77,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
76
77
|
Returns:
|
77
78
|
DiffusionInferencePipeline instance
|
78
79
|
"""
|
79
|
-
states, config, run = load_from_wandb_run(
|
80
|
+
states, config, run, artifact = load_from_wandb_run(
|
80
81
|
wandb_run,
|
81
82
|
project=project,
|
82
83
|
entity=entity,
|
@@ -95,6 +96,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
95
96
|
best_state=best_state,
|
96
97
|
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
97
98
|
run=run,
|
99
|
+
artifact=artifact,
|
98
100
|
)
|
99
101
|
return pipeline
|
100
102
|
|
@@ -119,7 +121,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
119
121
|
Returns:
|
120
122
|
DiffusionInferencePipeline instance
|
121
123
|
"""
|
122
|
-
states, config, run = load_from_wandb_registry(
|
124
|
+
states, config, run, artifact = load_from_wandb_registry(
|
123
125
|
modelname=modelname,
|
124
126
|
project=project,
|
125
127
|
entity=entity,
|
@@ -140,6 +142,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
140
142
|
best_state=best_state,
|
141
143
|
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
142
144
|
run=run,
|
145
|
+
artifact=artifact,
|
143
146
|
)
|
144
147
|
return pipeline
|
145
148
|
|
@@ -151,11 +154,14 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
151
154
|
best_state: Optional[Dict[str, Any]] = None,
|
152
155
|
rngstate: Optional[RandomMarkovState] = None,
|
153
156
|
run=None,
|
157
|
+
artifact=None,
|
154
158
|
):
|
155
159
|
if rngstate is None:
|
156
160
|
rngstate = RandomMarkovState(jax.random.PRNGKey(42))
|
157
161
|
# Build and return pipeline
|
158
162
|
return cls(
|
163
|
+
name=run.name if run else None,
|
164
|
+
artifact=artifact,
|
159
165
|
model=config['model'],
|
160
166
|
state=state,
|
161
167
|
best_state=best_state,
|
@@ -165,7 +171,6 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
165
171
|
autoencoder=config['autoencoder'],
|
166
172
|
input_config=config['input_config'],
|
167
173
|
config=config,
|
168
|
-
wandb_run=run,
|
169
174
|
)
|
170
175
|
|
171
176
|
def get_sampler(
|
@@ -292,7 +292,7 @@ def load_from_wandb_run(
|
|
292
292
|
config = run.config
|
293
293
|
except Exception as e:
|
294
294
|
print(f"Warning: Failed to load model from wandb: {e}")
|
295
|
-
return states, config, run
|
295
|
+
return states, config, run, artifact
|
296
296
|
|
297
297
|
def load_from_wandb_registry(
|
298
298
|
modelname: str,
|
@@ -318,4 +318,4 @@ def load_from_wandb_registry(
|
|
318
318
|
config = run.config
|
319
319
|
except Exception as e:
|
320
320
|
print(f"Warning: Failed to load model from wandb: {e}")
|
321
|
-
return states, config, run
|
321
|
+
return states, config, run, artifact
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|