servalcat 0.4.131__cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. servalcat/__init__.py +10 -0
  2. servalcat/__main__.py +120 -0
  3. servalcat/ext.cpython-314t-x86_64-linux-gnu.so +0 -0
  4. servalcat/refine/__init__.py +0 -0
  5. servalcat/refine/cgsolve.py +100 -0
  6. servalcat/refine/refine.py +1162 -0
  7. servalcat/refine/refine_geom.py +245 -0
  8. servalcat/refine/refine_spa.py +400 -0
  9. servalcat/refine/refine_xtal.py +339 -0
  10. servalcat/refine/spa.py +151 -0
  11. servalcat/refine/xtal.py +312 -0
  12. servalcat/refmac/__init__.py +0 -0
  13. servalcat/refmac/exte.py +191 -0
  14. servalcat/refmac/refmac_keywords.py +660 -0
  15. servalcat/refmac/refmac_wrapper.py +423 -0
  16. servalcat/spa/__init__.py +0 -0
  17. servalcat/spa/fofc.py +488 -0
  18. servalcat/spa/fsc.py +391 -0
  19. servalcat/spa/localcc.py +197 -0
  20. servalcat/spa/realspcc_from_var.py +128 -0
  21. servalcat/spa/run_refmac.py +979 -0
  22. servalcat/spa/shift_maps.py +293 -0
  23. servalcat/spa/shiftback.py +137 -0
  24. servalcat/spa/translate.py +129 -0
  25. servalcat/utils/__init__.py +35 -0
  26. servalcat/utils/commands.py +1629 -0
  27. servalcat/utils/fileio.py +836 -0
  28. servalcat/utils/generate_operators.py +296 -0
  29. servalcat/utils/hkl.py +811 -0
  30. servalcat/utils/logger.py +140 -0
  31. servalcat/utils/maps.py +345 -0
  32. servalcat/utils/model.py +933 -0
  33. servalcat/utils/refmac.py +759 -0
  34. servalcat/utils/restraints.py +888 -0
  35. servalcat/utils/symmetry.py +298 -0
  36. servalcat/xtal/__init__.py +0 -0
  37. servalcat/xtal/french_wilson.py +262 -0
  38. servalcat/xtal/run_refmac_small.py +240 -0
  39. servalcat/xtal/sigmaa.py +1954 -0
  40. servalcat/xtal/twin.py +316 -0
  41. servalcat-0.4.131.dist-info/METADATA +60 -0
  42. servalcat-0.4.131.dist-info/RECORD +45 -0
  43. servalcat-0.4.131.dist-info/WHEEL +6 -0
  44. servalcat-0.4.131.dist-info/entry_points.txt +4 -0
  45. servalcat-0.4.131.dist-info/licenses/LICENSE +373 -0
@@ -0,0 +1,1162 @@
1
+ """
2
+ Author: "Keitaro Yamashita, Garib N. Murshudov"
3
+ MRC Laboratory of Molecular Biology
4
+
5
+ This software is released under the
6
+ Mozilla Public License, version 2.0; see LICENSE.
7
+ """
8
+ from __future__ import absolute_import, division, print_function, generators
9
+ import os
10
+ import time
11
+ import re
12
+ import gemmi
13
+ import numpy
14
+ import json
15
+ import pandas
16
+ import scipy.sparse
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Dict
19
+ import omegaconf
20
+ import servalcat # for version
21
+ from servalcat.utils import logger
22
+ from servalcat import utils
23
+ from servalcat.refmac import exte
24
+ from servalcat import ext
25
+ from . import cgsolve
26
+ u_to_b = utils.model.u_to_b
27
+ b_to_u = utils.model.b_to_u
28
+ Type = ext.RefineParams.Type
29
+ #import line_profiler
30
+ #import atexit
31
+ #profile = line_profiler.LineProfiler()
32
+ #atexit.register(profile.print_stats)
33
+
34
+ """
35
+ atom_selection:
36
+ xyz:
37
+ include: []
38
+ exclude: []
39
+ exclude_restraint: []
40
+ adp:
41
+ include: []
42
+ exclude: []
43
+ exclude_restraint: []
44
+ occ:
45
+ include: []
46
+ exclude: []
47
+ exclude_restraint: []
48
+ dfrac:
49
+ include: []
50
+ exclude: []
51
+
52
+ occ_groups:
53
+ - id: 1
54
+ selections:
55
+ - sel_1
56
+ - sel_2
57
+ occ_group_constraints:
58
+ - ids: [1, 2]
59
+ complete: true
60
+ local_geom_weights:
61
+ - sel: ..
62
+ w: 1
63
+ local_adpr_weights:
64
+ - sel: ..
65
+ w: 1
66
+ initialisation:
67
+ adp:
68
+ '*': 50
69
+ occ: {}
70
+ dfrac:
71
+ '[H]': 1.0
72
+ """
73
+
74
+ @dataclass
75
+ class SelectionConfig:
76
+ include: List[str] = field(default_factory=list, metadata={"help": "List of gemmi Selection to include"})
77
+ exclude: List[str] = field(default_factory=list, metadata={"help": "List of gemmi Selection to exclude"})
78
+ exclude_restraint: List[str] = field(default_factory=list, metadata={"help": "List of gemmi Selection to exclude from restraints"})
79
+ @dataclass
80
+ class OccGroupItem:
81
+ id: int
82
+ selections: List[str] = field(default_factory=list)
83
+ @dataclass
84
+ class OccGroupConstItem:
85
+ ids: List[int] = field(default_factory=list)
86
+ complete: bool = True
87
+ @dataclass
88
+ class SelectionAndWeightItem:
89
+ sel: str = omegaconf.MISSING
90
+ w: float = 1.
91
+ @dataclass
92
+ class RefineConfig:
93
+ atom_selection: Dict[str, SelectionConfig] = field(
94
+ default_factory=lambda: {
95
+ "xyz": SelectionConfig(include=["*"]),
96
+ "adp": SelectionConfig(include=["*"]),
97
+ "occ": SelectionConfig(),
98
+ "dfrac": SelectionConfig(),
99
+ },
100
+ metadata={"help": "Configuration for atom selection during refinement"}
101
+ )
102
+ occ_groups: List[OccGroupItem] = field(default_factory=list)
103
+ occ_group_constraints: List[OccGroupConstItem] = field(default_factory=list)
104
+ occ_group_const_mu: float = 10
105
+ occ_group_const_mu_update_factor: float = 1.1
106
+ occ_group_const_mu_update_tol_rel: float = 0.25
107
+ occ_group_const_mu_update_tol_abs: float = 0.01
108
+ local_geom_weights: List[SelectionAndWeightItem] = field(default_factory=list)
109
+ local_adpr_weights: List[SelectionAndWeightItem] = field(default_factory=list)
110
+ initialisation: Dict[str, Dict[str, float]] = field(
111
+ default_factory=lambda: {
112
+ "adp": {},
113
+ "occ": {},
114
+ "dfrac": {}
115
+ },
116
+ metadata={"help": ""}
117
+ )
118
+ write_trajectory: bool = False
119
+
120
+ def load_config(yaml_file, args, refmac_params):
121
+ cfg = omegaconf.OmegaConf.create({"refine": RefineConfig()})
122
+ if yaml_file:
123
+ conf = omegaconf.OmegaConf.load(yaml_file)
124
+ try:
125
+ cfg = omegaconf.OmegaConf.merge(cfg, conf)
126
+ except omegaconf.errors.ValidationError as e:
127
+ raise SystemExit(f"Error while parsing {yaml_file}.\n{e.msg}")
128
+
129
+ # load from args
130
+ rcfg = cfg.refine
131
+ if getattr(args, "write_trajectory", False):
132
+ rcfg.write_trajectory = True
133
+ if getattr(args, "fix_xyz", False):
134
+ rcfg.atom_selection.xyz.include = []
135
+ rcfg.atom_selection.xyz.exclude = []
136
+ if getattr(args, "adp", None) == "fix":
137
+ rcfg.atom_selection.adp.include = []
138
+ rcfg.atom_selection.adp.exclude = []
139
+ if getattr(args, "refine_all_occ", False):
140
+ rcfg.atom_selection.occ.include = ["*"]
141
+ rcfg.atom_selection.occ.exclude = []
142
+ if getattr(args, "refine_dfrac", False):
143
+ rcfg.atom_selection.dfrac.include = ["*"] # or H?
144
+ rcfg.atom_selection.dfrac.exclude = []
145
+
146
+ # load Refmac params (unfinished)
147
+ if refmac_params.get("occu", {}).get("groups"):
148
+ rgroups = refmac_params["occu"]["groups"]
149
+ rconst = refmac_params["occu"].get("const", [])
150
+ occ_grs = rcfg.occ_groups
151
+ occ_cnst = rcfg.occ_group_constraints
152
+ occ_incl = rcfg.atom_selection.occ.include
153
+ for igr in rgroups:
154
+ occ_grs.append(OccGroupItem(igr))
155
+ for sel in rgroups[igr]:
156
+ for chain in sel.get("chains", ["*"]):
157
+ if "resi_from" in sel and "resi_to" in sel:
158
+ resi = f"{sel['resi_from']}-{sel['resi_to']}"
159
+ else:
160
+ resi = sel.get("resi", "")
161
+ atom = sel.get("atom", "*")
162
+ alt = sel.get("alt", "")
163
+ if alt: alt = ":" + alt
164
+ selstr = f"//{chain}/{resi}/{atom}{alt}"
165
+ occ_grs[-1]["selections"].append(selstr)
166
+ #occ_incl.append(selstr) # to do this, GroupOccupancy needs to be removed
167
+ for cmpl, ids in rconst:
168
+ occ_cnst.append(OccGroupConstItem(ids, cmpl))
169
+
170
+ logger.writeln("Config loaded")
171
+ logger.writeln("--")
172
+ logger.write(omegaconf.OmegaConf.to_yaml(cfg))
173
+ logger.writeln("--")
174
+ return cfg.refine
175
+ # load_config()
176
+
177
+ def iterate_selection(sel_str, st):
178
+ sel = gemmi.Selection(sel_str)
179
+ for model in sel.models(st):
180
+ for chain in sel.chains(model):
181
+ for residue in sel.residues(chain):
182
+ for atom in sel.atoms(residue):
183
+ yield (model, chain, residue, atom)
184
+ # iterate_selection()
185
+
186
+ def RefineParams(st, refine_xyz=False, adp_mode=0, refine_occ=False,
187
+ refine_dfrac=False, use_q_b_mixed=True, cfg=None,
188
+ exclude_h_ll=True): # FIXME refine_dfrac/exclude_h_ll and cfg
189
+ assert adp_mode in (0, 1, 2) # 0=fix, 1=iso, 2=aniso
190
+ if refine_dfrac and not st[0].has_hydrogen():
191
+ raise RuntimeError("Hydrogen must be present when deuterium fraction refinement is requested")
192
+ ret = ext.RefineParams(use_aniso=(adp_mode == 2), use_q_b_mixed=use_q_b_mixed)
193
+ ret.set_model(st[0])
194
+ if cfg:
195
+ # occupancy groups
196
+ occ_groups = []
197
+ group_ids = {}
198
+ for occ_gr in cfg.occ_groups:
199
+ occ_groups.append([])
200
+ group_ids[occ_gr.id] = len(occ_groups) - 1
201
+ for s in occ_gr.selections:
202
+ nsel = 0
203
+ for _, _, _, atom in iterate_selection(s, st):
204
+ occ_groups[-1].append(atom)
205
+ nsel += 1
206
+ if nsel == 0:
207
+ logger.writeln(f"Warning: no atom found for the selection {s}")
208
+ ret.set_occ_groups(occ_groups)
209
+ for o in cfg.occ_group_constraints:
210
+ ret.occ_group_constraints.append((o.complete, [group_ids[x] for x in o.ids]))
211
+
212
+ # selections
213
+ sele = cfg.atom_selection
214
+ ext.set_refine_flags(st[0],
215
+ sele.xyz.include, sele.xyz.exclude,
216
+ sele.adp.include, sele.adp.exclude,
217
+ sele.occ.include, sele.occ.exclude,
218
+ sele.dfrac.include, sele.dfrac.exclude)
219
+ ret.set_params_from_flags()
220
+
221
+ for t, p in ((Type.X, sele.xyz), (Type.B, sele.adp), (Type.Q, sele.occ)):
222
+ for ex_sel in p.exclude_restraint:
223
+ for _, _, _, atom in iterate_selection(ex_sel, st):
224
+ ret.add_geom_exclusion(atom.serial-1, t)
225
+
226
+ for specs, weights in ((cfg.local_geom_weights, ret.geom_weights),
227
+ (cfg.local_adpr_weights, ret.adpr_weights)):
228
+ for spec in specs:
229
+ for _, _, _, atom in iterate_selection(spec.sel, st):
230
+ weights[atom.serial-1] = spec.w
231
+ else:
232
+ ret.set_params(refine_xyz=refine_xyz, refine_adp=adp_mode > 0,
233
+ refine_occ=refine_occ, refine_dfrac=refine_dfrac)
234
+ if exclude_h_ll:
235
+ if refine_dfrac:
236
+ for t in (Type.X, Type.B, Type.Q):
237
+ if ret.is_refined(t):
238
+ ret.exclude_h_ll(t)
239
+ else:
240
+ ret.exclude_h_ll()
241
+
242
+ logger.writeln("Number of refinement parameters:")
243
+ df = pandas.DataFrame(ret.params_summary())
244
+ logger.writeln(df.to_string() + "\n")
245
+ return ret
246
+
247
+ class Geom:
248
+ def __init__(self, st, topo, monlib, refine_params, adpr_w=1, occr_w=1, shake_rms=0,
249
+ params=None, unrestrained=False, use_nucleus=False,
250
+ ncslist=None):
251
+ self.st = st
252
+ self.params = refine_params
253
+ self.lookup = {x.atom: x for x in self.st[0].all()}
254
+ try:
255
+ self.geom = ext.Geometry(self.st, self.params, monlib.ener_lib)
256
+ except TypeError as e:
257
+ raise SystemExit(f"An error occurred while creating the Geometry object:\n{e}\n\n"
258
+ "This likely indicates an installation issue. "
259
+ "Please verify that you have the correct version of gemmi installed and that both gemmi and servalcat were compiled in the same environment.")
260
+ self.specs = utils.model.find_special_positions(self.st)
261
+ #cs_count = len(self.st.find_spacegroup().operations())
262
+ for atom, images, matp, mata in self.specs:
263
+ #n_sym = len([x for x in images if x < cs_count]) + 1
264
+ n_sym = len(images) + 1
265
+ self.geom.specials.append(ext.Geometry.Special(atom, matp, mata, n_sym))
266
+ self.adpr_w = adpr_w
267
+ self.occr_w = occr_w
268
+ self.unrestrained = unrestrained
269
+ if shake_rms > 0:
270
+ numpy.random.seed(0)
271
+ utils.model.shake_structure(self.st, shake_rms, copy=False)
272
+ #utils.fileio.write_model(self.st, "shaken", pdb=True, cif=True)
273
+ self.use_nucleus = use_nucleus
274
+ self.calc_kwds = {"use_nucleus": self.use_nucleus}
275
+ if params is None:
276
+ params = {}
277
+ for k in ("wbond", "wangle", "wtors", "wplane", "wchir", "wvdw", "wncs"):
278
+ if k in params:
279
+ self.calc_kwds[k] = params[k]
280
+ logger.writeln("setting geometry weight {}= {}".format(k, params[k]))
281
+ inc_tors, exc_tors = utils.restraints.make_torsion_rules(params.get("restr", {}))
282
+ rtors = utils.restraints.select_restrained_torsions(monlib, inc_tors, exc_tors)
283
+ self.geom.mon_tors_names = rtors["monomer"]
284
+ self.geom.link_tors_names = rtors["link"]
285
+ self.group_occ = GroupOccupancy(self.st, params.get("occu"))
286
+ if not self.unrestrained:
287
+ self.geom.load_topo(topo)
288
+ exte.read_external_restraints(params.get("exte", []), self.st, self.geom)
289
+ self.geom.finalize_restraints()
290
+ self.outlier_sigmas = dict(bond=5, angle=5, torsion=5, vdw=5, ncs=5, chir=5, plane=5, staca=5, stacd=5, per_atom=5, interval=5)
291
+ self.parents = {}
292
+ self.ncslist = ncslist
293
+ self.const_ls, self.const_u = [], []
294
+ # __init__()
295
+
296
+ def set_h_parents(self):
297
+ self.parents = {}
298
+ for bond in self.geom.bonds:
299
+ if bond.atoms[0].is_hydrogen():
300
+ self.parents[bond.atoms[0]] = bond.atoms[1]
301
+ elif bond.atoms[1].is_hydrogen():
302
+ self.parents[bond.atoms[1]] = bond.atoms[0]
303
+ # set_h_parents()
304
+
305
+ def setup_target(self):
306
+ self.geom.setup_target(self.params.is_refined(Type.Q))
307
+ def setup_nonbonded(self):
308
+ skip_critical_dist = not self.params.is_refined(Type.X) or self.unrestrained
309
+ self.geom.setup_nonbonded(skip_critical_dist=skip_critical_dist)
310
+ if self.ncslist:
311
+ self.geom.setup_ncsr(self.ncslist)
312
+ def setup_occ_constraint(self, lambda_ini=0., u_ini=100.):
313
+ self.const_ls = [lambda_ini for _ in self.params.occ_group_constraints]
314
+ self.const_u = [u_ini for _ in self.params.occ_group_constraints]
315
+ def calc(self, target_only):
316
+ if self.params.is_refined(Type.X) and not self.unrestrained:
317
+ return self.geom.calc(check_only=target_only, **self.calc_kwds)
318
+ return 0
319
+ def calc_adp_restraint(self, target_only):
320
+ if self.params.is_refined(Type.B):
321
+ return self.geom.calc_adp_restraint(target_only, self.adpr_w)
322
+ return 0
323
+ def calc_occ_restraint(self, target_only):
324
+ if self.params.is_refined(Type.Q):
325
+ return self.geom.calc_occ_restraint(target_only, self.occr_w)
326
+ return 0
327
+ def update_occ_consts(self, consts_prev, alpha=1.1, eta=0.25, tol=0.01):
328
+ consts = self.params.occ_constraints()
329
+ self.const_ls = [self.const_ls[i] - self.const_u[i] * consts[i]
330
+ for i in range(len(consts))]
331
+ self.const_u = [u * (1 if abs(c) < max(tol, eta * abs(c_prev)) else alpha)
332
+ for u, c, c_prev in zip(self.const_u, consts, consts_prev)]
333
+ return consts
334
+ def calc_target(self, target_only):
335
+ self.geom.clear_target()
336
+ geom_x = self.calc(target_only)
337
+ geom_a = self.calc_adp_restraint(target_only)
338
+ geom_q = self.calc_occ_restraint(target_only)
339
+ geom_c = self.geom.calc_occ_constraint(target_only, self.const_ls, self.const_u)
340
+ logger.writeln(" geom_x = {}".format(geom_x))
341
+ logger.writeln(" geom_a = {}".format(geom_a))
342
+ logger.writeln(" geom_q = {}".format(geom_q))
343
+ logger.writeln(" geom_c = {}".format(geom_c))
344
+ geom = geom_x + geom_a + geom_q + geom_c
345
+ if not target_only:
346
+ self.geom.spec_correction()
347
+ return geom
348
+
349
+ def show_model_stats(self, show_outliers=True):
350
+ self.calc(True)
351
+ self.calc_adp_restraint(True)
352
+ self.calc_occ_restraint(True)
353
+ ret = {"outliers": {}}
354
+ if show_outliers:
355
+ get_table = dict(bond=self.geom.reporting.get_bond_outliers,
356
+ angle=self.geom.reporting.get_angle_outliers,
357
+ torsion=self.geom.reporting.get_torsion_outliers,
358
+ chir=self.geom.reporting.get_chiral_outliers,
359
+ plane=self.geom.reporting.get_plane_outliers,
360
+ staca=self.geom.reporting.get_stacking_angle_outliers,
361
+ stacd=self.geom.reporting.get_stacking_dist_outliers,
362
+ vdw=self.geom.reporting.get_vdw_outliers,
363
+ interval=self.geom.reporting.get_interval_outliers,
364
+ #ncs=self.geom.reporting.get_ncsr_outliers, # not useful?
365
+ )
366
+ labs = dict(bond="Bond distances",
367
+ angle="Bond angles",
368
+ torsion="Torsion angles",
369
+ chir="Chiral centres",
370
+ plane="Planar groups",
371
+ staca="Stacking plane angles",
372
+ stacd="Stacking plane distances",
373
+ vdw="VDW repulsions",
374
+ interval="Interval",
375
+ ncs="Local NCS restraints")
376
+
377
+ def atomlabel(r, i):
378
+ symstr = lambda idx, s: f" ({idx+1};{s[0]},{s[1]},{s[2]})"
379
+ ret = str(self.lookup[r.atoms[i]])
380
+ if type(r) in (ext.Geometry.Bond, ext.Geometry.Interval, ext.Geometry.Vdw) and i > 0 and not r.same_asu():
381
+ return ret + symstr(r.sym_idx, r.pbc_shift)
382
+ if type(r) is ext.Geometry.Angle and not r.same_asu(i): # always true if i==1
383
+ return ret + symstr(r.sym_idx_1 if i == 0 else r.sym_idx_2,
384
+ r.pbc_shift_1 if i == 0 else r.pbc_shift_2)
385
+ return ret
386
+
387
+ for k in get_table:
388
+ kwgs = {"min_z": self.outlier_sigmas[k]}
389
+ if k == "bond": kwgs["use_nucleus"] = self.use_nucleus
390
+ table = get_table[k](**kwgs)
391
+ if table["z"]:
392
+ if "restr" in table:
393
+ tmp = {}
394
+ for i in range(3 if k == "angle" else 2): # only bond/angle/interval/vdw return restr
395
+ tmp[f"atom{i+1}"] = [atomlabel(r, i) for r in table["restr"]]
396
+ del table["restr"]
397
+ table = {**tmp, **table}
398
+ else:
399
+ for kk in table:
400
+ if kk.startswith(("atom", "plane", "1_atom", "2_atom")):
401
+ table[kk] = [str(self.lookup[x]) for x in table[kk]]
402
+ df = pandas.DataFrame(table)
403
+ df = df.reindex(df.z.abs().sort_values(ascending=False).index)
404
+ ret["outliers"][k] = df
405
+ if k == "bond":
406
+ df0 = df[df.type < 2].drop(columns=["type", "alpha"])
407
+ if len(df0.index) > 0:
408
+ logger.writeln(" *** {} outliers (Z >= {}) ***\n".format(labs[k], self.outlier_sigmas[k]))
409
+ logger.writeln(df0.to_string(float_format="{:.3f}".format, index=False) + "\n")
410
+ df0 = df[df.type == 2].drop(columns=["type"])
411
+ if len(df0.index) > 0:
412
+ logger.writeln(" *** External bond outliers (Z >= {}) ***\n".format(self.outlier_sigmas[k]))
413
+ logger.writeln(df0.to_string(float_format="{:.3f}".format, index=False) + "\n")
414
+ else:
415
+ logger.writeln(" *** {} outliers (Z >= {}) ***\n".format(labs[k], self.outlier_sigmas[k]))
416
+ logger.writeln(df.to_string(float_format="{:.3f}".format, index=False) + "\n")
417
+
418
+ # Per-atom score
419
+ if 0:
420
+ peratom = self.geom.reporting.per_atom_score(len(self.atoms), self.use_nucleus, "mean")
421
+ df = pandas.DataFrame(peratom)
422
+ df.insert(0, "atom", [str(self.lookup[x]) for x in self.atoms])
423
+ df = df[df["total"] >= self.outlier_sigmas["per_atom"]]
424
+ if show_outliers and len(df.index) > 0:
425
+ df.sort_values("total", ascending=False, inplace=True)
426
+ ret["outliers"]["per_atom"] = df
427
+ logger.writeln(" *** Per-atom violations (Z >= {}) ***\n".format(self.outlier_sigmas["per_atom"]))
428
+ logger.writeln(df.to_string(float_format="{:.2f}".format, index=False) + "\n")
429
+
430
+ df = pandas.DataFrame(self.geom.reporting.get_summary_table(self.use_nucleus))
431
+ df = df.set_index("Restraint type").rename_axis(index=None)
432
+ ret["summary"] = df
433
+ logger.writeln(df.to_string(float_format="{:.3f}".format) + "\n")
434
+ return ret
435
+
436
+ def show_binstats(df, cycle_number):
437
+ forplot = []
438
+ datalabs = [x for x in ("Mn(Io)", "Mn(Ic)", "Mn(Fo)", "Mn(Fc)") if x in df]
439
+ rlabs = [x for x in df if x.startswith("R")]
440
+ fsclabs = [x for x in df if x.startswith("fsc")]
441
+ cclabs = [x for x in df if x.startswith("CC")]
442
+ dlabs = [x for x in df if re.search("^D[0-9]*", x)]
443
+ if datalabs: forplot.append(["Mean I/F vs. Resolution", datalabs])
444
+ if "fsc_model" in df: forplot.append(["FSC", ["fsc_model"]])
445
+ if rlabs: forplot.append(["R", rlabs])
446
+ if fsclabs: forplot.append(["FSC", fsclabs])
447
+ if cclabs: forplot.append(["CC", cclabs])
448
+ if dlabs: forplot.append(["ML parameters - D", dlabs])
449
+ if "S" in df: forplot.append(["ML parameters - Sigma", ["S"]])
450
+ if "Cmpl" in df: forplot.append(["Data completeness", ["Cmpl"]])
451
+ lstr = utils.make_loggraph_str(df, "Data stats in cycle {}".format(cycle_number), forplot,
452
+ s2=1/df["d_min"]**2,
453
+ float_format="{:.4f}".format)
454
+ logger.writeln(lstr)
455
+ # show_binstats()
456
+
457
+ def convert_stats_to_dicts(stats):
458
+ tmp = []
459
+ for s in stats: # stats must be a list of dict
460
+ tmp.append({})
461
+ for k in s:
462
+ if k == "geom":
463
+ tmp[-1]["geom"] = {"summary": s["geom"]["summary"].to_dict()}
464
+ for kk in s["geom"]["outliers"]:
465
+ tmp[-1]["geom"].setdefault("outliers", {})[kk] = s["geom"]["outliers"][kk].to_dict(orient="records")
466
+ else:
467
+ tmp[-1][k] = s[k]
468
+ return tmp
469
+ # convert_stats_to_dicts()
470
+
471
+ def write_stats_json_safe(stats, json_out):
472
+ tmp = convert_stats_to_dicts(stats)
473
+ out_tmp = json_out + ".part"
474
+ with open(out_tmp, "w") as ofs:
475
+ json.dump(tmp, ofs, indent=2)
476
+ for i in range(10):
477
+ try:
478
+ # On Windows, this fails when another process open the file
479
+ os.replace(out_tmp, json_out)
480
+ break
481
+ except PermissionError:
482
+ logger.writeln(f"{json_out} locked. retrying..")
483
+ time.sleep(0.5)
484
+ else:
485
+ raise RuntimeError(f"Cannot write {json_out}")
486
+ logger.writeln(f"Refinement statistics saved: {json_out}")
487
+ # write_stats_json_safe()
488
+
489
+ def print_h_options(h_change, h_present, refine_h, hout, geom_only):
490
+ if not h_present:
491
+ h_change = gemmi.HydrogenChange.Remove
492
+ logger.writeln("Hydrogen related options")
493
+ logger.write(" use in refinement{}: hydrogen atoms ".format("" if geom_only else "/map calculation"))
494
+ logger.writeln({gemmi.HydrogenChange.ReAddButWater: "have been (re)generated",
495
+ gemmi.HydrogenChange.ReAdd: "(including water) have been (re)generated",
496
+ gemmi.HydrogenChange.ReAddKnown: "(except for rotatable) have been (re) generated",
497
+ gemmi.HydrogenChange.NoChange: "from the input model have been retained",
498
+ gemmi.HydrogenChange.Remove: "have either been removed or were not present"}[h_change])
499
+ if h_present:
500
+ logger.write(" target: hydrogen atoms will be ")
501
+ if geom_only or not refine_h:
502
+ logger.writeln("just optimized according to geometric restraints")
503
+ else:
504
+ logger.writeln("refined against experimental data")
505
+ logger.writeln(" in output model: " + ("written" if hout and h_present else "not written"))
506
+ logger.writeln("")
507
+ # print_hydrogen_options()
508
+
509
+ class GroupOccupancy:
510
+ # TODO max may not be one. should check multiplicity
511
+ def __init__(self, st, params):
512
+ self.groups = []
513
+ self.consts = []
514
+ self.group_idxes = [0 for _ in range(st[0].count_atom_sites())]
515
+ self.ncycle = 0
516
+ if not params or not params.get("groups"):
517
+ return
518
+ logger.writeln("Occupancy groups:")
519
+ atom_idxes = []
520
+ for igr in params["groups"]:
521
+ self.groups.append([]) # list of atoms
522
+ n_curr = len(atom_idxes)
523
+ for sel in params["groups"][igr]:
524
+ sel_chains = sel.get("chains")
525
+ sel_from = sel.get("resi_from")
526
+ sel_to = sel.get("resi_to")
527
+ sel_seq = sel.get("resi")
528
+ sel_atom = sel.get("atom")
529
+ sel_alt = sel.get("alt")
530
+ for chain in st[0]:
531
+ if sel_chains and chain.name not in sel_chains:
532
+ continue
533
+ flag = False
534
+ for res in chain:
535
+ if sel_seq and res.seqid != sel_seq:
536
+ continue
537
+ if sel_from and res.seqid == sel_from:
538
+ flag = True
539
+ if sel_from and not flag:
540
+ continue
541
+ for atom in res:
542
+ if sel_atom and atom.name != sel_atom:
543
+ continue
544
+ if sel_alt and atom.altloc != sel_alt:
545
+ continue
546
+ atom_idxes.append(atom.serial-1)
547
+ self.groups[-1].append(atom)
548
+ self.group_idxes[atom.serial-1] = len(self.groups)
549
+ if sel_to and res.seqid == sel_to:
550
+ flag = False
551
+ logger.writeln(" id= {} atoms= {}".format(igr, len(atom_idxes) - n_curr))
552
+
553
+ igr_idxes = {igr:i for i, igr in enumerate(params["groups"])}
554
+ self.consts = [(is_comp, [igr_idxes[g] for g in gids])
555
+ for is_comp, gids in params["const"]]
556
+ self.ncycle = params.get("ncycle", 5)
557
+ self.params = ext.RefineParams()
558
+ self.params.set_model(st[0])
559
+ self.params.set_params_selected(atom_idxes, refine_occ=True)
560
+ self.params.exclude_h_ll() # should be reasonable
561
+ # __init__()
562
+
563
+ def constraint(self, x):
564
+ # x: occupancy parameters
565
+ ret = []
566
+ for is_comp, ids in self.consts:
567
+ x_sum = numpy.sum(x[ids])
568
+ if is_comp or x_sum > 1:
569
+ ret.append(x_sum - 1)
570
+ else:
571
+ ret.append(0.)
572
+ return numpy.array(ret)
573
+
574
+ def ensure_constraints(self):
575
+ vals = []
576
+ for atoms in self.groups:
577
+ occ = numpy.mean([a.occ for a in atoms])
578
+ occ = min(1, max(1e-3, occ))
579
+ vals.append(occ)
580
+ for is_comp, idxes in self.consts:
581
+ sum_occ = sum(vals[i] for i in idxes)
582
+ if not is_comp and sum_occ < 1:
583
+ sum_occ = 1. # do nothing
584
+ for i in idxes:
585
+ logger.writeln("Imposing constraints: {} {}".format(vals[i], vals[i]/sum_occ))
586
+ vals[i] /= sum_occ
587
+ for occ, atoms in zip(vals, self.groups):
588
+ for a in atoms: a.occ = occ
589
+
590
+ def get_x(self):
591
+ return numpy.array([atoms[0].occ for atoms in self.groups])
592
+
593
+ def set_x(self, x):
594
+ for p, atoms in zip(x, self.groups):
595
+ for a in atoms:
596
+ a.occ = p
597
+ #a.occ = max(1, min(1e-3, p))
598
+
599
+ def target(self, x, ll, ls, u):
600
+ self.set_x(x)
601
+ ll.update_fc()
602
+ c = self.constraint(x)
603
+ f = ll.calc_target() - numpy.dot(ls, c) + 0.5 * u * numpy.sum(c**2)
604
+ return f
605
+
606
+ def grad(self, x, ll, ls, u):
607
+ c = self.constraint(x)
608
+ ll.calc_grad(self.params, specs=None)
609
+ #print("grad=", ll.ll.vn)
610
+ #print("diag=", ll.ll.am)
611
+ assert len(ll.ll.vn) == len(ll.ll.am)
612
+ vn = []
613
+ diag = []
614
+ atom_to_param = self.params.atom_to_param(Type.Q)
615
+ for atoms in self.groups: # idxes
616
+ idxes = [atom_to_param[a.serial-1] for a in atoms if not self.params.is_excluded_ll(a.serial-1, Type.Q)]
617
+ vn.append(numpy.sum(numpy.array(ll.ll.vn)[idxes]))
618
+ diag.append(numpy.sum(numpy.array(ll.ll.am)[idxes]))
619
+ vn, diag = numpy.array(vn), numpy.array(diag)
620
+ for i, (is_comp, idxes) in enumerate(self.consts):
621
+ dcdx = numpy.zeros(len(self.groups))
622
+ dcdx[idxes] = 1.
623
+ if is_comp or c[i] != 0:
624
+ vn -= (ls[i] - u * c[i]) * dcdx
625
+ diag += u * dcdx**2
626
+
627
+ return vn, diag
628
+
629
+ def refine(self, ll, alpha=1.1):
630
+ # Refinement of grouped occupancies using augmented Lagrangian
631
+ # f(x) = LL(x) - sum_j (lambda_j c_j(x)) + u/2 sum_j (c_j(x))^2
632
+ # with c_j(x) = 0 constraints
633
+ if not self.groups:
634
+ return
635
+ logger.writeln("\n== Group occupancy refinement ==")
636
+ self.ensure_constraints() # make sure constrained groups have the same occupancies.
637
+ ls = 0 * numpy.ones(len(self.consts)) # Lagrange multiplier
638
+ u = 10000. # penalty parameter. in Refmac 1/0.01**2
639
+ x0 = self.get_x()
640
+ #logger.writeln(" parameters: {}".format(len(x0)))
641
+ f0 = self.target(x0, ll, ls, u)
642
+ ret = []
643
+ for cyc in range(self.ncycle):
644
+ ret.append({"Ncyc": cyc+1, "f0": f0})
645
+ logger.writeln("occ_{}_f0= {:.4e}".format(cyc, f0))
646
+ vn, diag = self.grad(x0, ll, ls, u)
647
+ #diag0 = diag.copy()
648
+ diag[diag < 1e-6] = 1.
649
+ dx = -vn / diag
650
+ #logger.writeln(f"debug {cyc=} {dx=} {vn=} {diag=} {diag0=}")
651
+ if 0:
652
+ ofs = open("debug.dat", "w")
653
+ for scale in (-1, -0.5, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 2):
654
+ self.set_x(x0 + scale * dx)
655
+ ll.update_fc()
656
+ c = self.constraint(x0 + dx)
657
+ f = ll.calc_target() + numpy.dot(ls, c) + 0.5 * u * numpy.sum(c**2)
658
+ ofs.write("{} {}\n".format(scale, f))
659
+ ofs.close()
660
+ import scipy.optimize
661
+ print(scipy.optimize.line_search(f=lambda x: self.target(x, ll, ls, u),
662
+ myfprime= lambda x: self.grad(ll, ls, u)[0],
663
+ xk= x0,
664
+ pk= dx))
665
+ quit()
666
+
667
+ scale = 1
668
+ for i in range(3):
669
+ scale = 1/2**i
670
+ f1 = self.target(x0 + dx * scale, ll, ls, u)
671
+ logger.writeln("occ_{}_f1, {}= {:.4e}".format(cyc, i, f1))
672
+ if f1 < f0: break
673
+ else:
674
+ logger.writeln("WARNING: function not minimised")
675
+ #self.set_x(x0) # Refmac accepts it even when function increases
676
+ c = self.constraint(x0 + dx * scale)
677
+ ret[-1]["f1"] = f1
678
+ ret[-1]["shift_scale"] = scale
679
+ f0 = f1
680
+ x0 = x0 + dx * scale
681
+ ls -= u * c
682
+ u = alpha * u
683
+ ret[-1]["const_viol"] = list(c)
684
+ ret[-1]["lambda_new"] = list(ls)
685
+ self.ensure_constraints()
686
+ ll.update_fc()
687
+ f = ll.calc_target()
688
+ logger.writeln("final -LL= {}".format(f))
689
+ return ret
690
+
691
+
692
+ class Refine:
693
+ def __init__(self, st, geom, cfg, refine_params, ll=None, unrestrained=False):
694
+ assert geom is not None
695
+ self.st = st # clone()?
696
+ self.st_traj = None
697
+ self.params = refine_params
698
+ self.geom = geom
699
+ self.ll = ll
700
+ self.gamma = 0
701
+ self.unrestrained = unrestrained
702
+ #self.h_inherit_parent_adp = self.params.is_refined(Type.B) and not self.refine_h and self.st[0].has_hydrogen()
703
+ #if self.h_inherit_parent_adp:
704
+ # self.geom.set_h_parents()
705
+ self.cfg = cfg
706
+ if self.cfg.write_trajectory:
707
+ self.st_traj = self.st.clone()
708
+ self.st_traj[-1].num = 0
709
+ assert self.geom.group_occ.groups or self.params.n_params() > 0
710
+ # __init__()
711
+
712
+ def print_weights(self): # TODO unfinished
713
+ logger.writeln("Geometry weights")
714
+ g = self.geom.geom
715
+ if self.params.is_refined(Type.B):
716
+ logger.writeln(" ADP restraints")
717
+ logger.writeln(" weight: {}".format(self.geom.adpr_w))
718
+ logger.writeln(" mode: {}".format(g.adpr_mode))
719
+ if g.adpr_mode == "diff":
720
+ logger.writeln(" sigmas: {}".format(" ".join("{:.2f}".format(x) for x in g.adpr_diff_sigs)))
721
+ elif g.adpr_mode == "kldiv":
722
+ logger.writeln(" sigmas: {}".format(" ".join("{:.2f}".format(x) for x in g.adpr_kl_sigs)))
723
+ else:
724
+ raise LookupError("unknown adpr_mode")
725
+ if self.params.is_refined(Type.Q):
726
+ logger.writeln(" Occupancy restraints")
727
+ logger.writeln(" weight: {}".format(self.geom.occr_w))
728
+
729
+ def scale_shifts(self, dx, scale):
730
+ shift_allow_high = 1.0
731
+ shift_allow_low = -1.0
732
+ shift_max_allow_B = 30.0
733
+ shift_min_allow_B = -30.0
734
+ shift_max_allow_q = 0.5
735
+ shift_min_allow_q = -0.5
736
+ shift_max_allow_d = 0.5
737
+ shift_min_allow_d = -0.5
738
+ dx = scale * dx
739
+ dxx = dx[self.params.vec_selection(Type.X)]
740
+ dxb = dx[self.params.vec_selection(Type.B)]
741
+ dxq = dx[self.params.vec_selection(Type.Q)]
742
+ dxd = dx[self.params.vec_selection(Type.D)]
743
+ if len(dxx) > 0:
744
+ logger.writeln("min(dx) = {}".format(numpy.min(dxx)))
745
+ logger.writeln("max(dx) = {}".format(numpy.max(dxx)))
746
+ logger.writeln("mean(dx)= {}".format(numpy.mean(dxx)))
747
+ dxx[dxx > shift_allow_high] = shift_allow_high
748
+ dxx[dxx < shift_allow_low] = shift_allow_low
749
+ if len(dxb) > 0:
750
+ # TODO this is misleading in anisotropic case
751
+ logger.writeln("min(dB) = {}".format(numpy.min(dxb)))
752
+ logger.writeln("max(dB) = {}".format(numpy.max(dxb)))
753
+ logger.writeln("mean(dB)= {}".format(numpy.mean(dxb)))
754
+ # FIXME we should'nt apply eigen decomp to dxb
755
+ if self.params.aniso:
756
+ for i in range(len(dxb)//6):
757
+ j = i * 6
758
+ a = numpy.array([[dxb[j], dxb[j+3], dxb[j+4]],
759
+ [dxb[j+3], dxb[j+1], dxb[j+5]],
760
+ [dxb[j+4], dxb[j+5], dxb[j+2]]])
761
+ v, Q = numpy.linalg.eigh(a)
762
+ v[v > shift_max_allow_B] = shift_max_allow_B
763
+ v[v < shift_min_allow_B] = shift_min_allow_B
764
+ a = Q.dot(numpy.diag(v)).dot(Q.T)
765
+ dxb[j:j+6] = a[0,0], a[1,1], a[2,2], a[0,1], a[0,2], a[1,2]
766
+ else:
767
+ dxb[dxb > shift_max_allow_B] = shift_max_allow_B
768
+ dxb[dxb < shift_min_allow_B] = shift_min_allow_B
769
+ if len(dxq) > 0:
770
+ logger.writeln("min(dq) = {}".format(numpy.min(dxq)))
771
+ logger.writeln("max(dq) = {}".format(numpy.max(dxq)))
772
+ logger.writeln("mean(dq)= {}".format(numpy.mean(dxq)))
773
+ dxq[dxq > shift_max_allow_q] = shift_max_allow_q
774
+ dxq[dxq < shift_min_allow_q] = shift_min_allow_q
775
+ if len(dxd) > 0:
776
+ logger.writeln("min(dd) = {}".format(numpy.min(dxd)))
777
+ logger.writeln("max(dd) = {}".format(numpy.max(dxd)))
778
+ logger.writeln("mean(dd)= {}".format(numpy.mean(dxd)))
779
+ dxd[dxd > shift_max_allow_d] = shift_max_allow_d
780
+ dxd[dxd < shift_min_allow_d] = shift_min_allow_d
781
+
782
+ return dx
783
+
784
+ def set_x(self, x):
785
+ self.params.set_x(x, min_b=0.5)
786
+ max_occ = {}
787
+ if self.params.is_refined(Type.Q) and self.geom.specs:
788
+ max_occ = {atom: 1./(len(images)+1) for atom, images, _, _ in self.geom.specs}
789
+ for a in self.params.atoms:
790
+ a.occ = min(max_occ.get(a, 1), max(0, a.occ))
791
+ # Copy B of hydrogen from parent
792
+ #if self.h_inherit_parent_adp:
793
+ # for h in self.geom.parents:
794
+ # p = self.geom.parents[h]
795
+ # h.b_iso = p.b_iso
796
+ # h.aniso = p.aniso
797
+
798
+ if self.ll is not None:
799
+ self.ll.update_fc()
800
+
801
+ self.geom.setup_nonbonded() # if refine_xyz=False, no need to do it every time
802
+ self.geom.setup_target()
803
+ logger.writeln("vdws = {}".format(len(self.geom.geom.vdws)))
804
+ logger.writeln(f"atoms = {len(self.params.atoms)}")
805
+ logger.writeln(f"pairs = {self.geom.geom.target.n_pairs()}")
806
+
807
+ def get_x(self):
808
+ return numpy.array(self.params.get_x())
809
+
810
+ #@profile
811
+ def calc_target(self, w=1, target_only=False):
812
+ geom = self.geom.calc_target(target_only)
813
+ if self.ll is not None:
814
+ ll = self.ll.calc_target()
815
+ logger.writeln(" ll= {}".format(ll))
816
+ if not target_only:
817
+ self.ll.calc_grad(self.params, self.geom.geom.specials)
818
+ else:
819
+ ll = 0
820
+
821
+ f = w * ll + geom
822
+ return f
823
+
824
+ #@profile
825
+ def run_cycle(self, weight=1):
826
+ if 0: # test of grad
827
+ self.ll.update_fc()
828
+ x0 = self.get_x()
829
+ f0,ader,_ = self.calc_target(weight)
830
+ i = 1
831
+ for e in 1e-1,1e-2,1e-3, 1e-4, 1e-5:
832
+ x1 = numpy.copy(x0)
833
+ x1[i] += e
834
+ self.set_x(x1)
835
+ self.ll.update_fc()
836
+ f1,_,_ = self.calc_target(weight, target_only=True)
837
+ nder = (f1 - f0) / e
838
+ print("e=", e)
839
+ print("NUM DER=", nder)
840
+ print("ANA DER=", ader[i])
841
+ print("ratio=", nder/ader[i])
842
+ quit()
843
+
844
+ f0 = self.calc_target(weight)
845
+ x0 = self.get_x()
846
+ logger.writeln("f0= {:.4e}".format(f0))
847
+ if 0:
848
+ logger.writeln(f"geom_vec=\n{self.geom.geom.target.vn}")
849
+ logger.writeln(f"geom_mat=\n{self.geom.geom.target.am_spmat}")
850
+ logger.writeln(f" ll_vec=\n{self.ll.ll.vn}")
851
+ logger.writeln(f" ll_mat=\n{self.ll.ll.fisher_spmat}")
852
+ if 1:
853
+ use_ic = False # incomplete cholesky. problematic at least in geometry optimisation case
854
+ logger.writeln("using cgsolve in c++, ic={}".format(use_ic))
855
+ cgsolver = ext.CgSolve(self.geom.geom.target, None if self.ll is None else self.ll.ll)
856
+ if use_ic:
857
+ cgsolver.gamma = 0
858
+ cgsolver.max_gamma_cyc = 1
859
+ else:
860
+ cgsolver.gamma = self.gamma
861
+ dx = cgsolver.solve(weight, logger, use_ic)
862
+ self.gamma = cgsolver.gamma
863
+ else:
864
+ logger.writeln("using cgsolve in py")
865
+ am = self.geom.geom.target.am_spmat
866
+ vn = numpy.array(self.geom.geom.target.vn)
867
+ if self.ll is not None:
868
+ am += self.ll.ll.fisher_spmat * weight
869
+ vn += numpy.array(self.ll.ll.vn) * weight
870
+ diag = am.diagonal()
871
+ diag[diag<=0] = 1.
872
+ diag = numpy.sqrt(diag)
873
+ rdiag = 1./diag # sk
874
+ M = scipy.sparse.diags(rdiag)
875
+ dx, self.gamma = cgsolve.cgsolve_rm(A=am, v=vn, M=M, gamma=self.gamma)
876
+
877
+ if 0: # to check hessian scale
878
+ with open("minimise_line.dat", "w") as ofs:
879
+ ofs.write("s f\n")
880
+ for s in numpy.arange(-2, 2, 0.1):
881
+ dx2 = self.scale_shifts(dx, s)
882
+ self.set_x(x0 + dx2)
883
+ fval = self.calc_target(weight, target_only=True)
884
+ ofs.write("{} {}\n".format(s, fval))
885
+ quit()
886
+
887
+ ret = True # success
888
+ shift_scale = 1
889
+ for i in range(3):
890
+ shift_scale = 1/2**i
891
+ dx2 = self.scale_shifts(dx, shift_scale)
892
+ self.set_x(x0 - dx2)
893
+ f1 = self.calc_target(weight, target_only=True)
894
+ logger.writeln("f1, {}= {:.4e}".format(i, f1))
895
+ if f1 < f0: break
896
+ else:
897
+ ret = False
898
+ logger.writeln("WARNING: function not minimised")
899
+ #self.set_x(x0) # Refmac accepts it even when function increases
900
+
901
+ return ret, shift_scale, f1
902
+
903
+ def run_cycles(self, ncycles, weight=1, weight_adjust=False, debug=False,
904
+ weight_adjust_bond_rmsz_range=(0.5, 1.), stats_json_out=None):
905
+ self.print_weights()
906
+ stats = [{"Ncyc": 0}]
907
+ self.params.ensure_occ_constraints()
908
+ self.geom.setup_nonbonded()
909
+ self.geom.setup_target()
910
+ self.geom.setup_occ_constraint(u_ini=self.cfg.occ_group_const_mu)
911
+ logger.writeln("vdws = {}".format(len(self.geom.geom.vdws)))
912
+ logger.writeln(f"atoms = {len(self.params.atoms)}")
913
+ logger.writeln(f"pairs = {self.geom.geom.target.n_pairs()}")
914
+ stats[-1]["geom"] = self.geom.show_model_stats(show_outliers=True)
915
+ if self.params.occ_group_constraints:
916
+ stats[-1]["occ_const"] = {"lambda": self.geom.const_ls,
917
+ "mu": self.geom.const_u,
918
+ "violation": self.params.occ_constraints(),
919
+ "occ": self.params.constrained_occ_values()
920
+ }
921
+ if self.ll is not None:
922
+ self.ll.update_fc()
923
+ self.ll.overall_scale()
924
+ self.ll.update_ml_params()
925
+ self.ll.prepare_target()
926
+ llstats = self.ll.calc_stats(bin_stats=True)
927
+ stats[-1]["data"] = {"summary": llstats["summary"],
928
+ "binned": llstats["bin_stats"].to_dict(orient="records"),
929
+ "ml": llstats["ml"].to_dict(orient="records")}
930
+ if "twin_alpha" in llstats:
931
+ stats[-1]["twin_alpha"] = llstats["twin_alpha"]
932
+ show_binstats(llstats["bin_stats"], 0)
933
+ if self.params.is_refined(Type.B):
934
+ utils.model.adp_analysis(self.st)
935
+ if stats_json_out:
936
+ write_stats_json_safe(stats, stats_json_out)
937
+ occ_refine_flag = self.ll is not None and self.geom.group_occ.groups and self.geom.group_occ.ncycle > 0
938
+
939
+ for i in range(ncycles):
940
+ logger.writeln("\n====== CYCLE {:2d} ======\n".format(i+1))
941
+ logger.writeln(f" weight = {weight:.4e}")
942
+ if self.params.is_refined_any():
943
+ is_ok, shift_scale, fval = self.run_cycle(weight=weight)
944
+ stats.append({"Ncyc": len(stats), "shift_scale": shift_scale, "fval": fval, "fval_decreased": is_ok,
945
+ "weight": weight})
946
+ elif occ_refine_flag:
947
+ stats.append({"Ncyc": len(stats)})
948
+ if occ_refine_flag:
949
+ stats[-1]["occ_refine"] = self.geom.group_occ.refine(self.ll)
950
+ if debug: utils.fileio.write_model(self.st, "refined_{:02d}".format(i+1), pdb=True)#, cif=True)
951
+ stats[-1]["geom"] = self.geom.show_model_stats(show_outliers=(i==ncycles-1))
952
+ if self.params.occ_group_constraints:
953
+ viols = self.geom.update_occ_consts(consts_prev=stats[-2]["occ_const"]["violation"],
954
+ alpha=self.cfg.occ_group_const_mu_update_factor,
955
+ eta=self.cfg.occ_group_const_mu_update_tol_rel,
956
+ tol=self.cfg.occ_group_const_mu_update_tol_abs)
957
+ stats[-1]["occ_const"] = {"lambda": self.geom.const_ls,
958
+ "mu": self.geom.const_u,
959
+ "violation": viols,
960
+ "occ": self.params.constrained_occ_values()
961
+ }
962
+ # TODO add stats[-1]["occ_constraints"] and hide stdout
963
+ if self.ll is not None:
964
+ if i == ncycles - 1: # last cycle
965
+ self.params.ensure_occ_constraints()
966
+ self.ll.overall_scale()
967
+ f0 = self.ll.calc_target()
968
+ self.ll.update_ml_params()
969
+ self.ll.prepare_target()
970
+ llstats = self.ll.calc_stats(bin_stats=True)#(i==ncycles-1))
971
+ if llstats["summary"]["-LL"] > f0:
972
+ logger.writeln("WARNING: -LL has increased after ML parameter optimization:"
973
+ "{} to {}".format(f0, llstats["summary"]["-LL"]))
974
+ stats[-1]["data"] = {"summary": llstats["summary"],
975
+ "binned": llstats["bin_stats"].to_dict(orient="records"),
976
+ "ml": llstats["ml"].to_dict(orient="records")}
977
+ if "twin_alpha" in llstats:
978
+ stats[-1]["twin_alpha"] = llstats["twin_alpha"]
979
+ show_binstats(llstats["bin_stats"], i+1)
980
+ if self.params.is_refined(Type.B):
981
+ utils.model.adp_analysis(self.st)
982
+ if (weight_adjust and self.params.is_refined(Type.X) and not self.unrestrained and self.ll is not None and
983
+ len(stats) > 2 and "Bond distances, non H" in stats[-1]["geom"]["summary"].index):
984
+ rmsz = stats[-1]["geom"]["summary"]["r.m.s.Z"]["Bond distances, non H"]
985
+ rmsz0 = stats[-2]["geom"]["summary"]["r.m.s.Z"]["Bond distances, non H"]
986
+ if rmsz > weight_adjust_bond_rmsz_range[1] and rmsz > rmsz0:
987
+ weight /= 1.1
988
+ elif rmsz < weight_adjust_bond_rmsz_range[0] and rmsz0 < weight_adjust_bond_rmsz_range[0] and rmsz < rmsz0:
989
+ weight *= 1.3
990
+ elif rmsz > 1.5 * rmsz0:
991
+ weight /= 1.1
992
+ if self.st_traj is not None:
993
+ self.st_traj.add_model(self.st[0])
994
+ self.st_traj[-1].num = len(self.st_traj) - 1
995
+ if stats_json_out:
996
+ write_stats_json_safe(stats, stats_json_out)
997
+
998
+ logger.writeln("")
999
+
1000
+ # Make tables
1001
+ if self.params.occ_group_constraints:
1002
+ tmp = []
1003
+ for icyc, s in enumerate(stats):
1004
+ con = s["occ_const"]
1005
+ d = {"Ncyc": icyc}
1006
+ d.update({f"lambda_{i+1}":l for i,l in enumerate(con["lambda"])})
1007
+ d.update({f"mu_{i+1}":l for i,l in enumerate(con["mu"])})
1008
+ d.update({f"violation_{i+1}":l for i,l in enumerate(con["violation"])})
1009
+ d.update({f"occ_{i+1}_{j+1}":q for i, l in enumerate(con["occ"])
1010
+ for j, q in enumerate(l)})
1011
+ tmp.append(d)
1012
+ df = pandas.DataFrame(tmp)
1013
+ forplot = [
1014
+ ["Lagrange multiplier",
1015
+ ["Ncyc"] + [x for x in df if x.startswith("lambda_")]],
1016
+ ["Penalty parameter mu",
1017
+ ["Ncyc"] + [x for x in df if x.startswith("mu")]],
1018
+ ["Group constrained occupancies",
1019
+ ["Ncyc"] + [x for x in df if x.startswith("occ_")]],
1020
+ ["Constraint violations",
1021
+ ["Ncyc"] + [x for x in df if x.startswith("violation_")]],
1022
+ ]
1023
+ lstr = utils.make_loggraph_str(df, "group occupancies vs cycle", forplot,
1024
+ float_format="{:.4f}".format)
1025
+ logger.writeln(lstr)
1026
+
1027
+ data_keys, geom_keys = set(), set()
1028
+ tmp = []
1029
+ for d in stats:
1030
+ x = {"Ncyc": d["Ncyc"]}
1031
+ if "data" in d and "summary" in d["data"]:
1032
+ x.update(d["data"]["summary"])
1033
+ data_keys.update(d["data"]["summary"])
1034
+ if "geom" in d:
1035
+ for k, n, l in (("r.m.s.d.", "Bond distances, non H", "rmsBOND"),
1036
+ ("r.m.s.Z", "Bond distances, non H", "zBOND"),
1037
+ ("r.m.s.d.", "Bond angles, non H", "rmsANGL"),
1038
+ ("r.m.s.Z", "Bond angles, non H", "zANGL")):
1039
+ if k in d["geom"]["summary"] and n in d["geom"]["summary"][k]:
1040
+ x[l] = d["geom"]["summary"][k].get(n)
1041
+ geom_keys.add(l)
1042
+ tmp.append(x)
1043
+ df = pandas.DataFrame(tmp)
1044
+ forplot = []
1045
+ if "FSCaverage" in data_keys:
1046
+ forplot.append(["FSC", ["Ncyc", "FSCaverage"]])
1047
+ r_keys = [x for x in data_keys if x.startswith("R")]
1048
+ if r_keys:
1049
+ forplot.append(["R", ["Ncyc"] + r_keys])
1050
+ cc_keys = [x for x in data_keys if x.startswith("CC")]
1051
+ if cc_keys:
1052
+ forplot.append(["CC", ["Ncyc"] + cc_keys])
1053
+ if "-LL" in data_keys:
1054
+ forplot.append(["-LL", ["Ncyc", "-LL"]])
1055
+ rms_keys = [x for x in geom_keys if x.startswith("rms")]
1056
+ if rms_keys:
1057
+ forplot.append(["Geometry", ["Ncyc"] + rms_keys])
1058
+ z_keys = [x for x in geom_keys if x.startswith("z")]
1059
+ if z_keys:
1060
+ forplot.append(["Geometry Z", ["Ncyc"] + z_keys])
1061
+
1062
+ lstr = utils.make_loggraph_str(df, "stats vs cycle", forplot,
1063
+ float_format="{:.4f}".format)
1064
+ logger.writeln(lstr)
1065
+ return stats
1066
+
1067
+ # class Refine
1068
+
1069
+ def update_meta(st, stats, ll=None):
1070
+ # TODO write stats. probably geom.reporting.get_summary_table should return with _refine_ls_restr.type names
1071
+ # should remove st.mod_residues?
1072
+ st.helices.clear()
1073
+ st.sheets.clear()
1074
+ raw_remarks = [f'REMARK 3',
1075
+ f'REMARK 3 REFINEMENT.',
1076
+ f'REMARK 3 PROGRAM : SERVALCAT {servalcat.__version__}',
1077
+ f'REMARK 3 AUTHORS : YAMASHITA,MURSHUDOV',
1078
+ f'REMARK 3',
1079
+ ]
1080
+ si = gemmi.SoftwareItem()
1081
+ si.classification = gemmi.SoftwareItem.Classification.Refinement
1082
+ si.name = "Servalcat"
1083
+ si.version = servalcat.__version__
1084
+ si.date = servalcat.__date__
1085
+ st.meta.software = [si]
1086
+
1087
+ ri = gemmi.RefinementInfo()
1088
+ if "geom" in stats:
1089
+ restr_stats = []
1090
+ raw_remarks.append("REMARK 3 RMS DEVIATIONS FROM IDEAL VALUES COUNT RMS WEIGHT")
1091
+ for k, n, l, pl in (("r.m.s.d.", "Bond distances, non H", "s_bond_nonh_d", "BOND LENGTHS REFINED ATOMS (A)"),
1092
+ ("r.m.s.d.", "Bond angles, non H", "s_angle_nonh_deg", "BOND ANGLES REFINED ATOMS (DEGREES)"),
1093
+ ("r.m.s.d.", "Torsion angles, period 1", "s_dihedral_angle_1_deg", "TORSION ANGLES, PERIOD 1 (DEGREES)"),
1094
+ ("r.m.s.d.", "Torsion angles, period 2", "s_dihedral_angle_2_deg", "TORSION ANGLES, PERIOD 2 (DEGREES)"),
1095
+ ("r.m.s.d.", "Torsion angles, period 3", "s_dihedral_angle_3_deg", "TORSION ANGLES, PERIOD 3 (DEGREES)"),
1096
+ ("r.m.s.d.", "Torsion angles, period 6", "s_dihedral_angle_6_deg", "TORSION ANGLES, PERIOD 6 (DEGREES)"),
1097
+ ("r.m.s.d.", "Chiral centres", "s_chiral_restr", "CHIRAL-CENTER RESTRAINTS (A**3)"),
1098
+ ("r.m.s.d.", "Planar groups", "s_planes", "GENERAL PLANES REFINED ATOMS (A)"),
1099
+ ("r.m.s.d.", "VDW nonbonded", "s_nbd", ""),
1100
+ ("r.m.s.d.", "VDW torsion", "s_nbtor", ""),
1101
+ ("r.m.s.d.", "VDW hbond", "s_hbond_nbd", ""),
1102
+ ("r.m.s.d.", "VDW metal", "s_metal_ion", ""),
1103
+ ("r.m.s.d.", "VDW dummy", "s_dummy_nbd", ""),
1104
+ ("r.m.s.d.", "VDW nonbonded, symmetry", "s_symmetry_nbd", ""),
1105
+ ("r.m.s.d.", "VDW torsion, symmetry", "s_symmetry_nbtor", ""),
1106
+ ("r.m.s.d.", "VDW hbond, symmetry", "s_symmetry_hbond_nbd", ""),
1107
+ ("r.m.s.d.", "VDW metal, symmetry", "s_symmetry_metal_ion", ""),
1108
+ ("r.m.s.d.", "VDW dummy, symmetry", "s_symmetry_dummy_nbd", "")):
1109
+ if k in stats["geom"]["summary"] and n in stats["geom"]["summary"][k]:
1110
+ rr = gemmi.RefinementInfo.Restr(l)
1111
+ rr.dev_ideal = round(stats["geom"]["summary"][k].get(n), 4)
1112
+ rr.count = stats["geom"]["summary"]["N restraints"].get(n)
1113
+ rr.weight = round(stats["geom"]["summary"]["Mn(sigma)"].get(n), 4)
1114
+ restr_stats.append(rr)
1115
+ if pl:
1116
+ raw_remarks.append(f"REMARK 3 {pl}:{rr.count:6d} ;{rr.dev_ideal:6.3f} ;{rr.weight:6.3f}")
1117
+ ri.restr_stats = restr_stats
1118
+ raw_remarks.append("REMARK 3")
1119
+ if ll is not None:
1120
+ ri.id = ll.refine_id()
1121
+ ri.mean_b = round(numpy.mean([cra.atom.b_iso for cra in st[0].all()]), 2)
1122
+ if ll.b_aniso is not None:
1123
+ ri.aniso_b = ll.b_aniso
1124
+ for k, kd, nd in (("Rwork", "r_work", 4), ("Rfree", "r_free", 4), ("R", "r_all", 4),
1125
+ ("FSCaverage", "fsc_work", 4),
1126
+ ("FSCaverage_half1", "fsc_work", 4), ("FSCaverage_half2", "fsc_free", 4)):
1127
+ if k in stats["data"]["summary"]:
1128
+ setattr(ri, kd, round(stats["data"]["summary"][k], nd))
1129
+ bins = []
1130
+ n_all = 0
1131
+ for b in stats["data"]["binned"]:
1132
+ bri = gemmi.BasicRefinementInfo()
1133
+ bri.resolution_high = round(b["d_min"], 3)
1134
+ bri.resolution_low = round(b["d_max"], 3)
1135
+ for k, kd, nd in (("Rwork", "r_work", 4), ("Rfree", "r_free", 4),
1136
+ ("R1work", "r_work", 4), ("R1free", "r_free", 4),
1137
+ ("R", "r_all", 4), ("R1", "r_all", 4),
1138
+ ("CCI", "cc_intensity_work", 4), ("CCF", "cc_fo_fc_work", 4),
1139
+ ("CCIwork", "cc_intensity_work", 4), ("CCIfree", "cc_intensity_free", 4),
1140
+ ("CCFwork", "cc_fo_fc_work", 4), ("CCFfree", "cc_fo_fc_free", 4),
1141
+ ("fsc_FC_full", "fsc_work", 4), ("fsc_model", "fsc_work", 4),
1142
+ ("fsc_model_half1", "fsc_work", 4), ("fsc_model_half2", "fsc_free", 4),
1143
+ ("n_work", "work_set_count", 0), ("n_free", "rfree_set_count", 0),
1144
+ ("n_obs", "reflection_count", 0), ("ncoeffs", "reflection_count", 0)):
1145
+ if k in b: setattr(bri, kd, round(b[k], nd))
1146
+ if "n_all" in b and "n_obs" in b:
1147
+ bri.completeness = round(b["n_obs"] / b["n_all"] * 100, 2)
1148
+ n_all += b["n_all"]
1149
+ bins.append(bri)
1150
+ ri.rfree_set_count = max(-1, sum(b.rfree_set_count for b in bins))
1151
+ ri.work_set_count = max(-1, sum(b.work_set_count for b in bins))
1152
+ ri.reflection_count = max(-1, sum(b.reflection_count for b in bins))
1153
+ ri.resolution_high = round(min(b.resolution_high for b in bins), 3)
1154
+ ri.resolution_low = round(max(b.resolution_low for b in bins), 3)
1155
+ if ri.reflection_count > 0 and n_all > 0:
1156
+ ri.completeness = round(ri.reflection_count / n_all * 100, 2)
1157
+ ri.bins = bins
1158
+ if ri.rfree_set_count > 0:
1159
+ ri.cross_validation_method = "THROUGHOUT"
1160
+ st.meta.refinement = [ri]
1161
+ st.raw_remarks = raw_remarks
1162
+ # update_meta()