classy-szfast 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -87,27 +87,62 @@ class Restore_NN(tf.keras.Model):
87
87
  print(multiline_str)
88
88
 
89
89
 
90
- # restore attributes
91
- def restore(self,
92
- filename
93
- ):
90
+
91
+ # from https://github.com/HTJense/cosmopower/blob/packaging-paper/cosmopower/cosmopower_NN.py
92
+ def restore(self, filename: str, allow_pickle: bool = False) -> None:
94
93
  r"""
95
- Load pre-trained model
94
+ Load pre-trained model.
95
+ The default file format is compressed numpy files (.npz). The
96
+ Module will attempt to use this as a file extension and restore
97
+ from there (i.e. look for `filename.npz`). If this file does
98
+ not exist, and `allow_pickle` is set to True, then the file
99
+ `filename.pkl` will be attempted to be read by `restore_pickle`.
100
+
101
+ The function will trim the file extension from `filename`, so
102
+ `restore("filename")` and `restore("filename.npz")` are identical.
96
103
 
97
104
  Parameters:
98
- filename (str):
99
- filename tag (without suffix) where model was saved
105
+ :param filename: filename (without suffix) where model was saved.
106
+ :param allow_pickle: whether or not to permit passing this filename
107
+ to the `restore_pickle` function.
100
108
  """
101
- # load attributes
102
- with open(filename + ".pkl", 'rb') as f:
103
- self.W_, self.b_, self.alphas_, self.betas_, \
104
- self.parameters_mean_, self.parameters_std_, \
105
- self.features_mean_, self.features_std_, \
106
- self.n_parameters, self.parameters, \
107
- self.n_modes, self.modes, \
108
- self.n_hidden, self.n_layers, self.architecture = pickle.load(f)
109
-
110
-
109
+ # Check if npz file exists.
110
+ filename_npz = filename + ".npz"
111
+ if not os.path.exists(filename_npz):
112
+ # Can we load this file as a pickle file?
113
+ filename_pkl = filename + ".pkl"
114
+ if allow_pickle and os.path.exists(filename_pkl):
115
+ self.restore_pickle(filename_pkl)
116
+ return
117
+
118
+ raise IOError(f"Failed to restore network from {filename}: "
119
+ + (" is a pickle file, try setting 'allow_pickle = \
120
+ True'" if os.path.exists(filename_pkl) else
121
+ " does not exist."))
122
+
123
+ with open(filename_npz, "rb") as fp:
124
+ fpz = np.load(fp)
125
+
126
+ self.architecture = fpz["architecture"]
127
+ self.n_layers = fpz["n_layers"]
128
+ self.n_hidden = fpz["n_hidden"]
129
+ self.n_parameters = fpz["n_parameters"]
130
+ self.n_modes = fpz["n_modes"]
131
+
132
+ self.parameters = list(fpz["parameters"])
133
+ self.modes = fpz["modes"]
134
+
135
+ self.parameters_mean_ = fpz["parameters_mean"]
136
+ self.parameters_std_ = fpz["parameters_std"]
137
+ self.features_mean_ = fpz["features_mean"]
138
+ self.features_std_ = fpz["features_std"]
139
+
140
+ self.W_ = [fpz[f"W_{i}"] for i in range(self.n_layers)]
141
+ self.b_ = [fpz[f"b_{i}"] for i in range(self.n_layers)]
142
+ self.alphas_ = [
143
+ fpz[f"alphas_{i}"] for i in range(self.n_layers - 1)
144
+ ]
145
+ self.betas_ = [fpz[f"betas_{i}"] for i in range(self.n_layers - 1)]
111
146
 
112
147
  # auxiliary function to sort input parameters
113
148
  def dict_to_ordered_arr_np(self,
@@ -274,29 +309,66 @@ class Restore_PCAplusNN(tf.keras.Model):
274
309
 
275
310
 
276
311
 
277
- # restore attributes
278
- def restore(self,
279
- filename,
280
- ):
312
+ # from https://github.com/HTJense/cosmopower/blob/packaging-paper/cosmopower/cosmopower_PCAplusNN.py
313
+ def restore(self, filename: str, allow_pickle: bool = False) -> None:
281
314
  r"""
282
- Load pre-trained model
315
+ Load pre-trained model.
316
+ The default file format is compressed numpy files (.npz). The
317
+ Module will attempt to use this as a file extension and restore
318
+ from there (i.e. look for `filename.npz`). If this file does
319
+ not exist, and `allow_pickle` is set to True, then the file
320
+ `filename.pkl` will be attempted to be read by `restore_pickle`.
321
+
322
+ The function will trim the file extension from `filename`, so
323
+ `restore("filename")` and `restore("filename.npz")` are identical.
283
324
 
284
325
  Parameters:
285
- filename (str):
286
- filename tag (without suffix) where model was saved
326
+ :param filename: filename (without suffix) where model was saved.
327
+ :param allow_pickle: whether or not to permit passing this filename to
328
+ the `restore_pickle` function.
287
329
  """
288
- # load attributes
289
- f = open(filename + ".pkl", 'rb')
290
- self.W_, self.b_, self.alphas_, self.betas_, \
291
- self.parameters_mean_, self.parameters_std_, \
292
- self.pca_mean_, self.pca_std_, \
293
- self.features_mean_, self.features_std_, \
294
- self.parameters, self.n_parameters, \
295
- self.modes, self.n_modes, \
296
- self.n_pcas, self.pca_transform_matrix_, \
297
- self.n_hidden, self.n_layers, self.architecture = pickle.load(f)
298
- f.close()
299
-
330
+ # Check if npz file exists.
331
+ filename_npz = filename + ".npz"
332
+ if not os.path.exists(filename_npz):
333
+ # Can we load this file as a pickle file?
334
+ filename_pkl = filename + ".pkl"
335
+ if allow_pickle and os.path.exists(filename_pkl):
336
+ self.restore_pickle(filename_pkl)
337
+ return
338
+
339
+ raise IOError(f"Failed to restore network from {filename}: "
340
+ + (" is a pickle file, try setting 'allow_pickle = \
341
+ True'" if os.path.exists(filename_pkl) else
342
+ " does not exist."))
343
+
344
+ with open(filename_npz, "rb") as fp:
345
+ fpz = np.load(fp)
346
+
347
+ self.architecture = fpz["architecture"]
348
+ self.n_layers = fpz["n_layers"]
349
+ self.n_hidden = fpz["n_hidden"]
350
+ self.n_parameters = fpz["n_parameters"]
351
+ self.n_modes = fpz["n_modes"]
352
+
353
+ self.parameters = fpz["parameters"]
354
+ self.modes = fpz["modes"]
355
+
356
+ self.parameters_mean_ = fpz["parameters_mean"]
357
+ self.parameters_std_ = fpz["parameters_std"]
358
+ self.features_mean_ = fpz["features_mean"]
359
+ self.features_std_ = fpz["features_std"]
360
+
361
+ self.pca_mean_ = fpz["pca_mean"]
362
+ self.pca_std_ = fpz["pca_std"]
363
+ self.n_pcas = fpz["n_pcas"]
364
+ self.pca_transform_matrix_ = fpz["pca_transform_matrix"]
365
+
366
+ self.W_ = [fpz[f"W_{i}"] for i in range(self.n_layers)]
367
+ self.b_ = [fpz[f"b_{i}"] for i in range(self.n_layers)]
368
+ self.alphas_ = [
369
+ fpz[f"alphas_{i}"] for i in range(self.n_layers - 1)
370
+ ]
371
+ self.betas_ = [fpz[f"betas_{i}"] for i in range(self.n_layers - 1)]
300
372
 
301
373
 
302
374
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: classy_szfast
3
- Version: 0.0.14
3
+ Version: 0.0.15
4
4
  Summary: The accelerator of the class_sz code from https://github.com/CLASS-SZ
5
5
  Maintainer-email: Boris Bolliet <bb667@cam.ac.uk>
6
6
  Project-URL: Homepage, https://github.com/CLASS-SZ
@@ -5,14 +5,14 @@ classy_szfast/config.py,sha256=4CvejtLcFOQR30bJ8tlEeBHhu3Rr7LakeLO6dbFgPSU,210
5
5
  classy_szfast/cosmopower.py,sha256=eym72TFAcSJSTUlrwD-sAg8_9e2GdZq0m3lLPQ7uvPU,9858
6
6
  classy_szfast/cosmosis_classy_szfast_interface.py,sha256=zAnxvFtn73a5yS7jgs59zpWFEYKCIQyraYPs5hQ4Le8,11483
7
7
  classy_szfast/pks_and_sigmas.py,sha256=drtuujE1HhlrYY1hY92DyY5lXlYS1uE15MSuVI4uo6k,6625
8
- classy_szfast/restore_nn.py,sha256=OyxaRRk9D4hOJTvUSY3c5wAWTPCZJRMxBtin4kq_xd0,14149
8
+ classy_szfast/restore_nn.py,sha256=tmR6qPLvf9JzEwUECeDeF8pbbmvoOGKuQDPSbC4kDu0,18010
9
9
  classy_szfast/suppress_warnings.py,sha256=6wIBml2Sj9DyRGZlZWhuA9hqvpxqrNyYjuz6BPK_a6E,202
10
10
  classy_szfast/utils.py,sha256=VdaRsJK2ttHI9zkyxVhergxHPC6t99usrlycblyqcP8,1464
11
11
  classy_szfast/custom_bias/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  classy_szfast/custom_bias/custom_bias.py,sha256=aR2t5RTIwv7P0m2bsEU0Eq6BTkj4pG10AebH6QpG4qM,486
13
13
  classy_szfast/custom_profiles/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  classy_szfast/custom_profiles/custom_profiles.py,sha256=4LZwb2XoqwCyWNmW2s24Z7AJdmgVdaRG7yYaBYe-d9Q,1188
15
- classy_szfast-0.0.14.dist-info/METADATA,sha256=Y_uk64BUO3RTDgeBFbhe2m7ALSfkC5uAyNrhLbzPY80,472
16
- classy_szfast-0.0.14.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
17
- classy_szfast-0.0.14.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
18
- classy_szfast-0.0.14.dist-info/RECORD,,
15
+ classy_szfast-0.0.15.dist-info/METADATA,sha256=OGZWiFxqf0ZY1yVr_VsMH-E5Fgs49R2ldrryTWlpGl0,472
16
+ classy_szfast-0.0.15.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
17
+ classy_szfast-0.0.15.dist-info/top_level.txt,sha256=hRgqpilUck4lx2KkaWI2y9aCDKqF6pFfGHfNaoPFxv0,14
18
+ classy_szfast-0.0.15.dist-info/RECORD,,