diffcrysgen 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffcrysgen/__init__.py +0 -0
- diffcrysgen/analyze_generated_structures.py +88 -0
- diffcrysgen/invert_pcr.py +163 -0
- diffcrysgen/model.py +316 -0
- diffcrysgen/sampler.py +116 -0
- diffcrysgen/trainer.py +128 -0
- diffcrysgen/utils.py +38 -0
- diffcrysgen-0.1.0.dist-info/METADATA +103 -0
- diffcrysgen-0.1.0.dist-info/RECORD +12 -0
- diffcrysgen-0.1.0.dist-info/WHEEL +5 -0
- diffcrysgen-0.1.0.dist-info/licenses/LICENSE +21 -0
- diffcrysgen-0.1.0.dist-info/top_level.txt +1 -0
diffcrysgen/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# diffcrysgen/analyze_generated_structures.py
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import spglib
|
|
6
|
+
from diffcrysgen.invert_pcr import InvertPCR
|
|
7
|
+
from ase import Atoms
|
|
8
|
+
import ase.io
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
def analyze_ircr_data(generated_ircr_path="generated_ircr.npy", output_csv="generated_material_data.csv", min_valid_distance=0.5, save_cifs=False, cif_dir="cif-files"):
|
|
12
|
+
"""
|
|
13
|
+
Analyzes generated IRCR data and extracts crystallographic info.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
generated_ircr_path (str): Path to the generated IRCR .npy file
|
|
17
|
+
output_csv (str): Path to save the resulting CSV
|
|
18
|
+
min_valid_distance (float): Distance threshold for validity check in Angstrom
|
|
19
|
+
save_cifs (bool): If True, saves CIF files into subfolders
|
|
20
|
+
cif_dir (str): Base directory to store valid/invalid CIFs
|
|
21
|
+
"""
|
|
22
|
+
data = np.load(generated_ircr_path)
|
|
23
|
+
Nmat = data.shape[0]
|
|
24
|
+
print(f"Loaded {Nmat} materials from {generated_ircr_path}")
|
|
25
|
+
|
|
26
|
+
results = {
|
|
27
|
+
"ID": [], "Formula": [], "type": [], "Natoms": [], "min-pairwise-distance": [],
|
|
28
|
+
"Zmax": [], "SPG-symbol": [], "SPG-number": [], "validity": [],
|
|
29
|
+
"a": [], "b": [], "c": [], "alpha": [], "beta": [], "gamma": []
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
if save_cifs:
|
|
33
|
+
os.makedirs(os.path.join(cif_dir, "valid"), exist_ok=True)
|
|
34
|
+
os.makedirs(os.path.join(cif_dir, "invalid"), exist_ok=True)
|
|
35
|
+
|
|
36
|
+
for i in range(Nmat):
|
|
37
|
+
mat = data[i, :]
|
|
38
|
+
ipcr = InvertPCR(mat, 94, 20)
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
atoms = ipcr.get_atoms_object()
|
|
42
|
+
formula = atoms.get_chemical_formula()
|
|
43
|
+
atomic_numbers = atoms.get_atomic_numbers()
|
|
44
|
+
unique_elements = len(set(atomic_numbers))
|
|
45
|
+
|
|
46
|
+
mat_type = {1: "elemental", 2: "binary", 3: "ternary"}.get(unique_elements, "complex")
|
|
47
|
+
a, b, c = atoms.cell.lengths()
|
|
48
|
+
alpha, beta, gamma = atoms.cell.angles()
|
|
49
|
+
|
|
50
|
+
structure = (atoms.get_cell(), atoms.get_scaled_positions(), atomic_numbers)
|
|
51
|
+
spg_symbol, spg_number = spglib.get_spacegroup(structure, symprec=0.1).split(" ")
|
|
52
|
+
spg_number = int(spg_number.strip("()"))
|
|
53
|
+
|
|
54
|
+
distances = ipcr.get_distances()
|
|
55
|
+
min_dist = min(distances) if distances else 3.0
|
|
56
|
+
is_valid = min_dist >= min_valid_distance
|
|
57
|
+
mat_id = f"smps-{i+1}"
|
|
58
|
+
|
|
59
|
+
if save_cifs:
|
|
60
|
+
cif_path = os.path.join(cif_dir, "valid" if is_valid else "invalid", f"{mat_id}.cif")
|
|
61
|
+
ase.io.write(cif_path, atoms)
|
|
62
|
+
|
|
63
|
+
# Append data
|
|
64
|
+
results["ID"].append(mat_id)
|
|
65
|
+
results["Formula"].append(formula)
|
|
66
|
+
results["type"].append(mat_type)
|
|
67
|
+
results["Natoms"].append(len(atomic_numbers))
|
|
68
|
+
results["min-pairwise-distance"].append(min_dist)
|
|
69
|
+
results["Zmax"].append(max(atomic_numbers))
|
|
70
|
+
results["SPG-symbol"].append(spg_symbol)
|
|
71
|
+
results["SPG-number"].append(spg_number)
|
|
72
|
+
results["validity"].append("valid" if is_valid else "invalid")
|
|
73
|
+
results["a"].append(a)
|
|
74
|
+
results["b"].append(b)
|
|
75
|
+
results["c"].append(c)
|
|
76
|
+
results["alpha"].append(alpha)
|
|
77
|
+
results["beta"].append(beta)
|
|
78
|
+
results["gamma"].append(gamma)
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
#print(f"[Material {i}] Error: {e}")
|
|
82
|
+
print("spglib failed to assign spg")
|
|
83
|
+
|
|
84
|
+
df = pd.DataFrame(results)
|
|
85
|
+
df.to_csv(output_csv, index=False)
|
|
86
|
+
print(f"Saved extracted data to: {output_csv}")
|
|
87
|
+
return df
|
|
88
|
+
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
|
|
2
|
+
from ase.spacegroup import crystal
|
|
3
|
+
import numpy as np
|
|
4
|
+
import ase.io
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Dictionary containing atomic number with respective element
|
|
8
|
+
|
|
9
|
+
atomic_name = {"1":"H", "2":"He", "3":"Li", "4":"Be", "5":"B", "6":"C", "7":"N", "8":"O", "9":"F", "10":"Ne",\
|
|
10
|
+
"11":"Na", "12":"Mg", "13":"Al", "14":"Si", "15":"P", "16":"S", "17":"Cl", "18":"Ar", "19": "K",\
|
|
11
|
+
"20":"Ca", "21":"Sc", "22":"Ti", "23":"V", "24":"Cr", "25":"Mn", "26":"Fe", "27":"Co", "28":"Ni",\
|
|
12
|
+
"29":"Cu", "30":"Zn", "31":"Ga", "32":"Ge", "33":"As","34":"Se", "35":"Br","36":"Kr", "37":"Rb",\
|
|
13
|
+
"38":"Sr", "39":"Y", "40":"Zr", "41":"Nb", "42":"Mo", "43":"Tc", "44":"Ru", "45":"Rh", "46":"Pd",\
|
|
14
|
+
"47":"Ag", "48":"Cd", "49":"In", "50":"Sn", "51":"Sb", "52":"Te", "53":"I", "54":"Xe", "55":"Cs", "56":"Ba",\
|
|
15
|
+
"57":"La", "58":"Ce", "59":"Pr", "60":"Nd", "61":"Pm", "62":"Sm", "63":"Eu", "64":"Gd", "65":"Tb", "66":"Dy",\
|
|
16
|
+
"67":"Ho", "68":"Er", "69":"Tm", "70":"Yb", "71":"Lu", "72":"Hf", "73":"Ta", "74":"W", "75":"Re", "76":"Os",\
|
|
17
|
+
"77":"Ir", "78":"Pt", "79":"Au", "80":"Hg", "81":"Tl", "82":"Pb", "83":"Bi", "84":"Po", "85":"At", "86":"Rn",\
|
|
18
|
+
"87":"Fr", "88":"Ra", "89":"Ac", "90":"Th", "91":"Pa", "92":"U", "93":"Np", "94":"Pu"}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class InvertPCR :
|
|
22
|
+
|
|
23
|
+
""" Here we will extract necessary informations from point clound reps (PCR) in order to construct cif file."""
|
|
24
|
+
|
|
25
|
+
def __init__(self,PCR,z_max,n_sites) :
|
|
26
|
+
self.PCR = PCR
|
|
27
|
+
self.z_max = z_max
|
|
28
|
+
self.n_sites = n_sites
|
|
29
|
+
|
|
30
|
+
def get_element_matrix(self):
|
|
31
|
+
ele_mat = self.PCR[0:self.z_max,:]
|
|
32
|
+
ele_mat[ele_mat < 0.5] = 0
|
|
33
|
+
return ele_mat
|
|
34
|
+
|
|
35
|
+
def get_lattice_matrix(self):
|
|
36
|
+
lat_mat = self.PCR[self.z_max:self.z_max+2,:]
|
|
37
|
+
return lat_mat
|
|
38
|
+
|
|
39
|
+
def get_lattice_parameters(self):
|
|
40
|
+
lattice_matrix = self.get_lattice_matrix()
|
|
41
|
+
lattice_parameters = list(lattice_matrix.flatten())
|
|
42
|
+
a = lattice_parameters[0]
|
|
43
|
+
b = lattice_parameters[1]
|
|
44
|
+
c = lattice_parameters[2]
|
|
45
|
+
alpha = lattice_parameters[3]
|
|
46
|
+
beta = lattice_parameters[4]
|
|
47
|
+
gamma = lattice_parameters[5]
|
|
48
|
+
scaled_par = [a,b,c,alpha,beta,gamma]
|
|
49
|
+
return scaled_par
|
|
50
|
+
|
|
51
|
+
def get_basis_matrix(self):
|
|
52
|
+
basis_mat = self.PCR[self.z_max+2:self.z_max+2+self.n_sites,:]
|
|
53
|
+
return basis_mat
|
|
54
|
+
|
|
55
|
+
def get_site_matrix(self):
|
|
56
|
+
site_mat = self.PCR[self.z_max+2+self.n_sites:self.z_max+2+2*self.n_sites,:]
|
|
57
|
+
site_mat[site_mat < 0.5] = 0
|
|
58
|
+
return site_mat
|
|
59
|
+
|
|
60
|
+
def get_property_matrix(self):
|
|
61
|
+
prop_mat = self.PCR[self.z_max+2+2*self.n_sites:self.z_max+2+2*self.n_sites+8,:]
|
|
62
|
+
return prop_mat
|
|
63
|
+
|
|
64
|
+
def get_unique_atomic_numbers(self):
|
|
65
|
+
element_matrix = self.get_element_matrix()
|
|
66
|
+
z_unique = []
|
|
67
|
+
for i in range(3):
|
|
68
|
+
col = list(element_matrix[:,i])
|
|
69
|
+
max_value = max(col)
|
|
70
|
+
if max_value == 0 :
|
|
71
|
+
z_unique.append(0)
|
|
72
|
+
if max_value != 0 :
|
|
73
|
+
z_unique.append(col.index(max_value)+1)
|
|
74
|
+
return z_unique
|
|
75
|
+
|
|
76
|
+
def get_unique_elements(self):
|
|
77
|
+
z_unique = self.get_unique_atomic_numbers()
|
|
78
|
+
unique_element = []
|
|
79
|
+
for z in z_unique :
|
|
80
|
+
if z == 0 :
|
|
81
|
+
unique_element.append("-")
|
|
82
|
+
else :
|
|
83
|
+
unique_element.append(atomic_name[str(z)])
|
|
84
|
+
return unique_element
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_atomic_numbers(self):
|
|
88
|
+
element_matrix = self.get_element_matrix()
|
|
89
|
+
site_matrix = self.get_site_matrix()
|
|
90
|
+
z_unique = self.get_unique_atomic_numbers()
|
|
91
|
+
z_total = []
|
|
92
|
+
for i in range(self.n_sites):
|
|
93
|
+
row = list(site_matrix[i,:])
|
|
94
|
+
val = max(row)
|
|
95
|
+
if val == 0 :
|
|
96
|
+
z_total.append(0)
|
|
97
|
+
if val != 0 :
|
|
98
|
+
z_total.append(z_unique[row.index(val)])
|
|
99
|
+
return z_total
|
|
100
|
+
|
|
101
|
+
def get_elements_basis(self):
|
|
102
|
+
basis = self.get_basis_matrix()
|
|
103
|
+
z_total = self.get_atomic_numbers()
|
|
104
|
+
final_atoms = []
|
|
105
|
+
final_atoms_index = []
|
|
106
|
+
final_elements = []
|
|
107
|
+
for i in range(len(z_total)):
|
|
108
|
+
z = z_total[i]
|
|
109
|
+
if z != 0 :
|
|
110
|
+
final_atoms_index.append(i)
|
|
111
|
+
final_atoms.append(z)
|
|
112
|
+
final_elements.append(atomic_name[str(z)])
|
|
113
|
+
final_basis = basis[final_atoms_index,:]
|
|
114
|
+
|
|
115
|
+
return final_elements, final_basis
|
|
116
|
+
|
|
117
|
+
def get_formula(self) :
|
|
118
|
+
a,_ = self.get_elements_basis()
|
|
119
|
+
elements = [a[0]]
|
|
120
|
+
numbers = []
|
|
121
|
+
|
|
122
|
+
for i in range(len(a)) :
|
|
123
|
+
ele = a[i]
|
|
124
|
+
if ele != elements[len(elements)-1] :
|
|
125
|
+
elements.append(ele)
|
|
126
|
+
numbers.append(i)
|
|
127
|
+
|
|
128
|
+
numbers.append(len(a))
|
|
129
|
+
|
|
130
|
+
final_numbers = []
|
|
131
|
+
final_numbers.append(numbers[0])
|
|
132
|
+
for i in range(1,len(numbers)):
|
|
133
|
+
final_numbers.append(numbers[i]-numbers[i-1])
|
|
134
|
+
|
|
135
|
+
formula = elements[0] + str(final_numbers[0])
|
|
136
|
+
for i in range(1,len(elements)) :
|
|
137
|
+
formula += elements[i] + str(final_numbers[i])
|
|
138
|
+
|
|
139
|
+
return formula
|
|
140
|
+
|
|
141
|
+
def get_atoms_object(self):
|
|
142
|
+
symbol = self.get_formula()
|
|
143
|
+
_, basis = self.get_elements_basis()
|
|
144
|
+
lattice = self.get_lattice_parameters()
|
|
145
|
+
atoms = crystal(symbols=symbol,
|
|
146
|
+
basis=basis,
|
|
147
|
+
cellpar=lattice)
|
|
148
|
+
return atoms
|
|
149
|
+
|
|
150
|
+
def get_distances(self):
|
|
151
|
+
atoms = self.get_atoms_object()
|
|
152
|
+
positions = atoms.get_positions()
|
|
153
|
+
nele = positions.shape[0]
|
|
154
|
+
distances = []
|
|
155
|
+
for i in range(nele):
|
|
156
|
+
for j in range(i+1,nele):
|
|
157
|
+
pos1 = positions[i,:]
|
|
158
|
+
pos2 = positions[j,:]
|
|
159
|
+
vec = pos1 - pos2
|
|
160
|
+
dist = np.linalg.norm(vec)
|
|
161
|
+
distances.append(np.around(dist,4))
|
|
162
|
+
return distances
|
|
163
|
+
|
diffcrysgen/model.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
# Importing necessary libraries
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Sinusoidal Positional Embedding
|
|
12
|
+
|
|
13
|
+
class PositionalEmbedding(torch.nn.Module):
|
|
14
|
+
"""
|
|
15
|
+
Given a noise label, it provides the corresponding positional embedding vector.
|
|
16
|
+
"""
|
|
17
|
+
def __init__(self, emb_dim, max_positions = 10000):
|
|
18
|
+
"""
|
|
19
|
+
Parameter : emb_dim (Type : Int) : Embedding Dimension
|
|
20
|
+
"""
|
|
21
|
+
super(PositionalEmbedding, self).__init__()
|
|
22
|
+
self.emb_dim = emb_dim
|
|
23
|
+
self.max_positions = max_positions
|
|
24
|
+
|
|
25
|
+
def forward(self, x):
|
|
26
|
+
"""
|
|
27
|
+
Input : x [Tensor object], representing noise levels for a batch of input images
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# Precompute frequencies for efficiency
|
|
31
|
+
freqs = torch.arange(0, self.emb_dim // 2, dtype=torch.float32, device=x.device)
|
|
32
|
+
freqs = freqs / (self.emb_dim // 2)
|
|
33
|
+
freqs = (1 / self.max_positions) ** freqs
|
|
34
|
+
|
|
35
|
+
# Outer product between Input tensor and freqs tensor
|
|
36
|
+
x = torch.tensordot(x, freqs, dims=0)
|
|
37
|
+
# Sinusoidal embedding
|
|
38
|
+
# sin block
|
|
39
|
+
sin_block = torch.sin(x)
|
|
40
|
+
# cos block
|
|
41
|
+
cos_block = torch.cos(x)
|
|
42
|
+
# Concatenate sin and cos block along column dimension
|
|
43
|
+
embedding = torch.cat([sin_block, cos_block], dim=1)
|
|
44
|
+
# Finally, concatenate sin and cos block in such a way so that even columns
|
|
45
|
+
# come from sin block and odd column comes from cos block.
|
|
46
|
+
sin_unsqueezed = sin_block.unsqueeze(2)
|
|
47
|
+
cos_unsqueezed = cos_block.unsqueeze(2)
|
|
48
|
+
final_embedding = torch.reshape(torch.cat([sin_unsqueezed, cos_unsqueezed], dim=2), (embedding.shape[0], -1))
|
|
49
|
+
return final_embedding
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Module for Preconditioning
|
|
53
|
+
# Adapted from "Elucidating the Design Space of Diffusion-Based
|
|
54
|
+
# Generative Models" by Karras et al.
|
|
55
|
+
|
|
56
|
+
class Precond(nn.Module):
|
|
57
|
+
def __init__(self,
|
|
58
|
+
denoise_fn,
|
|
59
|
+
sigma_min=0, # minimum supported noise level
|
|
60
|
+
sigma_max=float("inf"), # maximum supported noise level
|
|
61
|
+
sigma_data=0.5, # expected standard deviation of training data
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
|
|
65
|
+
self.denoise_fn_F = denoise_fn
|
|
66
|
+
self.sigma_min = sigma_min
|
|
67
|
+
self.sigma_max = sigma_max
|
|
68
|
+
self.sigma_data = sigma_data
|
|
69
|
+
|
|
70
|
+
def forward(self, x, sigma):
|
|
71
|
+
x = x.to(torch.float32)
|
|
72
|
+
sigma = sigma.to(torch.float32).reshape(-1,1,1)
|
|
73
|
+
dtype = torch.float32
|
|
74
|
+
|
|
75
|
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
|
76
|
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
|
77
|
+
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
|
78
|
+
c_noise = sigma.log() / 4
|
|
79
|
+
|
|
80
|
+
x_in = c_in * x
|
|
81
|
+
F_x = self.denoise_fn_F((x_in).to(dtype), c_noise.flatten())
|
|
82
|
+
|
|
83
|
+
assert F_x.dtype == dtype
|
|
84
|
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
|
85
|
+
return D_x
|
|
86
|
+
|
|
87
|
+
def round_sigma(self, sigma):
|
|
88
|
+
return torch.as_tensor(sigma)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# EDM Loss
|
|
92
|
+
|
|
93
|
+
def EDMLoss(denoise_fn, data):
|
|
94
|
+
P_mean = -1.2
|
|
95
|
+
P_std = 1.2
|
|
96
|
+
sigma_data = 0.5
|
|
97
|
+
rnd_normal = torch.randn(data.shape[0], device=data.device)
|
|
98
|
+
sigma = (rnd_normal * P_std + P_mean).exp()
|
|
99
|
+
weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
|
|
100
|
+
y = data
|
|
101
|
+
n = torch.randn_like(y) * sigma.unsqueeze(1).unsqueeze(1)
|
|
102
|
+
D_yn = denoise_fn(y + n, sigma)
|
|
103
|
+
target = y
|
|
104
|
+
loss = weight.unsqueeze(1).unsqueeze(1) * ((D_yn - target) ** 2)
|
|
105
|
+
return loss
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# Final Model
|
|
109
|
+
|
|
110
|
+
class Model(nn.Module):
|
|
111
|
+
def __init__(self, denoise_fn, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.P_mean = P_mean
|
|
114
|
+
self.P_std = P_std
|
|
115
|
+
self.sigma_data = sigma_data
|
|
116
|
+
self.denoise_fn_D = Precond(denoise_fn)
|
|
117
|
+
|
|
118
|
+
def forward(self, x):
|
|
119
|
+
loss = EDMLoss(self.denoise_fn_D, x)
|
|
120
|
+
return loss.mean(-1).mean()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# time-encoding layer
|
|
124
|
+
|
|
125
|
+
def make_te(dim_in, dim_out):
|
|
126
|
+
return nn.Sequential(nn.Linear(dim_in, dim_out),nn.SiLU(),nn.Linear(dim_out, dim_out))
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
#=================================== UNet=================================================
|
|
131
|
+
|
|
132
|
+
def weight_standardization(weight: torch.Tensor, eps: float):
|
|
133
|
+
c_out, c_in, kernel_size = weight.shape
|
|
134
|
+
weight = weight.view(c_out, -1)
|
|
135
|
+
var, mean = torch.var_mean(weight, dim=1, keepdim=True)
|
|
136
|
+
# Standardize weights
|
|
137
|
+
weight = (weight - mean) / (torch.sqrt(var + eps))
|
|
138
|
+
return weight.view(c_out, c_in, kernel_size)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class WSConv1d(torch.nn.Conv1d):
|
|
142
|
+
def __init__(self, in_channels, out_channels, kernel_size,
|
|
143
|
+
stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True,
|
|
144
|
+
padding_mode: str = 'zeros', eps: float = 1e-5):
|
|
145
|
+
super(WSConv1d, self).__init__(in_channels, out_channels, kernel_size,
|
|
146
|
+
stride=stride, padding=padding, dilation=dilation,
|
|
147
|
+
groups=groups, bias=bias, padding_mode=padding_mode)
|
|
148
|
+
self.eps = eps
|
|
149
|
+
|
|
150
|
+
def forward(self, x: torch.Tensor):
|
|
151
|
+
# Apply weight standardization before convolution
|
|
152
|
+
standardized_weight = weight_standardization(self.weight, self.eps)
|
|
153
|
+
return F.conv1d(x, standardized_weight, self.bias, self.stride,
|
|
154
|
+
self.padding, self.dilation, self.groups)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class ResNetBlock(nn.Module):
|
|
158
|
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, eps=1e-5):
|
|
159
|
+
super(ResNetBlock, self).__init__()
|
|
160
|
+
|
|
161
|
+
# First WSConv1d layer with GroupNorm
|
|
162
|
+
self.conv1 = WSConv1d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2, groups=groups, eps=eps)
|
|
163
|
+
self.norm1 = nn.GroupNorm(num_groups=1, num_channels=out_channels)
|
|
164
|
+
|
|
165
|
+
# Second WSConv1d layer with GroupNorm
|
|
166
|
+
self.conv2 = WSConv1d(out_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2, groups=groups, eps=eps)
|
|
167
|
+
self.norm2 = nn.GroupNorm(num_groups=1, num_channels=out_channels)
|
|
168
|
+
|
|
169
|
+
# Shortcut connection to match the shape when in_channels != out_channels
|
|
170
|
+
self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) if in_channels != out_channels else nn.Identity()
|
|
171
|
+
|
|
172
|
+
# activation
|
|
173
|
+
self.activation = nn.SiLU()
|
|
174
|
+
|
|
175
|
+
def forward(self,x):
|
|
176
|
+
# Save the input for the residual connection
|
|
177
|
+
identity = x
|
|
178
|
+
|
|
179
|
+
# First conv-norm-activation
|
|
180
|
+
out = self.conv1(x)
|
|
181
|
+
out = self.norm1(out)
|
|
182
|
+
out = self.activation(out)
|
|
183
|
+
|
|
184
|
+
# Second conv-norm-activation
|
|
185
|
+
out = self.conv2(out)
|
|
186
|
+
out = self.norm2(out)
|
|
187
|
+
|
|
188
|
+
# Add the residual (shortcut connection)
|
|
189
|
+
identity = self.shortcut(identity)
|
|
190
|
+
out += identity
|
|
191
|
+
|
|
192
|
+
# Apply ReLU after adding the residual
|
|
193
|
+
out = self.activation(out)
|
|
194
|
+
|
|
195
|
+
return out
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# attention block
|
|
199
|
+
# we apply attention block at a resolution level of 17
|
|
200
|
+
|
|
201
|
+
class AttentionBlock(nn.Module):
|
|
202
|
+
"""
|
|
203
|
+
MHA expects input of shape (batch_size, seq_length, embedding_dim).
|
|
204
|
+
Our data shape (batch_size, num_channels, resolution), e.g. (64,128,17).
|
|
205
|
+
embedding_dim == resolution
|
|
206
|
+
seq_length = num_channels
|
|
207
|
+
"""
|
|
208
|
+
def __init__(self, emb_dim, num_heads=1, dropout=0.1):
|
|
209
|
+
super(AttentionBlock, self).__init__()
|
|
210
|
+
|
|
211
|
+
self.attention = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, dropout=dropout)
|
|
212
|
+
|
|
213
|
+
# Layer normalization
|
|
214
|
+
self.layernorm1 = nn.LayerNorm(emb_dim)
|
|
215
|
+
self.layernorm2 = nn.LayerNorm(emb_dim)
|
|
216
|
+
|
|
217
|
+
# Feed-forward network (after attention)
|
|
218
|
+
self.ffn = nn.Sequential(
|
|
219
|
+
nn.Linear(emb_dim, emb_dim * 4),
|
|
220
|
+
nn.ReLU(),
|
|
221
|
+
nn.Linear(emb_dim * 4, emb_dim)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def forward(self, x):
|
|
225
|
+
|
|
226
|
+
# Apply self-attention
|
|
227
|
+
attn_output, attn_weights = self.attention(x, x, x) # Query, Key, Value all are x (self-attention)
|
|
228
|
+
|
|
229
|
+
# Residual connection followed by layer normalization : add & norm
|
|
230
|
+
x = self.layernorm1(x + attn_output)
|
|
231
|
+
|
|
232
|
+
# Feed-forward network
|
|
233
|
+
ffn_output = self.ffn(x)
|
|
234
|
+
|
|
235
|
+
# Residual connection followed by layer normalization : add & norm
|
|
236
|
+
x = self.layernorm2(x + ffn_output)
|
|
237
|
+
|
|
238
|
+
return x
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class UNet(nn.Module):
|
|
242
|
+
def __init__(self, in_c, out_c, time_emb_dim):
|
|
243
|
+
super(UNet, self).__init__()
|
|
244
|
+
|
|
245
|
+
# map noise labels
|
|
246
|
+
self.map_noise = PositionalEmbedding(emb_dim=time_emb_dim)
|
|
247
|
+
|
|
248
|
+
# encoder path
|
|
249
|
+
self.te1 = make_te(time_emb_dim, 3)
|
|
250
|
+
self.b1 = nn.Sequential(ResNetBlock(3, 32),ResNetBlock(32, 32),ResNetBlock(32,32)) # shape = [batch,32,136]
|
|
251
|
+
self.down1 = nn.Conv1d(32, 32, 4, 2, 1) # shape = [batch,32,68]
|
|
252
|
+
|
|
253
|
+
self.te2 = make_te(time_emb_dim, 32)
|
|
254
|
+
self.b2 = nn.Sequential(ResNetBlock(32, 64),ResNetBlock(64, 64),ResNetBlock(64,64)) # shape = [batch,64,68]
|
|
255
|
+
self.down2 = nn.Conv1d(64, 64, 4, 2, 1) # shape = [batch,64,34]
|
|
256
|
+
|
|
257
|
+
self.te3 = make_te(time_emb_dim, 64)
|
|
258
|
+
self.b3 = nn.Sequential(ResNetBlock(64, 128),ResNetBlock(128, 128),ResNetBlock(128,128)) # shape = [batch,128,34]
|
|
259
|
+
self.down3 = nn.Conv1d(128, 128, 4, 2, 1) # shape = [batch,128,17]
|
|
260
|
+
|
|
261
|
+
self.attn1 = AttentionBlock(emb_dim=17)
|
|
262
|
+
|
|
263
|
+
# Bottleneck
|
|
264
|
+
self.te_mid = make_te(time_emb_dim, 128)
|
|
265
|
+
self.b_mid = nn.Sequential(ResNetBlock(128, 64),ResNetBlock(64, 64),ResNetBlock(64,128)) # shape = [batch,128,17]
|
|
266
|
+
|
|
267
|
+
self.attn2 = AttentionBlock(emb_dim=17)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
# decoder path
|
|
271
|
+
self.up1 = nn.ConvTranspose1d(128, 128, 4, 2, 1) # shape = [batch,128,34]
|
|
272
|
+
self.te4 = make_te(time_emb_dim, 256)
|
|
273
|
+
self.b4 = nn.Sequential(ResNetBlock(256,128),ResNetBlock(128,64),ResNetBlock(64,64)) # shape = [batch,64,34]
|
|
274
|
+
|
|
275
|
+
self.up2 = nn.ConvTranspose1d(64, 64, 4, 2, 1) # shape = [batch,64,68]
|
|
276
|
+
self.te5 = make_te(time_emb_dim, 128)
|
|
277
|
+
self.b5 = nn.Sequential(ResNetBlock(128, 64),ResNetBlock(64, 32),ResNetBlock(32, 32)) # shape = [batch,32,68]
|
|
278
|
+
|
|
279
|
+
self.up3 = nn.ConvTranspose1d(32, 32, 4, 2, 1) # shape = [batch,32,136]
|
|
280
|
+
self.te6 = make_te(time_emb_dim, 64)
|
|
281
|
+
self.b6 = nn.Sequential(ResNetBlock(64, 32),ResNetBlock(32, 32),ResNetBlock(32, 32)) # shape = [batch,32,136]
|
|
282
|
+
|
|
283
|
+
# output
|
|
284
|
+
self.conv_out = nn.Conv1d(32, 3, 3, 1, 1) # shape = [batch,3,136]
|
|
285
|
+
|
|
286
|
+
def forward(self, x, noise_labels):
|
|
287
|
+
t = self.map_noise(noise_labels)
|
|
288
|
+
n = len(x)
|
|
289
|
+
out1 = self.b1(x + self.te1(t).reshape(n, -1, 1))
|
|
290
|
+
out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1))
|
|
291
|
+
out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1))
|
|
292
|
+
|
|
293
|
+
out_mid = self.b_mid(self.attn1(self.down3(out3)) + self.te_mid(t).reshape(n, -1, 1))
|
|
294
|
+
|
|
295
|
+
# self-attention
|
|
296
|
+
out_mid = self.attn2(out_mid)
|
|
297
|
+
|
|
298
|
+
out4 = torch.cat((out3,self.up1(out_mid)), dim=1)
|
|
299
|
+
out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1))
|
|
300
|
+
out5 = torch.cat((out2, self.up2(out4)), dim=1)
|
|
301
|
+
out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1))
|
|
302
|
+
|
|
303
|
+
out6 = torch.cat((out1, self.up3(out5)), dim=1)
|
|
304
|
+
out6 = self.b6(out6 + self.te6(t).reshape(n, -1, 1))
|
|
305
|
+
|
|
306
|
+
out = self.conv_out(out6)
|
|
307
|
+
return out
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
|
diffcrysgen/sampler.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
import tqdm
|
|
5
|
+
import time
|
|
6
|
+
from diffcrysgen.model import Model, UNet
|
|
7
|
+
from diffcrysgen.utils import *
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Adapted from "Elucidating the Design Space of Diffusion-Based
|
|
11
|
+
# Generative Models" by Karras et al.
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def round_sigma(sigma):
|
|
15
|
+
return torch.as_tensor(sigma)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def generate_samples(num_samples: int = 100, batch_size: int = 1000, model_path: str = "assets/saved-model/sdm.pt"):
|
|
19
|
+
"""
|
|
20
|
+
Generate IRCR samples using our pre-trained version.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
num_samples (int): Total number of samples to generate.
|
|
24
|
+
batch_size (int): Batch size for sampling.
|
|
25
|
+
model_path (str): Path to the saved model (.pt file).
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
np.ndarray: Generated IRCR array of shape (num_samples, F, C)
|
|
29
|
+
"""
|
|
30
|
+
# Load scaler
|
|
31
|
+
scaler = load_saved_diffusion_scaler()
|
|
32
|
+
print("Scaler loaded.")
|
|
33
|
+
|
|
34
|
+
# Load device
|
|
35
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
36
|
+
print("Using device:", device)
|
|
37
|
+
|
|
38
|
+
# Load the model
|
|
39
|
+
denoise_fn = UNet(in_c=3, out_c=3, time_emb_dim=256).to(device)
|
|
40
|
+
model = Model(denoise_fn=denoise_fn).to(device)
|
|
41
|
+
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
|
42
|
+
model.eval()
|
|
43
|
+
print("Model loaded from:", model_path)
|
|
44
|
+
|
|
45
|
+
# Sampling hyperparameters
|
|
46
|
+
num_steps = 100
|
|
47
|
+
sigma_min = 0.002
|
|
48
|
+
sigma_max = 80
|
|
49
|
+
rho = 7
|
|
50
|
+
S_churn = 1
|
|
51
|
+
S_min = 0
|
|
52
|
+
S_max = float("inf")
|
|
53
|
+
S_noise = 1
|
|
54
|
+
|
|
55
|
+
net = model.denoise_fn_D
|
|
56
|
+
sigma_min = max(sigma_min, net.sigma_min)
|
|
57
|
+
sigma_max = min(sigma_max, net.sigma_max)
|
|
58
|
+
|
|
59
|
+
# Time steps
|
|
60
|
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
|
|
61
|
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
|
62
|
+
t_steps = torch.cat([round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # append t_N = 0
|
|
63
|
+
|
|
64
|
+
img_channels = 3
|
|
65
|
+
img_resolution = 136
|
|
66
|
+
|
|
67
|
+
total_loops = int(np.ceil(num_samples / batch_size))
|
|
68
|
+
final_data = []
|
|
69
|
+
|
|
70
|
+
start_time = time.time()
|
|
71
|
+
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
for loop in range(total_loops):
|
|
74
|
+
curr_batch_size = min(batch_size, num_samples - loop * batch_size)
|
|
75
|
+
latents = torch.randn([curr_batch_size, img_channels, img_resolution], device=device)
|
|
76
|
+
x_next = latents * t_steps[0]
|
|
77
|
+
|
|
78
|
+
# Main Sampling Loop
|
|
79
|
+
for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), desc=f"Sampling {loop+1}/{total_loops}", leave=False):
|
|
80
|
+
x_cur = x_next
|
|
81
|
+
|
|
82
|
+
# Increase Noise temporarily
|
|
83
|
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
|
84
|
+
t_hat = round_sigma(t_cur + gamma * t_cur)
|
|
85
|
+
|
|
86
|
+
t_cur = t_cur.repeat(curr_batch_size)
|
|
87
|
+
t_next = t_next.repeat(curr_batch_size)
|
|
88
|
+
t_hat = t_hat.repeat(curr_batch_size)
|
|
89
|
+
|
|
90
|
+
x_hat = x_cur + (t_hat[:, None, None] ** 2 - t_cur[:, None, None] ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
|
|
91
|
+
|
|
92
|
+
denoised = net(x_hat, t_hat).float()
|
|
93
|
+
d_cur = (x_hat - denoised) / t_hat[:, None, None]
|
|
94
|
+
x_next = x_hat + (t_next[:, None, None] - t_hat[:, None, None]) * d_cur
|
|
95
|
+
|
|
96
|
+
if i < num_steps - 1:
|
|
97
|
+
denoised = net(x_next, t_next).float()
|
|
98
|
+
d_prime = (x_next - denoised) / t_next[:, None, None]
|
|
99
|
+
x_next = x_hat + (t_next[:, None, None] - t_hat[:, None, None]) * (0.5 * d_cur + 0.5 * d_prime)
|
|
100
|
+
|
|
101
|
+
x_next = x_next.permute(0, 2, 1).cpu().numpy() # shape: (batch, F, C)
|
|
102
|
+
x_next = (x_next + 1) / 2
|
|
103
|
+
x_next = inv_minmax(x_next, scaler)
|
|
104
|
+
final_data.append(x_next)
|
|
105
|
+
|
|
106
|
+
print(f"Sampling batch {loop + 1}/{total_loops} done.")
|
|
107
|
+
|
|
108
|
+
elapsed = (time.time() - start_time) / 60
|
|
109
|
+
print(f"Total sampling time: {elapsed:.2f} minutes")
|
|
110
|
+
|
|
111
|
+
generated_array = np.concatenate(final_data, axis=0)[:num_samples]
|
|
112
|
+
print("Generated IRCR shape:", generated_array.shape)
|
|
113
|
+
return generated_array
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
|
diffcrysgen/trainer.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau as RLROP
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
from model import UNet, Model
|
|
11
|
+
from utils import minmax
|
|
12
|
+
|
|
13
|
+
def load_data(train_path, test_path, batch_size=64):
|
|
14
|
+
# Original IRCR has shape [batch,142,3] : IRCR={E,L,C,O,P}
|
|
15
|
+
# We are not considering property matrix (P) for unconditional generation.
|
|
16
|
+
train_data = np.load(train_path)[:, :136, :]
|
|
17
|
+
test_data = np.load(test_path)[:, :136, :]
|
|
18
|
+
|
|
19
|
+
print(f"Training data shape: {train_data.shape}")
|
|
20
|
+
print(f"Test data shape: {test_data.shape}")
|
|
21
|
+
|
|
22
|
+
# Normalize [0, 1]
|
|
23
|
+
train_scaled, train_scaler = minmax(train_data)
|
|
24
|
+
test_scaled, test_scaler = minmax(test_data)
|
|
25
|
+
|
|
26
|
+
# Normalize to [-1, 1]
|
|
27
|
+
train_scaled = 2 * train_scaled - 1
|
|
28
|
+
test_scaled = 2 * test_scaled - 1
|
|
29
|
+
|
|
30
|
+
# Convert to torch tensors (N, 3, F)
|
|
31
|
+
train_tensor = torch.from_numpy(train_scaled).float().permute(0, 2, 1)
|
|
32
|
+
test_tensor = torch.from_numpy(test_scaled).float().permute(0, 2, 1)
|
|
33
|
+
|
|
34
|
+
print(f"Transformed training shape: {train_tensor.shape}")
|
|
35
|
+
print(f"Transformed test shape: {test_tensor.shape}")
|
|
36
|
+
print(f"Train value range: {train_tensor.min().item()} to {train_tensor.max().item()}")
|
|
37
|
+
|
|
38
|
+
train_loader = DataLoader(train_tensor, batch_size=batch_size)
|
|
39
|
+
test_loader = DataLoader(test_tensor, batch_size=batch_size)
|
|
40
|
+
|
|
41
|
+
return train_loader, test_loader, train_scaler, test_scaler
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def train_diffcrysgen(
|
|
45
|
+
train_loader,
|
|
46
|
+
test_loader,
|
|
47
|
+
num_epochs=1000,
|
|
48
|
+
save_path="sdm.pt",
|
|
49
|
+
initial_lr=5e-4,
|
|
50
|
+
device=None,
|
|
51
|
+
):
|
|
52
|
+
if device is None:
|
|
53
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
54
|
+
print(f"Using device: {device}")
|
|
55
|
+
|
|
56
|
+
denoise_fn = UNet(in_c=3, out_c=3, time_emb_dim=256).to(device)
|
|
57
|
+
model = Model(denoise_fn=denoise_fn).to(device)
|
|
58
|
+
|
|
59
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=0)
|
|
60
|
+
scheduler = RLROP(optimizer, mode="min", factor=0.3, patience=50, verbose=True)
|
|
61
|
+
|
|
62
|
+
best_loss = float("inf")
|
|
63
|
+
train_losses = []
|
|
64
|
+
test_losses = []
|
|
65
|
+
lrs = []
|
|
66
|
+
|
|
67
|
+
start_time = time.time()
|
|
68
|
+
|
|
69
|
+
for epoch in tqdm(range(num_epochs), desc="Training Progress", colour="green"):
|
|
70
|
+
# Manual LR scheduling based on epochs
|
|
71
|
+
if epoch == 500:
|
|
72
|
+
for group in optimizer.param_groups:
|
|
73
|
+
group["lr"] = 1e-4
|
|
74
|
+
elif epoch == 800:
|
|
75
|
+
for group in optimizer.param_groups:
|
|
76
|
+
group["lr"] = 5e-5
|
|
77
|
+
|
|
78
|
+
model.train()
|
|
79
|
+
train_loss = 0
|
|
80
|
+
for batch in tqdm(train_loader, leave=False, desc=f"Epoch {epoch+1}/{num_epochs}", colour="blue"):
|
|
81
|
+
batch = batch.to(device)
|
|
82
|
+
loss = model(batch)
|
|
83
|
+
train_loss += loss.item()
|
|
84
|
+
|
|
85
|
+
optimizer.zero_grad()
|
|
86
|
+
loss.backward()
|
|
87
|
+
optimizer.step()
|
|
88
|
+
|
|
89
|
+
train_loss /= len(train_loader)
|
|
90
|
+
train_losses.append(train_loss)
|
|
91
|
+
|
|
92
|
+
model.eval()
|
|
93
|
+
test_loss = 0
|
|
94
|
+
with torch.no_grad():
|
|
95
|
+
for batch in tqdm(test_loader, leave=False, desc=f"Epoch {epoch+1}/{num_epochs}", colour="blue"):
|
|
96
|
+
batch = batch.to(device)
|
|
97
|
+
loss = model(batch)
|
|
98
|
+
test_loss += loss.item()
|
|
99
|
+
test_loss /= len(test_loader)
|
|
100
|
+
test_losses.append(test_loss)
|
|
101
|
+
|
|
102
|
+
# step the scheduler
|
|
103
|
+
# scheduler.step(test_loss)
|
|
104
|
+
|
|
105
|
+
current_lr = optimizer.param_groups[0]["lr"]
|
|
106
|
+
lrs.append(current_lr)
|
|
107
|
+
|
|
108
|
+
log_msg = f"Epoch {epoch+1}: train_loss: {train_loss:.4f}, test_loss: {test_loss:.4f}, lr: {current_lr:.6f}"
|
|
109
|
+
if test_loss < best_loss:
|
|
110
|
+
best_loss = test_loss
|
|
111
|
+
torch.save(model.state_dict(), save_path)
|
|
112
|
+
log_msg += " --> Best model ever (stored)"
|
|
113
|
+
print(log_msg)
|
|
114
|
+
|
|
115
|
+
# scheduler.step(val_loss)
|
|
116
|
+
|
|
117
|
+
print(f"Training completed in {(time.time() - start_time):.2f}s")
|
|
118
|
+
|
|
119
|
+
# Save epoch-wise loss
|
|
120
|
+
results_df = pd.DataFrame({
|
|
121
|
+
"epoch": np.arange(num_epochs),
|
|
122
|
+
"train_loss": train_losses,
|
|
123
|
+
"test_loss": test_losses,
|
|
124
|
+
"lr": lrs
|
|
125
|
+
})
|
|
126
|
+
results_df.to_csv("epoch-loss.csv", index=False)
|
|
127
|
+
print(results_df.tail())
|
|
128
|
+
|
diffcrysgen/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import joblib
|
|
4
|
+
import numpy as np
|
|
5
|
+
from sklearn.preprocessing import MinMaxScaler
|
|
6
|
+
|
|
7
|
+
# This is the normalization script adapted from FTCP work.
|
|
8
|
+
# It works for FTCP, IRCR.
|
|
9
|
+
|
|
10
|
+
def minmax(pcr):
|
|
11
|
+
dim0, dim1, dim2 = pcr.shape
|
|
12
|
+
scaler = MinMaxScaler()
|
|
13
|
+
pcr_ = np.transpose(pcr, (1, 0, 2))
|
|
14
|
+
pcr_ = pcr_.reshape(dim1, dim0*dim2)
|
|
15
|
+
pcr_ = scaler.fit_transform(pcr_.T)
|
|
16
|
+
pcr_ = pcr_.T
|
|
17
|
+
pcr_ = pcr_.reshape(dim1, dim0, dim2)
|
|
18
|
+
pcr_normed = np.transpose(pcr_, (1, 0, 2))
|
|
19
|
+
return pcr_normed, scaler
|
|
20
|
+
|
|
21
|
+
def inv_minmax(pcr_normed, scaler):
|
|
22
|
+
dim0, dim1, dim2 = pcr_normed.shape
|
|
23
|
+
pcr_ = np.transpose(pcr_normed, (1, 0, 2))
|
|
24
|
+
pcr_ = pcr_.reshape(dim1, dim0*dim2)
|
|
25
|
+
pcr_ = scaler.inverse_transform(pcr_.T)
|
|
26
|
+
pcr_ = pcr_.T
|
|
27
|
+
pcr_ = pcr_.reshape(dim1, dim0, dim2)
|
|
28
|
+
pcr = np.transpose(pcr_, (1, 0, 2))
|
|
29
|
+
return pcr
|
|
30
|
+
|
|
31
|
+
def load_saved_diffusion_scaler(path=None):
|
|
32
|
+
if path is None:
|
|
33
|
+
here = os.path.dirname(__file__)
|
|
34
|
+
path = os.path.join(here, "..", "assets", "ircr_diffusion_scaler.pkl")
|
|
35
|
+
path = os.path.abspath(path)
|
|
36
|
+
return joblib.load(path)
|
|
37
|
+
|
|
38
|
+
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: diffcrysgen
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: "DiffCrysGen is a score-based diffusion model for accelerated design of diverse inorganic crystalline materials."
|
|
5
|
+
Home-page: https://github.com/SouravMal/DiffCrysGen.git
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Chemistry
|
|
12
|
+
Classifier: Topic :: Scientific/Engineering :: Physics
|
|
13
|
+
Requires-Python: >=3.11
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: torch~=2.5.1
|
|
17
|
+
Requires-Dist: numpy~=2.0.1
|
|
18
|
+
Requires-Dist: pandas~=2.2.3
|
|
19
|
+
Requires-Dist: scikit-learn~=1.6.1
|
|
20
|
+
Requires-Dist: tqdm
|
|
21
|
+
Requires-Dist: ase~=3.24.0
|
|
22
|
+
Requires-Dist: spglib~=2.5.0
|
|
23
|
+
Requires-Dist: joblib
|
|
24
|
+
Dynamic: license-file
|
|
25
|
+
|
|
26
|
+
# DiffCrysGen [](https://github.com/SouravMal/DiffCrysGen)
|
|
27
|
+
|
|
28
|
+
DiffCrysGen is a score-based diffusion model. It treats the entire materials representation with a single, unified diffusion process, allowing a single denosing neural network to predict a holistic score for the entire noisy crystal data. This unified treatment significantly simplifies the architecture and improves the computational efficiency.
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
<p align="center">
|
|
32
|
+
<img src="images/logo-DiffCrysGen.png" alt="DiffCrysGen Logo" width="350"/>
|
|
33
|
+
</p>
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
## Generative diffusion framework in DiffCrysGen
|
|
37
|
+
<img src="images/diffusion-schematic.png" alt="DiffCrysGen Schematic" width="550">
|
|
38
|
+
|
|
39
|
+
---
|
|
40
|
+
|
|
41
|
+
## Architecture of the denoising neural network
|
|
42
|
+
<img src="images/architecture.png" alt="DiffCrysGen Architecture" width="750">
|
|
43
|
+
|
|
44
|
+
---
|
|
45
|
+
|
|
46
|
+
## Installation
|
|
47
|
+
|
|
48
|
+
### Prerequisites
|
|
49
|
+
|
|
50
|
+
The package requires specific environments and dependencies.
|
|
51
|
+
Using a virtual environment is highly recommended.
|
|
52
|
+
**Environment Setup using Conda**
|
|
53
|
+
|
|
54
|
+
```
|
|
55
|
+
conda create -n diffcrysgen python=3.11
|
|
56
|
+
conda activate diffcrysgen
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
### Install from PyPI
|
|
60
|
+
```
|
|
61
|
+
pip install diffcrysgen
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
### Install from Source Code
|
|
65
|
+
```
|
|
66
|
+
git clone https://github.com/SouravMal/DiffCrysGen.git
|
|
67
|
+
cd DiffCrysGen
|
|
68
|
+
pip install -e .
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
## Quick Start
|
|
72
|
+
For a simple walkthrough of generating materials and analyzing them, see the [DiffCrysGen Demo Notebook](./notebooks/DiffCrysGen-demo.ipynb).
|
|
73
|
+
|
|
74
|
+
## License
|
|
75
|
+
|
|
76
|
+
This project is licensed under the **MIT License**.
|
|
77
|
+
|
|
78
|
+
See the [LICENSE](LICENSE) file for details.
|
|
79
|
+
|
|
80
|
+
Developed by: [Sourav Mal](https://github.com/SouravMal) at Harish-Chandra Research Institute (HRI) (https://www.hri.res.in/), Prayagraj, India.
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
## Citation
|
|
84
|
+
|
|
85
|
+
Please consider citing our work if you find it helpful:
|
|
86
|
+
|
|
87
|
+
```bibtex
|
|
88
|
+
@misc{mal2025generativediffusionmodeldiffcrysgen,
|
|
89
|
+
title={Generative Diffusion Model DiffCrysGen Discovers Rare Earth-Free Magnetic Materials},
|
|
90
|
+
author={Sourav Mal and Nehad Ahmed and Subhankar Mishra and Prasenjit Sen},
|
|
91
|
+
year={2025},
|
|
92
|
+
eprint={2510.12329},
|
|
93
|
+
archivePrefix={arXiv},
|
|
94
|
+
primaryClass={cond-mat.mtrl-sci},
|
|
95
|
+
url={https://arxiv.org/abs/2510.12329},
|
|
96
|
+
}
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
## Contact
|
|
101
|
+
|
|
102
|
+
If you have any questions, feel free to reach us at:
|
|
103
|
+
**Sourav Mal** <souravmal492@gmail.com>
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
diffcrysgen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
diffcrysgen/analyze_generated_structures.py,sha256=fbybCbvH0o_xv_1v_tSV1OfbDEau7FHPs8oHyCAlf0Q,3475
|
|
3
|
+
diffcrysgen/invert_pcr.py,sha256=Q6j7AVb6iZWR7t3VjB1yFLnuPB9hCc4_0Vp1K72gcak,5944
|
|
4
|
+
diffcrysgen/model.py,sha256=sb42SQJPm9nQNSGMNN88RXwWE0tIsompypFZxmJob3k,11166
|
|
5
|
+
diffcrysgen/sampler.py,sha256=k5r2D0MmWcQRebcg0KBugl1AEg6FPrnvfobZ_uyznP0,4092
|
|
6
|
+
diffcrysgen/trainer.py,sha256=SI988DGpG0rgfixV9Sz9Ue1Lk5qG6uZvC1ka-BP9KOA,4189
|
|
7
|
+
diffcrysgen/utils.py,sha256=FI6bkDGMt7E5_WbL55la--MGkOryz11KBgmVfivBvao,1094
|
|
8
|
+
diffcrysgen-0.1.0.dist-info/licenses/LICENSE,sha256=7WpR2tbfkkPUcuvEnP_B3ijJczHpi-_U_i2ngeJGqSY,1067
|
|
9
|
+
diffcrysgen-0.1.0.dist-info/METADATA,sha256=KhTsxU3Q6uqtvBWe9UrwHrrkUcFyMDvOJrcCt3x-_t4,3209
|
|
10
|
+
diffcrysgen-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
11
|
+
diffcrysgen-0.1.0.dist-info/top_level.txt,sha256=aN4XuKXo6kncvD9N8mx9SOVKSFfNqBglK7G9jm5u3Fs,12
|
|
12
|
+
diffcrysgen-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Sourav Mal
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
diffcrysgen
|