flaxdiff 0.2.6.1__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -113,12 +113,14 @@ class ImageTFDSSource(DataSource):
113
113
  class ImageTFDSAugmenter(DataAugmenter):
114
114
  """Augmenter for TFDS image datasets."""
115
115
 
116
- def __init__(self, label_path: str = "/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"):
116
+ def __init__(self, label_path: str = None):
117
117
  """Initialize a TFDS image augmenter.
118
118
 
119
119
  Args:
120
120
  label_path: Path to the labels file for datasets like Oxford Flowers.
121
121
  """
122
+ if label_path is None:
123
+ label_path = os.path.join(os.path.expanduser("~"), "tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
122
124
  self.label_path = label_path
123
125
 
124
126
  def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
@@ -25,6 +25,7 @@ from flaxdiff.inference.utils import parse_config, load_from_wandb_run, load_fro
25
25
  @dataclass
26
26
  class InferencePipeline:
27
27
  """Inference pipeline for a general model."""
28
+ name: str = None
28
29
  model: nn.Module = None
29
30
  state: SimpleTrainState = None
30
31
  best_state: SimpleTrainState = None
@@ -44,6 +45,7 @@ class DiffusionInferencePipeline(InferencePipeline):
44
45
  This pipeline handles loading models from wandb and generating samples using the
45
46
  DiffusionSampler from FlaxDiff.
46
47
  """
48
+ artifact: Any = None
47
49
  state: TrainState = None
48
50
  best_state: TrainState = None
49
51
  rngstate: Optional[RandomMarkovState] = None
@@ -51,7 +53,6 @@ class DiffusionInferencePipeline(InferencePipeline):
51
53
  model_output_transform: DiffusionPredictionTransform = None
52
54
  autoencoder: AutoEncoder = None
53
55
  input_config: DiffusionInputConfig = None
54
- wandb_run = None
55
56
  samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
56
57
  config: Dict[str, Any] = field(default_factory=dict)
57
58
 
@@ -76,7 +77,7 @@ class DiffusionInferencePipeline(InferencePipeline):
76
77
  Returns:
77
78
  DiffusionInferencePipeline instance
78
79
  """
79
- states, config, run = load_from_wandb_run(
80
+ states, config, run, artifact = load_from_wandb_run(
80
81
  wandb_run,
81
82
  project=project,
82
83
  entity=entity,
@@ -95,6 +96,7 @@ class DiffusionInferencePipeline(InferencePipeline):
95
96
  best_state=best_state,
96
97
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
97
98
  run=run,
99
+ artifact=artifact,
98
100
  )
99
101
  return pipeline
100
102
 
@@ -119,7 +121,7 @@ class DiffusionInferencePipeline(InferencePipeline):
119
121
  Returns:
120
122
  DiffusionInferencePipeline instance
121
123
  """
122
- states, config, run = load_from_wandb_registry(
124
+ states, config, run, artifact = load_from_wandb_registry(
123
125
  modelname=modelname,
124
126
  project=project,
125
127
  entity=entity,
@@ -140,6 +142,7 @@ class DiffusionInferencePipeline(InferencePipeline):
140
142
  best_state=best_state,
141
143
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
142
144
  run=run,
145
+ artifact=artifact,
143
146
  )
144
147
  return pipeline
145
148
 
@@ -151,11 +154,14 @@ class DiffusionInferencePipeline(InferencePipeline):
151
154
  best_state: Optional[Dict[str, Any]] = None,
152
155
  rngstate: Optional[RandomMarkovState] = None,
153
156
  run=None,
157
+ artifact=None,
154
158
  ):
155
159
  if rngstate is None:
156
160
  rngstate = RandomMarkovState(jax.random.PRNGKey(42))
157
161
  # Build and return pipeline
158
162
  return cls(
163
+ name=run.name if run else None,
164
+ artifact=artifact,
159
165
  model=config['model'],
160
166
  state=state,
161
167
  best_state=best_state,
@@ -165,7 +171,6 @@ class DiffusionInferencePipeline(InferencePipeline):
165
171
  autoencoder=config['autoencoder'],
166
172
  input_config=config['input_config'],
167
173
  config=config,
168
- wandb_run=run,
169
174
  )
170
175
 
171
176
  def get_sampler(
@@ -292,7 +292,7 @@ def load_from_wandb_run(
292
292
  config = run.config
293
293
  except Exception as e:
294
294
  print(f"Warning: Failed to load model from wandb: {e}")
295
- return states, config, run
295
+ return states, config, run, artifact
296
296
 
297
297
  def load_from_wandb_registry(
298
298
  modelname: str,
@@ -318,4 +318,4 @@ def load_from_wandb_registry(
318
318
  config = run.config
319
319
  except Exception as e:
320
320
  print(f"Warning: Failed to load model from wandb: {e}")
321
- return states, config, run
321
+ return states, config, run, artifact
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.6.1
3
+ Version: 0.2.7
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -9,13 +9,13 @@ flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTS
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
10
  flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
11
  flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
12
- flaxdiff/data/sources/images.py,sha256=PSAJaECkXGeO0HWBveuW29AyNxAEBIb2wkNeyZNVJVE,11716
12
+ flaxdiff/data/sources/images.py,sha256=RFLtKW1xzw6ZPVXtCMmnTg1MPb8dc7rP77rZWbK7qpo,11796
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
14
  flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
16
  flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- flaxdiff/inference/pipeline.py,sha256=2eQEeQYuoWL6_d_AdTmAkvtz_poDmwPr9fihOnCt43Y,9045
18
- flaxdiff/inference/utils.py,sha256=0QEkPjFdqTmMsRaSUHgI9GV8gC2uCo1z40PvsnNYWaw,12303
17
+ flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
18
+ flaxdiff/inference/utils.py,sha256=MVnWl0LnC-1ILk0SsLd1YFu6igaQFR7mGhzo0jE797E,12323
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
21
  flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -62,7 +62,7 @@ flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9O
62
62
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
63
63
  flaxdiff/trainer/general_diffusion_trainer.py,sha256=BeDpJzgR8bUClJI4epQXlAul27MwiSfRW0lIBZSiPWk,28342
64
64
  flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
65
- flaxdiff-0.2.6.1.dist-info/METADATA,sha256=zNcAo6k99Wb9N5gG8h685M1_1vC5Q_XEJ9lZpfkuBvM,24059
66
- flaxdiff-0.2.6.1.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
67
- flaxdiff-0.2.6.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
68
- flaxdiff-0.2.6.1.dist-info/RECORD,,
65
+ flaxdiff-0.2.7.dist-info/METADATA,sha256=nwglJYeF2lH_MNq5PeFLR8TSPU-I9tzJUcBbTaLYxRM,24057
66
+ flaxdiff-0.2.7.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
67
+ flaxdiff-0.2.7.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
68
+ flaxdiff-0.2.7.dist-info/RECORD,,