calphy 1.4.4__py3-none-any.whl → 1.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
calphy/helpers.py CHANGED
@@ -3,14 +3,14 @@ calphy: a Python library and command line interface for automated free
3
3
  energy calculations.
4
4
 
5
5
  Copyright 2021 (c) Sarath Menon^1, Yury Lysogorskiy^2, Ralf Drautz^2
6
- ^1: Max Planck Institut für Eisenforschung, Dusseldorf, Germany
6
+ ^1: Max Planck Institut für Eisenforschung, Dusseldorf, Germany
7
7
  ^2: Ruhr-University Bochum, Bochum, Germany
8
8
 
9
- calphy is published and distributed under the Academic Software License v1.0 (ASL).
10
- calphy is distributed in the hope that it will be useful for non-commercial academic research,
11
- but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
9
+ calphy is published and distributed under the Academic Software License v1.0 (ASL).
10
+ calphy is distributed in the hope that it will be useful for non-commercial academic research,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
12
12
  calphy API is published and distributed under the BSD 3-Clause "New" or "Revised" License
13
- See the LICENSE FILE for more details.
13
+ See the LICENSE FILE for more details.
14
14
 
15
15
  More information about the program can be found in:
16
16
  Menon, Sarath, Yury Lysogorskiy, Jutta Rogal, and Ralf Drautz.
@@ -34,6 +34,7 @@ from ase.io import read, write
34
34
  import pyscal3.core as pc
35
35
  from pyscal3.trajectory import Trajectory
36
36
 
37
+
37
38
  class LammpsScript:
38
39
  def __init__(self):
39
40
  self.script = []
@@ -42,12 +43,14 @@ class LammpsScript:
42
43
  self.script.append(command_str)
43
44
 
44
45
  def write(self, infile):
45
- with open(infile, 'w') as fout:
46
+ with open(infile, "w") as fout:
46
47
  for line in self.script:
47
- fout.write(f'{line}\n')
48
+ fout.write(f"{line}\n")
49
+
48
50
 
49
- def create_object(cores, directory, timestep, cmdargs="",
50
- init_commands=(), script_mode=False):
51
+ def create_object(
52
+ cores, directory, timestep, cmdargs="", init_commands=(), script_mode=False
53
+ ):
51
54
  """
52
55
  Create LAMMPS object
53
56
 
@@ -71,28 +74,28 @@ def create_object(cores, directory, timestep, cmdargs="",
71
74
  else:
72
75
  if cmdargs == "":
73
76
  cmdargs = None
74
- lmp = LammpsLibrary(
75
- cores=cores, working_directory=directory, cmdargs=cmdargs
76
- )
77
+ lmp = LammpsLibrary(cores=cores, working_directory=directory, cmdargs=cmdargs)
77
78
 
78
- commands = [["units", "metal"],
79
- ["boundary", "p p p"],
80
- ["atom_style", "atomic"],
81
- ["timestep", str(timestep)],
82
- ["box", "tilt large"]]
79
+ commands = [
80
+ ["units", "metal"],
81
+ ["boundary", "p p p"],
82
+ ["atom_style", "atomic"],
83
+ ["timestep", str(timestep)],
84
+ ["box", "tilt large"],
85
+ ]
83
86
 
84
87
  if len(init_commands) > 0:
85
- #we need to replace some initial commands
88
+ # we need to replace some initial commands
86
89
  for rc in init_commands:
87
- #split the command
90
+ # split the command
88
91
  raw = rc.split()
89
92
  for x in range(len(commands)):
90
93
  if raw[0] == commands[x][0]:
91
- #we found a matching command
94
+ # we found a matching command
92
95
  commands[x] = [rc]
93
96
  break
94
97
  else:
95
- #its a new command, add it to the list
98
+ # its a new command, add it to the list
96
99
  commands.append([rc])
97
100
 
98
101
  for command in commands:
@@ -121,12 +124,12 @@ def create_structure(lmp, calc):
121
124
 
122
125
 
123
126
  def set_mass(lmp, options):
124
- if options.mode == 'composition_scaling':
125
- lmp.command(f'mass * {options.mass[-1]}')
127
+ if options.mode == "composition_scaling":
128
+ lmp.command(f"mass * {options.mass[-1]}")
126
129
 
127
130
  else:
128
131
  for i in range(options.n_elements):
129
- lmp.command(f'mass {i+1} {options.mass[i]}')
132
+ lmp.command(f"mass {i+1} {options.mass[i]}")
130
133
  return lmp
131
134
 
132
135
 
@@ -144,19 +147,21 @@ def set_potential(lmp, options):
144
147
  -------
145
148
  lmp : LammpsLibrary object
146
149
  """
147
- #lmp.pair_style(options.pair_style_with_options[0])
148
- #lmp.pair_coeff(options.pair_coeff[0])
149
- lmp.command(f'pair_style {options._pair_style_with_options[0]}')
150
- lmp.command(f'pair_coeff {options.pair_coeff[0]}')
150
+ # lmp.pair_style(options.pair_style_with_options[0])
151
+ # lmp.pair_coeff(options.pair_coeff[0])
152
+ lmp.command(f"pair_style {options._pair_style_with_options[0]}")
153
+ lmp.command(f"pair_coeff {options.pair_coeff[0]}")
151
154
 
152
155
  lmp = set_mass(lmp, options)
153
156
 
154
157
  return lmp
155
158
 
159
+
156
160
  def read_data(lmp, file):
157
161
  lmp.command(f"read_data {file}")
158
162
  return lmp
159
163
 
164
+
160
165
  def get_structures(file, species, index=None):
161
166
  traj = Trajectory(file)
162
167
  if index is None:
@@ -165,6 +170,7 @@ def get_structures(file, species, index=None):
165
170
  aseobjs = traj[index].to_ase(species=species)
166
171
  return aseobjs
167
172
 
173
+
168
174
  def remap_box(lmp, x, y, z):
169
175
  lmp.command("run 0")
170
176
  lmp.command(
@@ -222,8 +228,15 @@ def write_data(lmp, file):
222
228
  lmp.command(f"write_data {file}")
223
229
  return lmp
224
230
 
231
+
225
232
  def prepare_log(file, screen=False):
226
233
  logger = logging.getLogger(__name__)
234
+
235
+ # Remove all existing handlers to prevent duplicate logging
236
+ for handler in logger.handlers[:]:
237
+ handler.close()
238
+ logger.removeHandler(handler)
239
+
227
240
  handler = logging.FileHandler(file)
228
241
  formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
229
242
  handler.setFormatter(formatter)
@@ -238,6 +251,7 @@ def prepare_log(file, screen=False):
238
251
  logger.addHandler(scr)
239
252
  return logger
240
253
 
254
+
241
255
  def check_if_any_is_none(data):
242
256
  """
243
257
  Check if any elements of a list is None, if so return True
calphy/input.py CHANGED
@@ -49,7 +49,7 @@ from pyscal3.core import structure_dict, element_dict, _make_crystal
49
49
  from ase.io import read, write
50
50
  import shutil
51
51
 
52
- __version__ = "1.4.4"
52
+ __version__ = "1.4.6"
53
53
 
54
54
 
55
55
  def _check_equal(val):
@@ -92,6 +92,49 @@ def _to_float(val):
92
92
  return [float(x) for x in val]
93
93
 
94
94
 
95
+ def _extract_elements_from_pair_coeff(pair_coeff_string):
96
+ """
97
+ Extract element symbols from pair_coeff string.
98
+ Returns None if pair_coeff doesn't contain element specifications.
99
+
100
+ Parameters
101
+ ----------
102
+ pair_coeff_string : str
103
+ The pair_coeff command string (e.g., "* * potential.eam.fs Cu Zr")
104
+
105
+ Returns
106
+ -------
107
+ list or None
108
+ List of element symbols in order, or None if no elements found
109
+ """
110
+ if pair_coeff_string is None:
111
+ return None
112
+
113
+ pcsplit = pair_coeff_string.strip().split()
114
+ elements = []
115
+
116
+ # Start collecting after we find element symbols
117
+ # Elements are typically after the potential filename
118
+ started = False
119
+
120
+ for p in pcsplit:
121
+ # Check if this looks like an element symbol
122
+ # Element symbols are 1-2 characters, start with uppercase
123
+ if len(p) <= 2 and p[0].isupper():
124
+ try:
125
+ # Verify it's a valid element using mendeleev
126
+ _ = mendeleev.element(p)
127
+ elements.append(p)
128
+ started = True
129
+ except:
130
+ # Not a valid element, might be done collecting
131
+ if started:
132
+ # We already started collecting elements and hit a non-element
133
+ break
134
+
135
+ return elements if len(elements) > 0 else None
136
+
137
+
95
138
  class UFMP(BaseModel, title="UFM potential input options"):
96
139
  p: Annotated[float, Field(default=50.0)]
97
140
  sigma: Annotated[float, Field(default=1.5)]
@@ -165,7 +208,6 @@ class Queue(BaseModel, title="Options for configuring queue"):
165
208
  memory: Annotated[str, Field(default="3GB")]
166
209
  commands: Annotated[List, Field(default=[])]
167
210
  options: Annotated[List, Field(default=[])]
168
- modules: Annotated[List, Field(default=[])]
169
211
 
170
212
 
171
213
  class Tolerance(BaseModel, title="Tolerance settings for convergence"):
@@ -181,10 +223,17 @@ class MeltingTemperature(BaseModel, title="Input options for melting temperature
181
223
  step: Annotated[int, Field(default=200, ge=20)]
182
224
  attempts: Annotated[int, Field(default=5, ge=1)]
183
225
 
184
- class MaterialsProject(BaseModel, title='Input options for materials project'):
226
+
227
+ class MaterialsProject(BaseModel, title="Input options for materials project"):
185
228
  api_key: Annotated[str, Field(default="", exclude=True)]
186
229
  conventional: Annotated[bool, Field(default=True)]
187
- target_natoms: Annotated[int, Field(default=1500, description='The structure parsed from materials project would be repeated to approximately this value')]
230
+ target_natoms: Annotated[
231
+ int,
232
+ Field(
233
+ default=1500,
234
+ description="The structure parsed from materials project would be repeated to approximately this value",
235
+ ),
236
+ ]
188
237
 
189
238
  @field_validator("api_key", mode="after")
190
239
  def resolve_api_key(cls, v: str) -> str:
@@ -198,6 +247,7 @@ class MaterialsProject(BaseModel, title='Input options for materials project'):
198
247
  )
199
248
  return value
200
249
 
250
+
201
251
  class Calculation(BaseModel, title="Main input class"):
202
252
  monte_carlo: Optional[MonteCarlo] = MonteCarlo()
203
253
  composition_scaling: Optional[CompositionScaling] = CompositionScaling()
@@ -306,6 +356,43 @@ class Calculation(BaseModel, title="Main input class"):
306
356
 
307
357
  self.n_elements = len(self.element)
308
358
 
359
+ # Validate element/mass/pair_coeff ordering consistency
360
+ # This is critical for multi-element systems where LAMMPS type numbers
361
+ # are assigned based on element order: element[0]=Type1, element[1]=Type2, etc.
362
+ if (
363
+ len(self.element) > 1
364
+ and self.pair_coeff is not None
365
+ and len(self.pair_coeff) > 0
366
+ ):
367
+ extracted_elements = _extract_elements_from_pair_coeff(self.pair_coeff[0])
368
+
369
+ if extracted_elements is not None:
370
+ # pair_coeff specifies elements - check ordering
371
+ if set(extracted_elements) != set(self.element):
372
+ raise ValueError(
373
+ f"Element mismatch between 'element' and 'pair_coeff'!\n"
374
+ f" element: {self.element}\n"
375
+ f" pair_coeff: {extracted_elements}\n"
376
+ f"The elements specified must be the same."
377
+ )
378
+
379
+ if list(extracted_elements) != list(self.element):
380
+ raise ValueError(
381
+ f"Element ordering mismatch detected!\n\n"
382
+ f" element: {self.element}\n"
383
+ f" pair_coeff: {extracted_elements}\n"
384
+ f" mass: {self.mass}\n\n"
385
+ f"For multi-element systems, all three must be in the SAME order.\n\n"
386
+ f"Why this matters:\n"
387
+ f" - Element order determines LAMMPS type numbers:\n"
388
+ f" element[0] → Type 1, element[1] → Type 2, etc.\n"
389
+ f" - The pair_coeff elements must match this type order\n"
390
+ f" - The mass values must correspond to the same order\n"
391
+ f" - Composition transformations depend on this ordering\n\n"
392
+ f"Please reorder your input so element, mass, and pair_coeff\n"
393
+ f"all use the same element ordering."
394
+ )
395
+
309
396
  self._pressure_input = copy.copy(self.pressure)
310
397
  if self.pressure is None:
311
398
  self._iso = True
@@ -515,40 +602,58 @@ class Calculation(BaseModel, title="Main input class"):
515
602
  self._original_lattice = self.lattice.lower()
516
603
  write_structure_file = True
517
604
 
518
- elif self.lattice.split('-')[0] == 'mp':
519
- #confirm here that API key exists
605
+ elif self.lattice.split("-")[0] == "mp":
606
+ # confirm here that API key exists
520
607
  if not self.materials_project.api_key:
521
- raise ValueError('could not find API KEY, pls set it.')
522
- #now we need to fetch the structure
608
+ raise ValueError("could not find API KEY, pls set it.")
609
+ # now we need to fetch the structure
523
610
  try:
524
611
  from mp_api.client import MPRester
525
612
  except ImportError:
526
- raise ImportError('Could not import mp_api, make sure you install mp_api package!')
527
- #now all good
613
+ raise ImportError(
614
+ "Could not import mp_api, make sure you install mp_api package!"
615
+ )
616
+ # now all good
528
617
  rest = {
529
- "use_document_model": False,
530
- "include_user_agent": True,
531
- "api_key": self.materials_project.api_key,
532
- }
618
+ "use_document_model": False,
619
+ "include_user_agent": True,
620
+ "api_key": self.materials_project.api_key,
621
+ }
533
622
  with MPRester(**rest) as mpr:
534
623
  docs = mpr.materials.summary.search(material_ids=[self.lattice])
535
624
 
536
625
  structures = []
537
626
  for doc in docs:
538
- struct = doc['structure']
627
+ struct = doc["structure"]
539
628
  if self.materials_project.conventional:
540
629
  aseatoms = struct.to_conventional().to_ase_atoms()
541
630
  else:
542
631
  aseatoms = struct.to_primitive().to_ase_atoms()
543
632
  structures.append(aseatoms)
544
633
  structure = structures[0]
545
-
634
+
546
635
  if np.prod(self.repeat) == 1:
547
- x = int(np.ceil((self.materials_project.target_natoms/len(structure))**(1/3)))
636
+ x = int(
637
+ np.ceil(
638
+ (self.materials_project.target_natoms / len(structure))
639
+ ** (1 / 3)
640
+ )
641
+ )
548
642
  structure = structure.repeat(x)
549
643
  else:
550
644
  structure = structure.repeat(self.repeat)
551
645
 
646
+ # extract composition
647
+ types, typecounts = np.unique(
648
+ structure.get_chemical_symbols(), return_counts=True
649
+ )
650
+
651
+ for c, t in enumerate(types):
652
+ self._element_dict[t]["count"] = typecounts[c]
653
+ self._element_dict[t]["composition"] = typecounts[c] / np.sum(
654
+ typecounts
655
+ )
656
+
552
657
  self._natoms = len(structure)
553
658
  self._original_lattice = self.lattice.lower()
554
659
  write_structure_file = True
calphy/phase_diagram.py CHANGED
@@ -254,6 +254,141 @@ matcolors = {
254
254
  }
255
255
  }
256
256
 
257
+ def read_structure_composition(lattice_file, element_list):
258
+ """
259
+ Read a LAMMPS data file and determine the input chemical composition.
260
+
261
+ Parameters
262
+ ----------
263
+ lattice_file : str
264
+ Path to the LAMMPS data file
265
+ element_list : list
266
+ List of element symbols in order (element[0] = type 1, element[1] = type 2, etc.)
267
+
268
+ Returns
269
+ -------
270
+ dict
271
+ Dictionary mapping element symbols to atom counts
272
+ Elements not present in the structure will have count 0
273
+ """
274
+ from ase.io import read
275
+ from collections import Counter
276
+
277
+ # Read the structure file
278
+ structure = read(lattice_file, format='lammps-data', style='atomic')
279
+
280
+ # Get the species/types from the structure
281
+ # ASE reads LAMMPS types as species strings ('1', '2', etc.)
282
+ if 'species' in structure.arrays:
283
+ types_in_structure = structure.arrays['species']
284
+ else:
285
+ # Fallback: get atomic numbers and convert to strings
286
+ types_in_structure = [str(x) for x in structure.get_atomic_numbers()]
287
+
288
+ # Count atoms by type
289
+ type_counts = Counter(types_in_structure)
290
+
291
+ # Build composition mapping element names to counts
292
+ # element[0] corresponds to LAMMPS type '1', element[1] to type '2', etc.
293
+ input_chemical_composition = {}
294
+ for idx, element in enumerate(element_list):
295
+ lammps_type = str(idx + 1) # LAMMPS types are 1-indexed
296
+ input_chemical_composition[element] = type_counts.get(lammps_type, 0)
297
+
298
+ return input_chemical_composition
299
+
300
+
301
+ # Constants for phase diagram preparation
302
+ COMPOSITION_TOLERANCE = 1E-5
303
+
304
+
305
+ def _create_composition_array(comp_range, interval, reference):
306
+ """
307
+ Create composition array from range specification.
308
+
309
+ Parameters
310
+ ----------
311
+ comp_range : list or scalar
312
+ Composition range [min, max] or single value
313
+ interval : float
314
+ Composition interval
315
+ reference : float
316
+ Reference composition value
317
+
318
+ Returns
319
+ -------
320
+ tuple
321
+ (comp_arr, is_reference) - composition array and boolean array marking reference compositions
322
+ """
323
+ # Convert to list if scalar
324
+ if not isinstance(comp_range, list):
325
+ comp_range = [comp_range]
326
+
327
+ if len(comp_range) == 2:
328
+ comp_arr = np.arange(comp_range[0], comp_range[-1], interval)
329
+ last_val = comp_range[-1]
330
+ if last_val not in comp_arr:
331
+ comp_arr = np.append(comp_arr, last_val)
332
+ is_reference = np.abs(comp_arr - reference) < COMPOSITION_TOLERANCE
333
+ elif len(comp_range) == 1:
334
+ comp_arr = [comp_range[0]]
335
+ is_reference = [True]
336
+ else:
337
+ raise ValueError("Composition range should be scalar or list of two values!")
338
+
339
+ return comp_arr, is_reference
340
+
341
+
342
+ def _create_temperature_array(temp_range, interval):
343
+ """
344
+ Create temperature array from range specification.
345
+
346
+ Parameters
347
+ ----------
348
+ temp_range : list or scalar
349
+ Temperature range [min, max] or single value
350
+ interval : float
351
+ Temperature interval
352
+
353
+ Returns
354
+ -------
355
+ ndarray
356
+ Temperature array
357
+ """
358
+ # Convert to list if scalar
359
+ if not isinstance(temp_range, list):
360
+ temp_range = [temp_range]
361
+
362
+ if len(temp_range) == 2:
363
+ ntemps = int((temp_range[-1] - temp_range[0]) / interval) + 1
364
+ temp_arr = np.linspace(temp_range[0], temp_range[-1], ntemps, endpoint=True)
365
+ elif len(temp_range) == 1:
366
+ temp_arr = [temp_range[0]]
367
+ else:
368
+ raise ValueError("Temperature range should be scalar or list of two values!")
369
+
370
+ return temp_arr
371
+
372
+
373
+ def _add_temperature_calculations(calc_dict, temp_arr, all_calculations):
374
+ """
375
+ Helper to add calculations for each temperature point.
376
+
377
+ Parameters
378
+ ----------
379
+ calc_dict : dict
380
+ Base calculation dictionary
381
+ temp_arr : array
382
+ Array of temperatures
383
+ all_calculations : list
384
+ List to append calculations to
385
+ """
386
+ for temp in temp_arr:
387
+ calc_for_temp = copy.deepcopy(calc_dict)
388
+ calc_for_temp['temperature'] = int(temp)
389
+ all_calculations.append(calc_for_temp)
390
+
391
+
257
392
  def fix_data_file(datafile, nelements):
258
393
  """
259
394
  Change the atom types keyword in the structure file
@@ -309,6 +444,31 @@ def prepare_inputs_for_phase_diagram(inputyamlfile, calculation_base_name=None):
309
444
  calculation_base_name = inputyamlfile
310
445
 
311
446
  for phase in data['phases']:
447
+ # Validate binary system assumption
448
+ n_elements = len(phase['element'])
449
+ if n_elements != 2:
450
+ raise ValueError(
451
+ f"Phase diagram preparation currently supports only binary systems. "
452
+ f"Found {n_elements} elements: {phase['element']}"
453
+ )
454
+
455
+ # Validate element ordering consistency with pair_coeff
456
+ # This ensures element[0] -> type 1, element[1] -> type 2
457
+ if 'pair_coeff' in phase:
458
+ from calphy.input import _extract_elements_from_pair_coeff
459
+ # pair_coeff can be a list or a string - handle both
460
+ pair_coeff = phase['pair_coeff']
461
+ if isinstance(pair_coeff, list):
462
+ pair_coeff = pair_coeff[0] if pair_coeff else None
463
+ pair_coeff_elements = _extract_elements_from_pair_coeff(pair_coeff)
464
+ if pair_coeff_elements != phase['element']:
465
+ raise ValueError(
466
+ f"Element ordering mismatch for phase '{phase.get('phase_name', 'unnamed')}'!\n"
467
+ f"Elements in 'element' field: {phase['element']}\n"
468
+ f"Elements from pair_coeff: {pair_coeff_elements}\n"
469
+ f"These must match exactly in order (element[0] -> LAMMPS type 1, element[1] -> type 2)."
470
+ )
471
+
312
472
  phase_reference_state = phase['reference_phase']
313
473
  phase_name = phase['phase_name']
314
474
 
@@ -325,34 +485,18 @@ def prepare_inputs_for_phase_diagram(inputyamlfile, calculation_base_name=None):
325
485
  other_element_list.remove(reference_element)
326
486
  other_element = other_element_list[0]
327
487
 
328
- #convert to list if scalar
329
- if not isinstance(comps['range'], list):
330
- comps["range"] = [comps["range"]]
331
- if len(comps["range"]) == 2:
332
- comp_arr = np.arange(comps['range'][0], comps['range'][-1], comps['interval'])
333
- last_val = comps['range'][-1]
334
- if last_val not in comp_arr:
335
- comp_arr = np.append(comp_arr, last_val)
336
- ncomps = len(comp_arr)
337
- is_reference = np.abs(comp_arr-comps['reference']) < 1E-5
338
- elif len(comps["range"]) == 1:
339
- ncomps = 1
340
- comp_arr = [comps["range"][0]]
341
- is_reference = [True]
342
- else:
343
- raise ValueError("Composition range should be scalar of list of two values!")
488
+ # Create composition array using helper function
489
+ comp_arr, is_reference = _create_composition_array(
490
+ comps['range'],
491
+ comps['interval'],
492
+ comps['reference']
493
+ )
494
+ ncomps = len(comp_arr)
344
495
 
496
+ # Create temperature array using helper function
345
497
  temps = phase["temperature"]
346
- if not isinstance(temps['range'], list):
347
- temps["range"] = [temps["range"]]
348
- if len(temps["range"]) == 2:
349
- ntemps = int((temps['range'][-1]-temps['range'][0])/temps['interval'])+1
350
- temp_arr = np.linspace(temps['range'][0], temps['range'][-1], ntemps, endpoint=True)
351
- elif len(temps["range"]) == 1:
352
- ntemps = 1
353
- temp_arr = [temps["range"][0]]
354
- else:
355
- raise ValueError("Temperature range should be scalar of list of two values!")
498
+ temp_arr = _create_temperature_array(temps['range'], temps['interval'])
499
+ ntemps = len(temp_arr)
356
500
 
357
501
  all_calculations = []
358
502
 
@@ -372,27 +516,30 @@ def prepare_inputs_for_phase_diagram(inputyamlfile, calculation_base_name=None):
372
516
  outfile = fix_data_file(calc['lattice'], len(calc['element']))
373
517
 
374
518
  #add ref phase, needed
375
- calc['reference_phase'] = str(phase_reference_state)
519
+ calc['reference_phase'] = phase_reference_state
376
520
  calc['reference_composition'] = comps['reference']
377
- calc['mode'] = str('fe')
521
+ calc['mode'] = 'fe'
378
522
  calc['folder_prefix'] = f'{phase_name}-{comp:.2f}'
379
- calc['lattice'] = str(outfile)
523
+ calc['lattice'] = outfile
380
524
 
381
- #now we need to run this for different temp
382
- for temp in temp_arr:
383
- calc_for_temp = copy.deepcopy(calc)
384
- calc_for_temp['temperature'] = int(temp)
385
- all_calculations.append(calc_for_temp)
525
+ # Add calculations for each temperature
526
+ _add_temperature_calculations(calc, temp_arr, all_calculations)
386
527
  else:
387
528
  #off stoichiometric
388
529
  #copy the dict
389
530
  calc = copy.deepcopy(phase)
390
531
 
391
- #first thing first, we need to calculate the number of atoms
392
- #we follow the convention that composition is always given with the second species
393
- n_atoms = np.sum(calc['composition']['number_of_atoms'])
532
+ #read the structure file to determine input composition automatically
533
+ input_chemical_composition = read_structure_composition(calc['lattice'], calc['element'])
534
+
535
+ #calculate total number of atoms from structure
536
+ n_atoms = sum(input_chemical_composition.values())
537
+
538
+ if n_atoms == 0:
539
+ raise ValueError(f"No atoms found in structure file {calc['lattice']}")
394
540
 
395
- #find number of atoms of second species
541
+ #find number of atoms of second species based on target composition
542
+ #we follow the convention that composition is always given with the reference element
396
543
  output_chemical_composition = {}
397
544
  n_species_b = int(np.round(comp*n_atoms, decimals=0))
398
545
  output_chemical_composition[reference_element] = n_species_b
@@ -400,11 +547,8 @@ def prepare_inputs_for_phase_diagram(inputyamlfile, calculation_base_name=None):
400
547
  n_species_a = int(n_atoms-n_species_b)
401
548
  output_chemical_composition[other_element] = n_species_a
402
549
 
403
- if n_species_a == 0:
404
- raise ValueError("Please add pure phase as a new entry!")
405
- #create input comp dict and output comp dict
406
- input_chemical_composition = {element:number for element, number in zip(calc['element'],
407
- calc['composition']['number_of_atoms'])}
550
+ # Note: Pure phases (n_species_a == 0 or n_species_b == 0) are allowed
551
+ # Composition transformation can handle 100% replacement
408
552
 
409
553
  #good, now we need to write such a structure out; likely better to use working directory for that
410
554
  folder_prefix = f'{phase_name}-{comp:.2f}'
@@ -421,7 +565,7 @@ def prepare_inputs_for_phase_diagram(inputyamlfile, calculation_base_name=None):
421
565
 
422
566
  #just submit comp scales
423
567
  #add ref phase, needed
424
- calc['mode'] = str('composition_scaling')
568
+ calc['mode'] = 'composition_scaling'
425
569
  calc['folder_prefix'] = folder_prefix
426
570
  calc['composition_scaling'] = {}
427
571
  calc['composition_scaling']['output_chemical_composition'] = output_chemical_composition
@@ -447,22 +591,20 @@ def prepare_inputs_for_phase_diagram(inputyamlfile, calculation_base_name=None):
447
591
  _ = calc.pop(key, None)
448
592
 
449
593
  #add ref phase, needed
450
- calc['mode'] = str('fe')
594
+ calc['mode'] = 'fe'
451
595
  calc['folder_prefix'] = folder_prefix
452
- calc['lattice'] = str(outfile)
596
+ calc['lattice'] = outfile
453
597
 
454
- #now we need to run this for different temp
455
- for temp in temp_arr:
456
- calc_for_temp = copy.deepcopy(calc)
457
- calc_for_temp['temperature'] = int(temp)
458
- all_calculations.append(calc_for_temp)
598
+ # Add calculations for each temperature
599
+ _add_temperature_calculations(calc, temp_arr, all_calculations)
459
600
 
460
601
  #finish and write up the file
461
602
  output_data = {"calculations": all_calculations}
603
+ base_name = os.path.basename(calculation_base_name)
462
604
  for rep in ['.yml', '.yaml']:
463
- calculation_base_name = calculation_base_name.replace(rep, '')
605
+ base_name = base_name.replace(rep, '')
464
606
 
465
- outfile_phase = phase_name + '_' + calculation_base_name + ".yaml"
607
+ outfile_phase = phase_name + '_' + base_name + ".yaml"
466
608
  with open(outfile_phase, 'w') as fout:
467
609
  yaml.safe_dump(output_data, fout)
468
610
  print(f'Total {len(all_calculations)} calculations found for phase {phase_name}, written to {outfile_phase}')