diffcrysgen 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,88 @@
1
+ # diffcrysgen/analyze_generated_structures.py
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import spglib
6
+ from diffcrysgen.invert_pcr import InvertPCR
7
+ from ase import Atoms
8
+ import ase.io
9
+ import os
10
+
11
+ def analyze_ircr_data(generated_ircr_path="generated_ircr.npy", output_csv="generated_material_data.csv", min_valid_distance=0.5, save_cifs=False, cif_dir="cif-files"):
12
+ """
13
+ Analyzes generated IRCR data and extracts crystallographic info.
14
+
15
+ Args:
16
+ generated_ircr_path (str): Path to the generated IRCR .npy file
17
+ output_csv (str): Path to save the resulting CSV
18
+ min_valid_distance (float): Distance threshold for validity check in Angstrom
19
+ save_cifs (bool): If True, saves CIF files into subfolders
20
+ cif_dir (str): Base directory to store valid/invalid CIFs
21
+ """
22
+ data = np.load(generated_ircr_path)
23
+ Nmat = data.shape[0]
24
+ print(f"Loaded {Nmat} materials from {generated_ircr_path}")
25
+
26
+ results = {
27
+ "ID": [], "Formula": [], "type": [], "Natoms": [], "min-pairwise-distance": [],
28
+ "Zmax": [], "SPG-symbol": [], "SPG-number": [], "validity": [],
29
+ "a": [], "b": [], "c": [], "alpha": [], "beta": [], "gamma": []
30
+ }
31
+
32
+ if save_cifs:
33
+ os.makedirs(os.path.join(cif_dir, "valid"), exist_ok=True)
34
+ os.makedirs(os.path.join(cif_dir, "invalid"), exist_ok=True)
35
+
36
+ for i in range(Nmat):
37
+ mat = data[i, :]
38
+ ipcr = InvertPCR(mat, 94, 20)
39
+
40
+ try:
41
+ atoms = ipcr.get_atoms_object()
42
+ formula = atoms.get_chemical_formula()
43
+ atomic_numbers = atoms.get_atomic_numbers()
44
+ unique_elements = len(set(atomic_numbers))
45
+
46
+ mat_type = {1: "elemental", 2: "binary", 3: "ternary"}.get(unique_elements, "complex")
47
+ a, b, c = atoms.cell.lengths()
48
+ alpha, beta, gamma = atoms.cell.angles()
49
+
50
+ structure = (atoms.get_cell(), atoms.get_scaled_positions(), atomic_numbers)
51
+ spg_symbol, spg_number = spglib.get_spacegroup(structure, symprec=0.1).split(" ")
52
+ spg_number = int(spg_number.strip("()"))
53
+
54
+ distances = ipcr.get_distances()
55
+ min_dist = min(distances) if distances else 3.0
56
+ is_valid = min_dist >= min_valid_distance
57
+ mat_id = f"smps-{i+1}"
58
+
59
+ if save_cifs:
60
+ cif_path = os.path.join(cif_dir, "valid" if is_valid else "invalid", f"{mat_id}.cif")
61
+ ase.io.write(cif_path, atoms)
62
+
63
+ # Append data
64
+ results["ID"].append(mat_id)
65
+ results["Formula"].append(formula)
66
+ results["type"].append(mat_type)
67
+ results["Natoms"].append(len(atomic_numbers))
68
+ results["min-pairwise-distance"].append(min_dist)
69
+ results["Zmax"].append(max(atomic_numbers))
70
+ results["SPG-symbol"].append(spg_symbol)
71
+ results["SPG-number"].append(spg_number)
72
+ results["validity"].append("valid" if is_valid else "invalid")
73
+ results["a"].append(a)
74
+ results["b"].append(b)
75
+ results["c"].append(c)
76
+ results["alpha"].append(alpha)
77
+ results["beta"].append(beta)
78
+ results["gamma"].append(gamma)
79
+
80
+ except Exception as e:
81
+ #print(f"[Material {i}] Error: {e}")
82
+ print("spglib failed to assign spg")
83
+
84
+ df = pd.DataFrame(results)
85
+ df.to_csv(output_csv, index=False)
86
+ print(f"Saved extracted data to: {output_csv}")
87
+ return df
88
+
@@ -0,0 +1,163 @@
1
+
2
+ from ase.spacegroup import crystal
3
+ import numpy as np
4
+ import ase.io
5
+
6
+
7
+ # Dictionary containing atomic number with respective element
8
+
9
+ atomic_name = {"1":"H", "2":"He", "3":"Li", "4":"Be", "5":"B", "6":"C", "7":"N", "8":"O", "9":"F", "10":"Ne",\
10
+ "11":"Na", "12":"Mg", "13":"Al", "14":"Si", "15":"P", "16":"S", "17":"Cl", "18":"Ar", "19": "K",\
11
+ "20":"Ca", "21":"Sc", "22":"Ti", "23":"V", "24":"Cr", "25":"Mn", "26":"Fe", "27":"Co", "28":"Ni",\
12
+ "29":"Cu", "30":"Zn", "31":"Ga", "32":"Ge", "33":"As","34":"Se", "35":"Br","36":"Kr", "37":"Rb",\
13
+ "38":"Sr", "39":"Y", "40":"Zr", "41":"Nb", "42":"Mo", "43":"Tc", "44":"Ru", "45":"Rh", "46":"Pd",\
14
+ "47":"Ag", "48":"Cd", "49":"In", "50":"Sn", "51":"Sb", "52":"Te", "53":"I", "54":"Xe", "55":"Cs", "56":"Ba",\
15
+ "57":"La", "58":"Ce", "59":"Pr", "60":"Nd", "61":"Pm", "62":"Sm", "63":"Eu", "64":"Gd", "65":"Tb", "66":"Dy",\
16
+ "67":"Ho", "68":"Er", "69":"Tm", "70":"Yb", "71":"Lu", "72":"Hf", "73":"Ta", "74":"W", "75":"Re", "76":"Os",\
17
+ "77":"Ir", "78":"Pt", "79":"Au", "80":"Hg", "81":"Tl", "82":"Pb", "83":"Bi", "84":"Po", "85":"At", "86":"Rn",\
18
+ "87":"Fr", "88":"Ra", "89":"Ac", "90":"Th", "91":"Pa", "92":"U", "93":"Np", "94":"Pu"}
19
+
20
+
21
+ class InvertPCR :
22
+
23
+ """ Here we will extract necessary informations from point clound reps (PCR) in order to construct cif file."""
24
+
25
+ def __init__(self,PCR,z_max,n_sites) :
26
+ self.PCR = PCR
27
+ self.z_max = z_max
28
+ self.n_sites = n_sites
29
+
30
+ def get_element_matrix(self):
31
+ ele_mat = self.PCR[0:self.z_max,:]
32
+ ele_mat[ele_mat < 0.5] = 0
33
+ return ele_mat
34
+
35
+ def get_lattice_matrix(self):
36
+ lat_mat = self.PCR[self.z_max:self.z_max+2,:]
37
+ return lat_mat
38
+
39
+ def get_lattice_parameters(self):
40
+ lattice_matrix = self.get_lattice_matrix()
41
+ lattice_parameters = list(lattice_matrix.flatten())
42
+ a = lattice_parameters[0]
43
+ b = lattice_parameters[1]
44
+ c = lattice_parameters[2]
45
+ alpha = lattice_parameters[3]
46
+ beta = lattice_parameters[4]
47
+ gamma = lattice_parameters[5]
48
+ scaled_par = [a,b,c,alpha,beta,gamma]
49
+ return scaled_par
50
+
51
+ def get_basis_matrix(self):
52
+ basis_mat = self.PCR[self.z_max+2:self.z_max+2+self.n_sites,:]
53
+ return basis_mat
54
+
55
+ def get_site_matrix(self):
56
+ site_mat = self.PCR[self.z_max+2+self.n_sites:self.z_max+2+2*self.n_sites,:]
57
+ site_mat[site_mat < 0.5] = 0
58
+ return site_mat
59
+
60
+ def get_property_matrix(self):
61
+ prop_mat = self.PCR[self.z_max+2+2*self.n_sites:self.z_max+2+2*self.n_sites+8,:]
62
+ return prop_mat
63
+
64
+ def get_unique_atomic_numbers(self):
65
+ element_matrix = self.get_element_matrix()
66
+ z_unique = []
67
+ for i in range(3):
68
+ col = list(element_matrix[:,i])
69
+ max_value = max(col)
70
+ if max_value == 0 :
71
+ z_unique.append(0)
72
+ if max_value != 0 :
73
+ z_unique.append(col.index(max_value)+1)
74
+ return z_unique
75
+
76
+ def get_unique_elements(self):
77
+ z_unique = self.get_unique_atomic_numbers()
78
+ unique_element = []
79
+ for z in z_unique :
80
+ if z == 0 :
81
+ unique_element.append("-")
82
+ else :
83
+ unique_element.append(atomic_name[str(z)])
84
+ return unique_element
85
+
86
+
87
+ def get_atomic_numbers(self):
88
+ element_matrix = self.get_element_matrix()
89
+ site_matrix = self.get_site_matrix()
90
+ z_unique = self.get_unique_atomic_numbers()
91
+ z_total = []
92
+ for i in range(self.n_sites):
93
+ row = list(site_matrix[i,:])
94
+ val = max(row)
95
+ if val == 0 :
96
+ z_total.append(0)
97
+ if val != 0 :
98
+ z_total.append(z_unique[row.index(val)])
99
+ return z_total
100
+
101
+ def get_elements_basis(self):
102
+ basis = self.get_basis_matrix()
103
+ z_total = self.get_atomic_numbers()
104
+ final_atoms = []
105
+ final_atoms_index = []
106
+ final_elements = []
107
+ for i in range(len(z_total)):
108
+ z = z_total[i]
109
+ if z != 0 :
110
+ final_atoms_index.append(i)
111
+ final_atoms.append(z)
112
+ final_elements.append(atomic_name[str(z)])
113
+ final_basis = basis[final_atoms_index,:]
114
+
115
+ return final_elements, final_basis
116
+
117
+ def get_formula(self) :
118
+ a,_ = self.get_elements_basis()
119
+ elements = [a[0]]
120
+ numbers = []
121
+
122
+ for i in range(len(a)) :
123
+ ele = a[i]
124
+ if ele != elements[len(elements)-1] :
125
+ elements.append(ele)
126
+ numbers.append(i)
127
+
128
+ numbers.append(len(a))
129
+
130
+ final_numbers = []
131
+ final_numbers.append(numbers[0])
132
+ for i in range(1,len(numbers)):
133
+ final_numbers.append(numbers[i]-numbers[i-1])
134
+
135
+ formula = elements[0] + str(final_numbers[0])
136
+ for i in range(1,len(elements)) :
137
+ formula += elements[i] + str(final_numbers[i])
138
+
139
+ return formula
140
+
141
+ def get_atoms_object(self):
142
+ symbol = self.get_formula()
143
+ _, basis = self.get_elements_basis()
144
+ lattice = self.get_lattice_parameters()
145
+ atoms = crystal(symbols=symbol,
146
+ basis=basis,
147
+ cellpar=lattice)
148
+ return atoms
149
+
150
+ def get_distances(self):
151
+ atoms = self.get_atoms_object()
152
+ positions = atoms.get_positions()
153
+ nele = positions.shape[0]
154
+ distances = []
155
+ for i in range(nele):
156
+ for j in range(i+1,nele):
157
+ pos1 = positions[i,:]
158
+ pos2 = positions[j,:]
159
+ vec = pos1 - pos2
160
+ dist = np.linalg.norm(vec)
161
+ distances.append(np.around(dist,4))
162
+ return distances
163
+
diffcrysgen/model.py ADDED
@@ -0,0 +1,316 @@
1
+
2
+
3
+ # Importing necessary libraries
4
+
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ # Sinusoidal Positional Embedding
12
+
13
+ class PositionalEmbedding(torch.nn.Module):
14
+ """
15
+ Given a noise label, it provides the corresponding positional embedding vector.
16
+ """
17
+ def __init__(self, emb_dim, max_positions = 10000):
18
+ """
19
+ Parameter : emb_dim (Type : Int) : Embedding Dimension
20
+ """
21
+ super(PositionalEmbedding, self).__init__()
22
+ self.emb_dim = emb_dim
23
+ self.max_positions = max_positions
24
+
25
+ def forward(self, x):
26
+ """
27
+ Input : x [Tensor object], representing noise levels for a batch of input images
28
+ """
29
+
30
+ # Precompute frequencies for efficiency
31
+ freqs = torch.arange(0, self.emb_dim // 2, dtype=torch.float32, device=x.device)
32
+ freqs = freqs / (self.emb_dim // 2)
33
+ freqs = (1 / self.max_positions) ** freqs
34
+
35
+ # Outer product between Input tensor and freqs tensor
36
+ x = torch.tensordot(x, freqs, dims=0)
37
+ # Sinusoidal embedding
38
+ # sin block
39
+ sin_block = torch.sin(x)
40
+ # cos block
41
+ cos_block = torch.cos(x)
42
+ # Concatenate sin and cos block along column dimension
43
+ embedding = torch.cat([sin_block, cos_block], dim=1)
44
+ # Finally, concatenate sin and cos block in such a way so that even columns
45
+ # come from sin block and odd column comes from cos block.
46
+ sin_unsqueezed = sin_block.unsqueeze(2)
47
+ cos_unsqueezed = cos_block.unsqueeze(2)
48
+ final_embedding = torch.reshape(torch.cat([sin_unsqueezed, cos_unsqueezed], dim=2), (embedding.shape[0], -1))
49
+ return final_embedding
50
+
51
+
52
+ # Module for Preconditioning
53
+ # Adapted from "Elucidating the Design Space of Diffusion-Based
54
+ # Generative Models" by Karras et al.
55
+
56
+ class Precond(nn.Module):
57
+ def __init__(self,
58
+ denoise_fn,
59
+ sigma_min=0, # minimum supported noise level
60
+ sigma_max=float("inf"), # maximum supported noise level
61
+ sigma_data=0.5, # expected standard deviation of training data
62
+ ):
63
+ super().__init__()
64
+
65
+ self.denoise_fn_F = denoise_fn
66
+ self.sigma_min = sigma_min
67
+ self.sigma_max = sigma_max
68
+ self.sigma_data = sigma_data
69
+
70
+ def forward(self, x, sigma):
71
+ x = x.to(torch.float32)
72
+ sigma = sigma.to(torch.float32).reshape(-1,1,1)
73
+ dtype = torch.float32
74
+
75
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
76
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
77
+ c_in = 1 / (sigma ** 2 + self.sigma_data ** 2).sqrt()
78
+ c_noise = sigma.log() / 4
79
+
80
+ x_in = c_in * x
81
+ F_x = self.denoise_fn_F((x_in).to(dtype), c_noise.flatten())
82
+
83
+ assert F_x.dtype == dtype
84
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
85
+ return D_x
86
+
87
+ def round_sigma(self, sigma):
88
+ return torch.as_tensor(sigma)
89
+
90
+
91
+ # EDM Loss
92
+
93
+ def EDMLoss(denoise_fn, data):
94
+ P_mean = -1.2
95
+ P_std = 1.2
96
+ sigma_data = 0.5
97
+ rnd_normal = torch.randn(data.shape[0], device=data.device)
98
+ sigma = (rnd_normal * P_std + P_mean).exp()
99
+ weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
100
+ y = data
101
+ n = torch.randn_like(y) * sigma.unsqueeze(1).unsqueeze(1)
102
+ D_yn = denoise_fn(y + n, sigma)
103
+ target = y
104
+ loss = weight.unsqueeze(1).unsqueeze(1) * ((D_yn - target) ** 2)
105
+ return loss
106
+
107
+
108
+ # Final Model
109
+
110
+ class Model(nn.Module):
111
+ def __init__(self, denoise_fn, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
112
+ super().__init__()
113
+ self.P_mean = P_mean
114
+ self.P_std = P_std
115
+ self.sigma_data = sigma_data
116
+ self.denoise_fn_D = Precond(denoise_fn)
117
+
118
+ def forward(self, x):
119
+ loss = EDMLoss(self.denoise_fn_D, x)
120
+ return loss.mean(-1).mean()
121
+
122
+
123
+ # time-encoding layer
124
+
125
+ def make_te(dim_in, dim_out):
126
+ return nn.Sequential(nn.Linear(dim_in, dim_out),nn.SiLU(),nn.Linear(dim_out, dim_out))
127
+
128
+
129
+
130
+ #=================================== UNet=================================================
131
+
132
+ def weight_standardization(weight: torch.Tensor, eps: float):
133
+ c_out, c_in, kernel_size = weight.shape
134
+ weight = weight.view(c_out, -1)
135
+ var, mean = torch.var_mean(weight, dim=1, keepdim=True)
136
+ # Standardize weights
137
+ weight = (weight - mean) / (torch.sqrt(var + eps))
138
+ return weight.view(c_out, c_in, kernel_size)
139
+
140
+
141
+ class WSConv1d(torch.nn.Conv1d):
142
+ def __init__(self, in_channels, out_channels, kernel_size,
143
+ stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True,
144
+ padding_mode: str = 'zeros', eps: float = 1e-5):
145
+ super(WSConv1d, self).__init__(in_channels, out_channels, kernel_size,
146
+ stride=stride, padding=padding, dilation=dilation,
147
+ groups=groups, bias=bias, padding_mode=padding_mode)
148
+ self.eps = eps
149
+
150
+ def forward(self, x: torch.Tensor):
151
+ # Apply weight standardization before convolution
152
+ standardized_weight = weight_standardization(self.weight, self.eps)
153
+ return F.conv1d(x, standardized_weight, self.bias, self.stride,
154
+ self.padding, self.dilation, self.groups)
155
+
156
+
157
+ class ResNetBlock(nn.Module):
158
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, groups=1, eps=1e-5):
159
+ super(ResNetBlock, self).__init__()
160
+
161
+ # First WSConv1d layer with GroupNorm
162
+ self.conv1 = WSConv1d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2, groups=groups, eps=eps)
163
+ self.norm1 = nn.GroupNorm(num_groups=1, num_channels=out_channels)
164
+
165
+ # Second WSConv1d layer with GroupNorm
166
+ self.conv2 = WSConv1d(out_channels, out_channels, kernel_size, stride=stride, padding=kernel_size//2, groups=groups, eps=eps)
167
+ self.norm2 = nn.GroupNorm(num_groups=1, num_channels=out_channels)
168
+
169
+ # Shortcut connection to match the shape when in_channels != out_channels
170
+ self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) if in_channels != out_channels else nn.Identity()
171
+
172
+ # activation
173
+ self.activation = nn.SiLU()
174
+
175
+ def forward(self,x):
176
+ # Save the input for the residual connection
177
+ identity = x
178
+
179
+ # First conv-norm-activation
180
+ out = self.conv1(x)
181
+ out = self.norm1(out)
182
+ out = self.activation(out)
183
+
184
+ # Second conv-norm-activation
185
+ out = self.conv2(out)
186
+ out = self.norm2(out)
187
+
188
+ # Add the residual (shortcut connection)
189
+ identity = self.shortcut(identity)
190
+ out += identity
191
+
192
+ # Apply ReLU after adding the residual
193
+ out = self.activation(out)
194
+
195
+ return out
196
+
197
+
198
+ # attention block
199
+ # we apply attention block at a resolution level of 17
200
+
201
+ class AttentionBlock(nn.Module):
202
+ """
203
+ MHA expects input of shape (batch_size, seq_length, embedding_dim).
204
+ Our data shape (batch_size, num_channels, resolution), e.g. (64,128,17).
205
+ embedding_dim == resolution
206
+ seq_length = num_channels
207
+ """
208
+ def __init__(self, emb_dim, num_heads=1, dropout=0.1):
209
+ super(AttentionBlock, self).__init__()
210
+
211
+ self.attention = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, dropout=dropout)
212
+
213
+ # Layer normalization
214
+ self.layernorm1 = nn.LayerNorm(emb_dim)
215
+ self.layernorm2 = nn.LayerNorm(emb_dim)
216
+
217
+ # Feed-forward network (after attention)
218
+ self.ffn = nn.Sequential(
219
+ nn.Linear(emb_dim, emb_dim * 4),
220
+ nn.ReLU(),
221
+ nn.Linear(emb_dim * 4, emb_dim)
222
+ )
223
+
224
+ def forward(self, x):
225
+
226
+ # Apply self-attention
227
+ attn_output, attn_weights = self.attention(x, x, x) # Query, Key, Value all are x (self-attention)
228
+
229
+ # Residual connection followed by layer normalization : add & norm
230
+ x = self.layernorm1(x + attn_output)
231
+
232
+ # Feed-forward network
233
+ ffn_output = self.ffn(x)
234
+
235
+ # Residual connection followed by layer normalization : add & norm
236
+ x = self.layernorm2(x + ffn_output)
237
+
238
+ return x
239
+
240
+
241
+ class UNet(nn.Module):
242
+ def __init__(self, in_c, out_c, time_emb_dim):
243
+ super(UNet, self).__init__()
244
+
245
+ # map noise labels
246
+ self.map_noise = PositionalEmbedding(emb_dim=time_emb_dim)
247
+
248
+ # encoder path
249
+ self.te1 = make_te(time_emb_dim, 3)
250
+ self.b1 = nn.Sequential(ResNetBlock(3, 32),ResNetBlock(32, 32),ResNetBlock(32,32)) # shape = [batch,32,136]
251
+ self.down1 = nn.Conv1d(32, 32, 4, 2, 1) # shape = [batch,32,68]
252
+
253
+ self.te2 = make_te(time_emb_dim, 32)
254
+ self.b2 = nn.Sequential(ResNetBlock(32, 64),ResNetBlock(64, 64),ResNetBlock(64,64)) # shape = [batch,64,68]
255
+ self.down2 = nn.Conv1d(64, 64, 4, 2, 1) # shape = [batch,64,34]
256
+
257
+ self.te3 = make_te(time_emb_dim, 64)
258
+ self.b3 = nn.Sequential(ResNetBlock(64, 128),ResNetBlock(128, 128),ResNetBlock(128,128)) # shape = [batch,128,34]
259
+ self.down3 = nn.Conv1d(128, 128, 4, 2, 1) # shape = [batch,128,17]
260
+
261
+ self.attn1 = AttentionBlock(emb_dim=17)
262
+
263
+ # Bottleneck
264
+ self.te_mid = make_te(time_emb_dim, 128)
265
+ self.b_mid = nn.Sequential(ResNetBlock(128, 64),ResNetBlock(64, 64),ResNetBlock(64,128)) # shape = [batch,128,17]
266
+
267
+ self.attn2 = AttentionBlock(emb_dim=17)
268
+
269
+
270
+ # decoder path
271
+ self.up1 = nn.ConvTranspose1d(128, 128, 4, 2, 1) # shape = [batch,128,34]
272
+ self.te4 = make_te(time_emb_dim, 256)
273
+ self.b4 = nn.Sequential(ResNetBlock(256,128),ResNetBlock(128,64),ResNetBlock(64,64)) # shape = [batch,64,34]
274
+
275
+ self.up2 = nn.ConvTranspose1d(64, 64, 4, 2, 1) # shape = [batch,64,68]
276
+ self.te5 = make_te(time_emb_dim, 128)
277
+ self.b5 = nn.Sequential(ResNetBlock(128, 64),ResNetBlock(64, 32),ResNetBlock(32, 32)) # shape = [batch,32,68]
278
+
279
+ self.up3 = nn.ConvTranspose1d(32, 32, 4, 2, 1) # shape = [batch,32,136]
280
+ self.te6 = make_te(time_emb_dim, 64)
281
+ self.b6 = nn.Sequential(ResNetBlock(64, 32),ResNetBlock(32, 32),ResNetBlock(32, 32)) # shape = [batch,32,136]
282
+
283
+ # output
284
+ self.conv_out = nn.Conv1d(32, 3, 3, 1, 1) # shape = [batch,3,136]
285
+
286
+ def forward(self, x, noise_labels):
287
+ t = self.map_noise(noise_labels)
288
+ n = len(x)
289
+ out1 = self.b1(x + self.te1(t).reshape(n, -1, 1))
290
+ out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1))
291
+ out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1))
292
+
293
+ out_mid = self.b_mid(self.attn1(self.down3(out3)) + self.te_mid(t).reshape(n, -1, 1))
294
+
295
+ # self-attention
296
+ out_mid = self.attn2(out_mid)
297
+
298
+ out4 = torch.cat((out3,self.up1(out_mid)), dim=1)
299
+ out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1))
300
+ out5 = torch.cat((out2, self.up2(out4)), dim=1)
301
+ out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1))
302
+
303
+ out6 = torch.cat((out1, self.up3(out5)), dim=1)
304
+ out6 = self.b6(out6 + self.te6(t).reshape(n, -1, 1))
305
+
306
+ out = self.conv_out(out6)
307
+ return out
308
+
309
+
310
+
311
+
312
+
313
+
314
+
315
+
316
+
diffcrysgen/sampler.py ADDED
@@ -0,0 +1,116 @@
1
+
2
+ import torch
3
+ import numpy as np
4
+ import tqdm
5
+ import time
6
+ from diffcrysgen.model import Model, UNet
7
+ from diffcrysgen.utils import *
8
+
9
+
10
+ # Adapted from "Elucidating the Design Space of Diffusion-Based
11
+ # Generative Models" by Karras et al.
12
+
13
+
14
+ def round_sigma(sigma):
15
+ return torch.as_tensor(sigma)
16
+
17
+
18
+ def generate_samples(num_samples: int = 100, batch_size: int = 1000, model_path: str = "assets/saved-model/sdm.pt"):
19
+ """
20
+ Generate IRCR samples using our pre-trained version.
21
+
22
+ Args:
23
+ num_samples (int): Total number of samples to generate.
24
+ batch_size (int): Batch size for sampling.
25
+ model_path (str): Path to the saved model (.pt file).
26
+
27
+ Returns:
28
+ np.ndarray: Generated IRCR array of shape (num_samples, F, C)
29
+ """
30
+ # Load scaler
31
+ scaler = load_saved_diffusion_scaler()
32
+ print("Scaler loaded.")
33
+
34
+ # Load device
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ print("Using device:", device)
37
+
38
+ # Load the model
39
+ denoise_fn = UNet(in_c=3, out_c=3, time_emb_dim=256).to(device)
40
+ model = Model(denoise_fn=denoise_fn).to(device)
41
+ model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
42
+ model.eval()
43
+ print("Model loaded from:", model_path)
44
+
45
+ # Sampling hyperparameters
46
+ num_steps = 100
47
+ sigma_min = 0.002
48
+ sigma_max = 80
49
+ rho = 7
50
+ S_churn = 1
51
+ S_min = 0
52
+ S_max = float("inf")
53
+ S_noise = 1
54
+
55
+ net = model.denoise_fn_D
56
+ sigma_min = max(sigma_min, net.sigma_min)
57
+ sigma_max = min(sigma_max, net.sigma_max)
58
+
59
+ # Time steps
60
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
61
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
62
+ t_steps = torch.cat([round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # append t_N = 0
63
+
64
+ img_channels = 3
65
+ img_resolution = 136
66
+
67
+ total_loops = int(np.ceil(num_samples / batch_size))
68
+ final_data = []
69
+
70
+ start_time = time.time()
71
+
72
+ with torch.no_grad():
73
+ for loop in range(total_loops):
74
+ curr_batch_size = min(batch_size, num_samples - loop * batch_size)
75
+ latents = torch.randn([curr_batch_size, img_channels, img_resolution], device=device)
76
+ x_next = latents * t_steps[0]
77
+
78
+ # Main Sampling Loop
79
+ for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), desc=f"Sampling {loop+1}/{total_loops}", leave=False):
80
+ x_cur = x_next
81
+
82
+ # Increase Noise temporarily
83
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
84
+ t_hat = round_sigma(t_cur + gamma * t_cur)
85
+
86
+ t_cur = t_cur.repeat(curr_batch_size)
87
+ t_next = t_next.repeat(curr_batch_size)
88
+ t_hat = t_hat.repeat(curr_batch_size)
89
+
90
+ x_hat = x_cur + (t_hat[:, None, None] ** 2 - t_cur[:, None, None] ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
91
+
92
+ denoised = net(x_hat, t_hat).float()
93
+ d_cur = (x_hat - denoised) / t_hat[:, None, None]
94
+ x_next = x_hat + (t_next[:, None, None] - t_hat[:, None, None]) * d_cur
95
+
96
+ if i < num_steps - 1:
97
+ denoised = net(x_next, t_next).float()
98
+ d_prime = (x_next - denoised) / t_next[:, None, None]
99
+ x_next = x_hat + (t_next[:, None, None] - t_hat[:, None, None]) * (0.5 * d_cur + 0.5 * d_prime)
100
+
101
+ x_next = x_next.permute(0, 2, 1).cpu().numpy() # shape: (batch, F, C)
102
+ x_next = (x_next + 1) / 2
103
+ x_next = inv_minmax(x_next, scaler)
104
+ final_data.append(x_next)
105
+
106
+ print(f"Sampling batch {loop + 1}/{total_loops} done.")
107
+
108
+ elapsed = (time.time() - start_time) / 60
109
+ print(f"Total sampling time: {elapsed:.2f} minutes")
110
+
111
+ generated_array = np.concatenate(final_data, axis=0)[:num_samples]
112
+ print("Generated IRCR shape:", generated_array.shape)
113
+ return generated_array
114
+
115
+
116
+
diffcrysgen/trainer.py ADDED
@@ -0,0 +1,128 @@
1
+
2
+
3
+ import time
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ from torch.utils.data import DataLoader
8
+ from torch.optim.lr_scheduler import ReduceLROnPlateau as RLROP
9
+ from tqdm.auto import tqdm
10
+ from model import UNet, Model
11
+ from utils import minmax
12
+
13
+ def load_data(train_path, test_path, batch_size=64):
14
+ # Original IRCR has shape [batch,142,3] : IRCR={E,L,C,O,P}
15
+ # We are not considering property matrix (P) for unconditional generation.
16
+ train_data = np.load(train_path)[:, :136, :]
17
+ test_data = np.load(test_path)[:, :136, :]
18
+
19
+ print(f"Training data shape: {train_data.shape}")
20
+ print(f"Test data shape: {test_data.shape}")
21
+
22
+ # Normalize [0, 1]
23
+ train_scaled, train_scaler = minmax(train_data)
24
+ test_scaled, test_scaler = minmax(test_data)
25
+
26
+ # Normalize to [-1, 1]
27
+ train_scaled = 2 * train_scaled - 1
28
+ test_scaled = 2 * test_scaled - 1
29
+
30
+ # Convert to torch tensors (N, 3, F)
31
+ train_tensor = torch.from_numpy(train_scaled).float().permute(0, 2, 1)
32
+ test_tensor = torch.from_numpy(test_scaled).float().permute(0, 2, 1)
33
+
34
+ print(f"Transformed training shape: {train_tensor.shape}")
35
+ print(f"Transformed test shape: {test_tensor.shape}")
36
+ print(f"Train value range: {train_tensor.min().item()} to {train_tensor.max().item()}")
37
+
38
+ train_loader = DataLoader(train_tensor, batch_size=batch_size)
39
+ test_loader = DataLoader(test_tensor, batch_size=batch_size)
40
+
41
+ return train_loader, test_loader, train_scaler, test_scaler
42
+
43
+
44
+ def train_diffcrysgen(
45
+ train_loader,
46
+ test_loader,
47
+ num_epochs=1000,
48
+ save_path="sdm.pt",
49
+ initial_lr=5e-4,
50
+ device=None,
51
+ ):
52
+ if device is None:
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ print(f"Using device: {device}")
55
+
56
+ denoise_fn = UNet(in_c=3, out_c=3, time_emb_dim=256).to(device)
57
+ model = Model(denoise_fn=denoise_fn).to(device)
58
+
59
+ optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr, weight_decay=0)
60
+ scheduler = RLROP(optimizer, mode="min", factor=0.3, patience=50, verbose=True)
61
+
62
+ best_loss = float("inf")
63
+ train_losses = []
64
+ test_losses = []
65
+ lrs = []
66
+
67
+ start_time = time.time()
68
+
69
+ for epoch in tqdm(range(num_epochs), desc="Training Progress", colour="green"):
70
+ # Manual LR scheduling based on epochs
71
+ if epoch == 500:
72
+ for group in optimizer.param_groups:
73
+ group["lr"] = 1e-4
74
+ elif epoch == 800:
75
+ for group in optimizer.param_groups:
76
+ group["lr"] = 5e-5
77
+
78
+ model.train()
79
+ train_loss = 0
80
+ for batch in tqdm(train_loader, leave=False, desc=f"Epoch {epoch+1}/{num_epochs}", colour="blue"):
81
+ batch = batch.to(device)
82
+ loss = model(batch)
83
+ train_loss += loss.item()
84
+
85
+ optimizer.zero_grad()
86
+ loss.backward()
87
+ optimizer.step()
88
+
89
+ train_loss /= len(train_loader)
90
+ train_losses.append(train_loss)
91
+
92
+ model.eval()
93
+ test_loss = 0
94
+ with torch.no_grad():
95
+ for batch in tqdm(test_loader, leave=False, desc=f"Epoch {epoch+1}/{num_epochs}", colour="blue"):
96
+ batch = batch.to(device)
97
+ loss = model(batch)
98
+ test_loss += loss.item()
99
+ test_loss /= len(test_loader)
100
+ test_losses.append(test_loss)
101
+
102
+ # step the scheduler
103
+ # scheduler.step(test_loss)
104
+
105
+ current_lr = optimizer.param_groups[0]["lr"]
106
+ lrs.append(current_lr)
107
+
108
+ log_msg = f"Epoch {epoch+1}: train_loss: {train_loss:.4f}, test_loss: {test_loss:.4f}, lr: {current_lr:.6f}"
109
+ if test_loss < best_loss:
110
+ best_loss = test_loss
111
+ torch.save(model.state_dict(), save_path)
112
+ log_msg += " --> Best model ever (stored)"
113
+ print(log_msg)
114
+
115
+ # scheduler.step(val_loss)
116
+
117
+ print(f"Training completed in {(time.time() - start_time):.2f}s")
118
+
119
+ # Save epoch-wise loss
120
+ results_df = pd.DataFrame({
121
+ "epoch": np.arange(num_epochs),
122
+ "train_loss": train_losses,
123
+ "test_loss": test_losses,
124
+ "lr": lrs
125
+ })
126
+ results_df.to_csv("epoch-loss.csv", index=False)
127
+ print(results_df.tail())
128
+
diffcrysgen/utils.py ADDED
@@ -0,0 +1,38 @@
1
+
2
+ import os
3
+ import joblib
4
+ import numpy as np
5
+ from sklearn.preprocessing import MinMaxScaler
6
+
7
+ # This is the normalization script adapted from FTCP work.
8
+ # It works for FTCP, IRCR.
9
+
10
+ def minmax(pcr):
11
+ dim0, dim1, dim2 = pcr.shape
12
+ scaler = MinMaxScaler()
13
+ pcr_ = np.transpose(pcr, (1, 0, 2))
14
+ pcr_ = pcr_.reshape(dim1, dim0*dim2)
15
+ pcr_ = scaler.fit_transform(pcr_.T)
16
+ pcr_ = pcr_.T
17
+ pcr_ = pcr_.reshape(dim1, dim0, dim2)
18
+ pcr_normed = np.transpose(pcr_, (1, 0, 2))
19
+ return pcr_normed, scaler
20
+
21
+ def inv_minmax(pcr_normed, scaler):
22
+ dim0, dim1, dim2 = pcr_normed.shape
23
+ pcr_ = np.transpose(pcr_normed, (1, 0, 2))
24
+ pcr_ = pcr_.reshape(dim1, dim0*dim2)
25
+ pcr_ = scaler.inverse_transform(pcr_.T)
26
+ pcr_ = pcr_.T
27
+ pcr_ = pcr_.reshape(dim1, dim0, dim2)
28
+ pcr = np.transpose(pcr_, (1, 0, 2))
29
+ return pcr
30
+
31
+ def load_saved_diffusion_scaler(path=None):
32
+ if path is None:
33
+ here = os.path.dirname(__file__)
34
+ path = os.path.join(here, "..", "assets", "ircr_diffusion_scaler.pkl")
35
+ path = os.path.abspath(path)
36
+ return joblib.load(path)
37
+
38
+
@@ -0,0 +1,103 @@
1
+ Metadata-Version: 2.4
2
+ Name: diffcrysgen
3
+ Version: 0.1.0
4
+ Summary: "DiffCrysGen is a score-based diffusion model for accelerated design of diverse inorganic crystalline materials."
5
+ Home-page: https://github.com/SouravMal/DiffCrysGen.git
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: Programming Language :: Python :: 3.11
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Topic :: Scientific/Engineering :: Chemistry
12
+ Classifier: Topic :: Scientific/Engineering :: Physics
13
+ Requires-Python: >=3.11
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Requires-Dist: torch~=2.5.1
17
+ Requires-Dist: numpy~=2.0.1
18
+ Requires-Dist: pandas~=2.2.3
19
+ Requires-Dist: scikit-learn~=1.6.1
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: ase~=3.24.0
22
+ Requires-Dist: spglib~=2.5.0
23
+ Requires-Dist: joblib
24
+ Dynamic: license-file
25
+
26
+ # DiffCrysGen [![Project Version](https://img.shields.io/badge/version-v0.1.0-blue)](https://github.com/SouravMal/DiffCrysGen)
27
+
28
+ DiffCrysGen is a score-based diffusion model. It treats the entire materials representation with a single, unified diffusion process, allowing a single denosing neural network to predict a holistic score for the entire noisy crystal data. This unified treatment significantly simplifies the architecture and improves the computational efficiency.
29
+
30
+
31
+ <p align="center">
32
+ <img src="images/logo-DiffCrysGen.png" alt="DiffCrysGen Logo" width="350"/>
33
+ </p>
34
+
35
+
36
+ ## Generative diffusion framework in DiffCrysGen
37
+ <img src="images/diffusion-schematic.png" alt="DiffCrysGen Schematic" width="550">
38
+
39
+ ---
40
+
41
+ ## Architecture of the denoising neural network
42
+ <img src="images/architecture.png" alt="DiffCrysGen Architecture" width="750">
43
+
44
+ ---
45
+
46
+ ## Installation
47
+
48
+ ### Prerequisites
49
+
50
+ The package requires specific environments and dependencies.
51
+ Using a virtual environment is highly recommended.
52
+ **Environment Setup using Conda**
53
+
54
+ ```
55
+ conda create -n diffcrysgen python=3.11
56
+ conda activate diffcrysgen
57
+ ```
58
+
59
+ ### Install from PyPI
60
+ ```
61
+ pip install diffcrysgen
62
+ ```
63
+
64
+ ### Install from Source Code
65
+ ```
66
+ git clone https://github.com/SouravMal/DiffCrysGen.git
67
+ cd DiffCrysGen
68
+ pip install -e .
69
+ ```
70
+
71
+ ## Quick Start
72
+ For a simple walkthrough of generating materials and analyzing them, see the [DiffCrysGen Demo Notebook](./notebooks/DiffCrysGen-demo.ipynb).
73
+
74
+ ## License
75
+
76
+ This project is licensed under the **MIT License**.
77
+
78
+ See the [LICENSE](LICENSE) file for details.
79
+
80
+ Developed by: [Sourav Mal](https://github.com/SouravMal) at Harish-Chandra Research Institute (HRI) (https://www.hri.res.in/), Prayagraj, India.
81
+
82
+
83
+ ## Citation
84
+
85
+ Please consider citing our work if you find it helpful:
86
+
87
+ ```bibtex
88
+ @misc{mal2025generativediffusionmodeldiffcrysgen,
89
+ title={Generative Diffusion Model DiffCrysGen Discovers Rare Earth-Free Magnetic Materials},
90
+ author={Sourav Mal and Nehad Ahmed and Subhankar Mishra and Prasenjit Sen},
91
+ year={2025},
92
+ eprint={2510.12329},
93
+ archivePrefix={arXiv},
94
+ primaryClass={cond-mat.mtrl-sci},
95
+ url={https://arxiv.org/abs/2510.12329},
96
+ }
97
+ ```
98
+
99
+
100
+ ## Contact
101
+
102
+ If you have any questions, feel free to reach us at:
103
+ **Sourav Mal** <souravmal492@gmail.com>
@@ -0,0 +1,12 @@
1
+ diffcrysgen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ diffcrysgen/analyze_generated_structures.py,sha256=fbybCbvH0o_xv_1v_tSV1OfbDEau7FHPs8oHyCAlf0Q,3475
3
+ diffcrysgen/invert_pcr.py,sha256=Q6j7AVb6iZWR7t3VjB1yFLnuPB9hCc4_0Vp1K72gcak,5944
4
+ diffcrysgen/model.py,sha256=sb42SQJPm9nQNSGMNN88RXwWE0tIsompypFZxmJob3k,11166
5
+ diffcrysgen/sampler.py,sha256=k5r2D0MmWcQRebcg0KBugl1AEg6FPrnvfobZ_uyznP0,4092
6
+ diffcrysgen/trainer.py,sha256=SI988DGpG0rgfixV9Sz9Ue1Lk5qG6uZvC1ka-BP9KOA,4189
7
+ diffcrysgen/utils.py,sha256=FI6bkDGMt7E5_WbL55la--MGkOryz11KBgmVfivBvao,1094
8
+ diffcrysgen-0.1.0.dist-info/licenses/LICENSE,sha256=7WpR2tbfkkPUcuvEnP_B3ijJczHpi-_U_i2ngeJGqSY,1067
9
+ diffcrysgen-0.1.0.dist-info/METADATA,sha256=KhTsxU3Q6uqtvBWe9UrwHrrkUcFyMDvOJrcCt3x-_t4,3209
10
+ diffcrysgen-0.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
11
+ diffcrysgen-0.1.0.dist-info/top_level.txt,sha256=aN4XuKXo6kncvD9N8mx9SOVKSFfNqBglK7G9jm5u3Fs,12
12
+ diffcrysgen-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Sourav Mal
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ diffcrysgen