flaxdiff 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +32 -3
- flaxdiff/data/dataset_map.py +1 -1
- flaxdiff/data/sources/images.py +17 -0
- flaxdiff/metrics/__init__.py +0 -0
- flaxdiff/metrics/common.py +11 -0
- flaxdiff/metrics/images.py +59 -0
- flaxdiff/trainer/general_diffusion_trainer.py +67 -23
- flaxdiff/trainer/simple_trainer.py +7 -5
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.4.dist-info}/METADATA +3 -1
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.4.dist-info}/RECORD +12 -9
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.4.dist-info}/WHEEL +1 -1
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.4.dist-info}/top_level.txt +0 -0
flaxdiff/data/dataloaders.py
CHANGED
@@ -258,7 +258,7 @@ def get_dataset_grain(
|
|
258
258
|
image_scale=256,
|
259
259
|
count=None,
|
260
260
|
num_epochs=None,
|
261
|
-
method=jax.image.ResizeMethod.LANCZOS3,
|
261
|
+
method=None, #jax.image.ResizeMethod.LANCZOS3,
|
262
262
|
worker_count=32,
|
263
263
|
read_thread_count=64,
|
264
264
|
read_buffer_size=50,
|
@@ -291,7 +291,7 @@ def get_dataset_grain(
|
|
291
291
|
|
292
292
|
local_batch_size = batch_size // jax.process_count()
|
293
293
|
|
294
|
-
|
294
|
+
train_sampler = pygrain.IndexSampler(
|
295
295
|
num_records=len(data_source) if count is None else count,
|
296
296
|
shuffle=True,
|
297
297
|
seed=seed,
|
@@ -299,6 +299,14 @@ def get_dataset_grain(
|
|
299
299
|
shard_options=pygrain.ShardByJaxProcess(),
|
300
300
|
)
|
301
301
|
|
302
|
+
# val_sampler = pygrain.IndexSampler(
|
303
|
+
# num_records=len(data_source) if count is None else count,
|
304
|
+
# shuffle=False,
|
305
|
+
# seed=seed,
|
306
|
+
# num_epochs=num_epochs,
|
307
|
+
# shard_options=pygrain.ShardByJaxProcess(),
|
308
|
+
# )
|
309
|
+
|
302
310
|
def get_trainset():
|
303
311
|
transformations = [
|
304
312
|
augmenter(),
|
@@ -307,7 +315,7 @@ def get_dataset_grain(
|
|
307
315
|
|
308
316
|
loader = pygrain.DataLoader(
|
309
317
|
data_source=data_source,
|
310
|
-
sampler=
|
318
|
+
sampler=train_sampler,
|
311
319
|
operations=transformations,
|
312
320
|
worker_count=worker_count,
|
313
321
|
read_options=pygrain.ReadOptions(
|
@@ -316,10 +324,31 @@ def get_dataset_grain(
|
|
316
324
|
worker_buffer_size=worker_buffer_size,
|
317
325
|
)
|
318
326
|
return loader
|
327
|
+
|
328
|
+
# def get_valset():
|
329
|
+
# transformations = [
|
330
|
+
# augmenter(),
|
331
|
+
# pygrain.Batch(local_batch_size, drop_remainder=True),
|
332
|
+
# ]
|
333
|
+
|
334
|
+
# loader = pygrain.DataLoader(
|
335
|
+
# data_source=data_source,
|
336
|
+
# sampler=val_sampler,
|
337
|
+
# operations=transformations,
|
338
|
+
# worker_count=worker_count,
|
339
|
+
# read_options=pygrain.ReadOptions(
|
340
|
+
# read_thread_count, read_buffer_size
|
341
|
+
# ),
|
342
|
+
# worker_buffer_size=worker_buffer_size,
|
343
|
+
# )
|
344
|
+
# return loader
|
345
|
+
get_valset = get_trainset # For now, use the same function for validation
|
319
346
|
|
320
347
|
return {
|
321
348
|
"train": get_trainset,
|
322
349
|
"train_len": len(data_source),
|
350
|
+
"val": get_valset,
|
351
|
+
"val_len": len(data_source),
|
323
352
|
"local_batch_size": local_batch_size,
|
324
353
|
"global_batch_size": batch_size,
|
325
354
|
}
|
flaxdiff/data/dataset_map.py
CHANGED
@@ -21,7 +21,7 @@ datasetMap = {
|
|
21
21
|
"augmenter": gcs_augmenters,
|
22
22
|
},
|
23
23
|
"laiona_coco": {
|
24
|
-
"source": data_source_gcs('
|
24
|
+
"source": data_source_gcs('datasets/laion12m+mscoco'),
|
25
25
|
"augmenter": gcs_augmenters,
|
26
26
|
},
|
27
27
|
"aesthetic_coyo": {
|
flaxdiff/data/sources/images.py
CHANGED
@@ -167,6 +167,16 @@ class ImageTFDSAugmenter(DataAugmenter):
|
|
167
167
|
|
168
168
|
return TFDSTransform
|
169
169
|
|
170
|
+
"""
|
171
|
+
Batch structure:
|
172
|
+
{
|
173
|
+
"image": image_batch,
|
174
|
+
"text": {
|
175
|
+
"input_ids": input_ids_batch,
|
176
|
+
"attention_mask": attention_mask_batch,
|
177
|
+
}
|
178
|
+
|
179
|
+
"""
|
170
180
|
|
171
181
|
# ----------------------------------------------------------------------------------
|
172
182
|
# GCS Image Source
|
@@ -248,6 +258,13 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
248
258
|
A callable that returns a pygrain.MapTransform.
|
249
259
|
"""
|
250
260
|
labelizer = self.labelizer
|
261
|
+
if method is None:
|
262
|
+
if image_scale > 256:
|
263
|
+
method = cv2.INTER_CUBIC
|
264
|
+
else:
|
265
|
+
method = cv2.INTER_AREA
|
266
|
+
|
267
|
+
print(f"Using method: {method}")
|
251
268
|
|
252
269
|
class GCSTransform(pygrain.MapTransform):
|
253
270
|
def __init__(self, *args, **kwargs):
|
File without changes
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
@dataclass
|
5
|
+
class EvaluationMetric:
|
6
|
+
"""
|
7
|
+
Evaluation metrics for the diffusion model.
|
8
|
+
The function is given generated samples batch [B, H, W, C] and the original batch.
|
9
|
+
"""
|
10
|
+
function: Callable
|
11
|
+
name: str
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from .common import EvaluationMetric
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
def get_clip_metric(
|
6
|
+
modelname: str = "openai/clip-vit-large-patch14",
|
7
|
+
):
|
8
|
+
from transformers import AutoProcessor, FlaxCLIPModel
|
9
|
+
model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
|
10
|
+
processor = AutoProcessor.from_pretrained(modelname, use_fast=True, dtype=jnp.float16)
|
11
|
+
|
12
|
+
@jax.jit
|
13
|
+
def calc(pixel_values, input_ids, attention_mask):
|
14
|
+
# Get the logits
|
15
|
+
generated_out = model(
|
16
|
+
pixel_values=pixel_values,
|
17
|
+
input_ids=input_ids,
|
18
|
+
attention_mask=attention_mask,
|
19
|
+
)
|
20
|
+
|
21
|
+
gen_img_emb = generated_out.image_embeds
|
22
|
+
txt_emb = generated_out.text_embeds
|
23
|
+
|
24
|
+
# 1. Normalize embeddings (essential for cosine similarity/distance)
|
25
|
+
gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
|
26
|
+
txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)
|
27
|
+
|
28
|
+
# 2. Calculate cosine similarity
|
29
|
+
# Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
|
30
|
+
# Calculate cosine similarity
|
31
|
+
similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)
|
32
|
+
|
33
|
+
scaled_distance = (1.0 - similarity)
|
34
|
+
# 4. Average over the batch
|
35
|
+
mean_scaled_distance = jnp.mean(scaled_distance)
|
36
|
+
|
37
|
+
return mean_scaled_distance
|
38
|
+
|
39
|
+
def clip_metric(
|
40
|
+
generated: jnp.ndarray,
|
41
|
+
batch
|
42
|
+
):
|
43
|
+
original_conditions = batch['text']
|
44
|
+
|
45
|
+
# Convert samples from [-1, 1] to [0, 255] and uint8
|
46
|
+
generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
|
47
|
+
|
48
|
+
generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
|
49
|
+
|
50
|
+
pixel_values = generated_inputs['pixel_values']
|
51
|
+
input_ids = original_conditions['input_ids']
|
52
|
+
attention_mask = original_conditions['attention_mask']
|
53
|
+
|
54
|
+
return calc(pixel_values, input_ids, attention_mask)
|
55
|
+
|
56
|
+
return EvaluationMetric(
|
57
|
+
function=clip_metric,
|
58
|
+
name='clip_similarity'
|
59
|
+
)
|
@@ -18,15 +18,17 @@ from ..samplers.ddim import DDIMSampler
|
|
18
18
|
from flaxdiff.utils import RandomMarkovState, serialize_model, get_latest_checkpoint
|
19
19
|
from flaxdiff.inputs import ConditioningEncoder, ConditionalInputConfig, DiffusionInputConfig
|
20
20
|
|
21
|
-
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
21
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics, convert_to_global_tree
|
22
22
|
|
23
23
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
24
24
|
from flax.training import dynamic_scale as dynamic_scale_lib
|
25
25
|
|
26
26
|
# Reuse the TrainState from the DiffusionTrainer
|
27
|
-
from
|
27
|
+
from .diffusion_trainer import TrainState, DiffusionTrainer
|
28
28
|
import shutil
|
29
29
|
|
30
|
+
from flaxdiff.metrics.common import EvaluationMetric
|
31
|
+
|
30
32
|
def generate_modelname(
|
31
33
|
dataset_name: str,
|
32
34
|
noise_schedule_name: str,
|
@@ -126,6 +128,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
126
128
|
native_resolution: int = None,
|
127
129
|
frames_per_sample: int = None,
|
128
130
|
wandb_config: Dict[str, Any] = None,
|
131
|
+
eval_metrics: List[EvaluationMetric] = None,
|
129
132
|
**kwargs
|
130
133
|
):
|
131
134
|
"""
|
@@ -150,6 +153,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
150
153
|
autoencoder=autoencoder,
|
151
154
|
)
|
152
155
|
self.input_config = input_config
|
156
|
+
self.eval_metrics = eval_metrics
|
153
157
|
|
154
158
|
if wandb_config is not None:
|
155
159
|
# If input_config is not in wandb_config, add it
|
@@ -363,7 +367,6 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
363
367
|
def generate_samples(
|
364
368
|
val_state: TrainState,
|
365
369
|
batch,
|
366
|
-
sampler: DiffusionSampler,
|
367
370
|
diffusion_steps: int,
|
368
371
|
):
|
369
372
|
# Process all conditional inputs
|
@@ -385,7 +388,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
385
388
|
model_conditioning_inputs=tuple(model_conditioning_inputs),
|
386
389
|
)
|
387
390
|
|
388
|
-
return
|
391
|
+
return generate_samples
|
389
392
|
|
390
393
|
def _get_image_size(self):
|
391
394
|
"""Helper to determine image size from available information."""
|
@@ -415,32 +418,73 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
415
418
|
"""
|
416
419
|
Run validation and log samples for both image and video diffusion.
|
417
420
|
"""
|
418
|
-
|
419
|
-
|
421
|
+
global_device_count = jax.device_count()
|
422
|
+
local_device_count = jax.local_device_count()
|
423
|
+
process_index = jax.process_index()
|
424
|
+
generate_samples = val_step_fn
|
420
425
|
|
426
|
+
val_ds = iter(val_ds()) if val_ds else None
|
427
|
+
# Evaluation step
|
421
428
|
try:
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
429
|
+
metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
|
430
|
+
for i in range(val_steps_per_epoch):
|
431
|
+
if val_ds is None:
|
432
|
+
batch = None
|
433
|
+
else:
|
434
|
+
batch = next(val_ds)
|
435
|
+
if self.distributed_training and global_device_count > 1:
|
436
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
437
|
+
# Generate samples
|
438
|
+
samples = generate_samples(
|
439
|
+
val_state,
|
440
|
+
batch,
|
441
|
+
diffusion_steps,
|
442
|
+
)
|
433
443
|
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
444
|
+
if self.eval_metrics is not None:
|
445
|
+
for metric in self.eval_metrics:
|
446
|
+
try:
|
447
|
+
# Evaluate metrics
|
448
|
+
metric_val = metric.function(samples, batch)
|
449
|
+
metrics[metric.name].append(metric_val)
|
450
|
+
except Exception as e:
|
451
|
+
print("Error in evaluation metrics:", e)
|
452
|
+
import traceback
|
453
|
+
traceback.print_exc()
|
454
|
+
pass
|
455
|
+
|
456
|
+
if i == 0:
|
457
|
+
print(f"Evaluation started for process index {process_index}")
|
458
|
+
# Log samples to wandb
|
459
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
460
|
+
import numpy as np
|
461
|
+
|
462
|
+
# Process samples differently based on dimensionality
|
463
|
+
if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
|
464
|
+
self._log_video_samples(samples, current_step)
|
465
|
+
else: # [B,H,W,C] - Image data
|
466
|
+
self._log_image_samples(samples, current_step)
|
439
467
|
|
468
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
469
|
+
# metrics is a dict of metrics
|
470
|
+
if metrics and type(metrics) == dict:
|
471
|
+
# Flatten the metrics
|
472
|
+
metrics = {k: np.mean(v) for k, v in metrics.items()}
|
473
|
+
# Log the metrics
|
474
|
+
for key, value in metrics.items():
|
475
|
+
if isinstance(value, jnp.ndarray):
|
476
|
+
value = np.array(value)
|
477
|
+
self.wandb.log({
|
478
|
+
f"val/{key}": value,
|
479
|
+
}, step=current_step)
|
480
|
+
|
481
|
+
except StopIteration:
|
482
|
+
print(f"Validation dataset exhausted for process index {process_index}")
|
440
483
|
except Exception as e:
|
441
|
-
print("Error
|
484
|
+
print(f"Error during validation for process index {process_index}: {e}")
|
442
485
|
import traceback
|
443
486
|
traceback.print_exc()
|
487
|
+
|
444
488
|
|
445
489
|
def _log_video_samples(self, samples, current_step):
|
446
490
|
"""Helper to log video samples to wandb."""
|
@@ -411,7 +411,9 @@ class SimpleTrainer:
|
|
411
411
|
train_ds,
|
412
412
|
train_steps_per_epoch,
|
413
413
|
current_step,
|
414
|
-
rng_state
|
414
|
+
rng_state,
|
415
|
+
save_every:int=None,
|
416
|
+
val_every=None,
|
415
417
|
):
|
416
418
|
global_device_count = jax.device_count()
|
417
419
|
process_index = jax.process_index()
|
@@ -491,8 +493,8 @@ class SimpleTrainer:
|
|
491
493
|
"train/loss": loss,
|
492
494
|
}, step=current_step)
|
493
495
|
# Save the model every few steps
|
494
|
-
if i %
|
495
|
-
print(f"Saving model after
|
496
|
+
if save_every and i % save_every == 0 and i > 0:
|
497
|
+
print(f"Saving model after {save_every} step {current_step}")
|
496
498
|
print(f"Devices: {len(jax.devices())}") # To sync the devices
|
497
499
|
self.save(current_epoch, current_step, train_state, rng_state)
|
498
500
|
print(f"Saving done by process index {process_index}")
|
@@ -518,7 +520,7 @@ class SimpleTrainer:
|
|
518
520
|
self.validation_loop(
|
519
521
|
train_state,
|
520
522
|
val_step,
|
521
|
-
data.get('
|
523
|
+
data.get('val', data.get('test', None)),
|
522
524
|
val_steps_per_epoch,
|
523
525
|
self.latest_step,
|
524
526
|
)
|
@@ -569,7 +571,7 @@ class SimpleTrainer:
|
|
569
571
|
self.validation_loop(
|
570
572
|
train_state,
|
571
573
|
val_step,
|
572
|
-
data.get('test', None),
|
574
|
+
data.get('val', data.get('test', None)),
|
573
575
|
val_steps_per_epoch,
|
574
576
|
current_step,
|
575
577
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.4
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
6
|
License-Expression: MIT
|
@@ -22,6 +22,8 @@ Requires-Dist: python-dotenv
|
|
22
22
|
|
23
23
|
# 
|
24
24
|
|
25
|
+
**This project is being used for the UMD Course project MSML 605: MLOps**
|
26
|
+
|
25
27
|
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
26
28
|
|
27
29
|
## A Versatile and simple Diffusion Library
|
@@ -2,14 +2,14 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
|
3
3
|
flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
|
4
4
|
flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
|
5
|
-
flaxdiff/data/dataloaders.py,sha256=
|
6
|
-
flaxdiff/data/dataset_map.py,sha256=
|
5
|
+
flaxdiff/data/dataloaders.py,sha256=LV8ugqoB86yihfYeOJZHHdRZJNmZ63A2NQkdILMR9QA,23564
|
6
|
+
flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
|
7
7
|
flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
|
8
8
|
flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
|
9
9
|
flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
|
10
10
|
flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
|
11
11
|
flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
|
12
|
-
flaxdiff/data/sources/images.py,sha256=
|
12
|
+
flaxdiff/data/sources/images.py,sha256=P7Rea7Zu0h9l7Zoc33zEHKdLI1ST6JEqgl1-bRwORM4,11460
|
13
13
|
flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
|
14
14
|
flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
|
15
15
|
flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
|
@@ -18,6 +18,9 @@ flaxdiff/inference/pipeline.py,sha256=oMBRjvTtlC3Yzl1FqiBHcI4V34HXGAecCg8UvQbKoO
|
|
18
18
|
flaxdiff/inference/utils.py,sha256=SRNYo-YtHzEPRpNv0fD8ZrUvnRIK941Rh4tjlsOGRgM,12278
|
19
19
|
flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
|
20
20
|
flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
|
21
|
+
flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
|
+
flaxdiff/metrics/common.py,sha256=E0MkL43dicImzNNa-RyQ3sVcrUbpeLlooIQsKTIStvo,285
|
23
|
+
flaxdiff/metrics/images.py,sha256=sIuF_Sa2VmPOKrFFoUpzhqOqNa9P7NF0njbrYi93AvE,2128
|
21
24
|
flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
|
22
25
|
flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
26
|
flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -56,9 +59,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
|
|
56
59
|
flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
|
57
60
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
|
58
61
|
flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
|
59
|
-
flaxdiff/trainer/general_diffusion_trainer.py,sha256=
|
60
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
61
|
-
flaxdiff-0.2.
|
62
|
-
flaxdiff-0.2.
|
63
|
-
flaxdiff-0.2.
|
64
|
-
flaxdiff-0.2.
|
62
|
+
flaxdiff/trainer/general_diffusion_trainer.py,sha256=9c3Ys5sN4_eTehusLjS6IKW5XPOkxoguik-6G0cyQc4,27082
|
63
|
+
flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
|
64
|
+
flaxdiff-0.2.4.dist-info/METADATA,sha256=mqm2um1TtgjQzrGlvl7x_CCb_09hK376_dsRECe6qLQ,24057
|
65
|
+
flaxdiff-0.2.4.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
|
66
|
+
flaxdiff-0.2.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
67
|
+
flaxdiff-0.2.4.dist-info/RECORD,,
|
File without changes
|