servalcat 0.4.88__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of servalcat might be problematic. Click here for more details.

Files changed (45) hide show
  1. servalcat/__init__.py +10 -0
  2. servalcat/__main__.py +120 -0
  3. servalcat/ext.cp313-win_amd64.pyd +0 -0
  4. servalcat/refine/__init__.py +0 -0
  5. servalcat/refine/cgsolve.py +100 -0
  6. servalcat/refine/refine.py +823 -0
  7. servalcat/refine/refine_geom.py +220 -0
  8. servalcat/refine/refine_spa.py +345 -0
  9. servalcat/refine/refine_xtal.py +268 -0
  10. servalcat/refine/spa.py +136 -0
  11. servalcat/refine/xtal.py +273 -0
  12. servalcat/refmac/__init__.py +0 -0
  13. servalcat/refmac/exte.py +182 -0
  14. servalcat/refmac/refmac_keywords.py +639 -0
  15. servalcat/refmac/refmac_wrapper.py +403 -0
  16. servalcat/spa/__init__.py +0 -0
  17. servalcat/spa/fofc.py +473 -0
  18. servalcat/spa/fsc.py +387 -0
  19. servalcat/spa/localcc.py +188 -0
  20. servalcat/spa/realspcc_from_var.py +128 -0
  21. servalcat/spa/run_refmac.py +972 -0
  22. servalcat/spa/shift_maps.py +293 -0
  23. servalcat/spa/shiftback.py +137 -0
  24. servalcat/spa/translate.py +129 -0
  25. servalcat/utils/__init__.py +35 -0
  26. servalcat/utils/commands.py +1397 -0
  27. servalcat/utils/fileio.py +737 -0
  28. servalcat/utils/generate_operators.py +296 -0
  29. servalcat/utils/hkl.py +712 -0
  30. servalcat/utils/logger.py +116 -0
  31. servalcat/utils/maps.py +345 -0
  32. servalcat/utils/model.py +782 -0
  33. servalcat/utils/refmac.py +760 -0
  34. servalcat/utils/restraints.py +782 -0
  35. servalcat/utils/symmetry.py +295 -0
  36. servalcat/xtal/__init__.py +0 -0
  37. servalcat/xtal/french_wilson.py +256 -0
  38. servalcat/xtal/run_refmac_small.py +240 -0
  39. servalcat/xtal/sigmaa.py +1622 -0
  40. servalcat/xtal/twin.py +115 -0
  41. servalcat-0.4.88.dist-info/METADATA +55 -0
  42. servalcat-0.4.88.dist-info/RECORD +45 -0
  43. servalcat-0.4.88.dist-info/WHEEL +5 -0
  44. servalcat-0.4.88.dist-info/entry_points.txt +4 -0
  45. servalcat-0.4.88.dist-info/licenses/LICENSE +373 -0
@@ -0,0 +1,823 @@
1
+ """
2
+ Author: "Keitaro Yamashita, Garib N. Murshudov"
3
+ MRC Laboratory of Molecular Biology
4
+
5
+ This software is released under the
6
+ Mozilla Public License, version 2.0; see LICENSE.
7
+ """
8
+ from __future__ import absolute_import, division, print_function, generators
9
+ import os
10
+ import re
11
+ import gemmi
12
+ import numpy
13
+ import json
14
+ import pandas
15
+ import scipy.sparse
16
+ import servalcat # for version
17
+ from servalcat.utils import logger
18
+ from servalcat import utils
19
+ from servalcat.refmac import exte
20
+ from servalcat import ext
21
+ from . import cgsolve
22
+ u_to_b = utils.model.u_to_b
23
+ b_to_u = utils.model.b_to_u
24
+
25
+ #import line_profiler
26
+ #import atexit
27
+ #profile = line_profiler.LineProfiler()
28
+ #atexit.register(profile.print_stats)
29
+
30
+ class Geom:
31
+ def __init__(self, st, topo, monlib, adpr_w=1, shake_rms=0,
32
+ params=None, unrestrained=False, use_nucleus=False,
33
+ ncslist=None, atom_pos=None):
34
+ self.st = st
35
+ self.atoms = [None for _ in range(self.st[0].count_atom_sites())]
36
+ for cra in self.st[0].all(): self.atoms[cra.atom.serial-1] = cra.atom
37
+ if atom_pos is not None:
38
+ self.atom_pos = atom_pos
39
+ else:
40
+ self.atom_pos = list(range(len(self.atoms)))
41
+ self.n_refine_atoms = max(self.atom_pos) + 1
42
+ self.lookup = {x.atom: x for x in self.st[0].all()}
43
+ try:
44
+ self.geom = ext.Geometry(self.st, self.atom_pos, monlib.ener_lib)
45
+ except TypeError as e:
46
+ raise SystemExit(f"An error occurred while creating the Geometry object:\n{e}\n\n"
47
+ "This likely indicates an installation issue. "
48
+ "Please verify that you have the correct version of gemmi installed and that both gemmi and servalcat were compiled in the same environment.")
49
+ self.specs = utils.model.find_special_positions(self.st)
50
+ #cs_count = len(self.st.find_spacegroup().operations())
51
+ for atom, images, matp, mata in self.specs:
52
+ #n_sym = len([x for x in images if x < cs_count]) + 1
53
+ n_sym = len(images) + 1
54
+ self.geom.specials.append(ext.Geometry.Special(atom, matp, mata, n_sym))
55
+ self.adpr_w = adpr_w
56
+ self.occr_w = 1.
57
+ self.unrestrained = unrestrained
58
+ if shake_rms > 0:
59
+ numpy.random.seed(0)
60
+ utils.model.shake_structure(self.st, shake_rms, copy=False)
61
+ #utils.fileio.write_model(self.st, "shaken", pdb=True, cif=True)
62
+ self.use_nucleus = use_nucleus
63
+ self.calc_kwds = {"use_nucleus": self.use_nucleus}
64
+ if params is None:
65
+ params = {}
66
+ for k in ("wbond", "wangle", "wtors", "wplane", "wchir", "wvdw", "wncs"):
67
+ if k in params:
68
+ self.calc_kwds[k] = params[k]
69
+ logger.writeln("setting geometry weight {}= {}".format(k, params[k]))
70
+ inc_tors, exc_tors = utils.restraints.make_torsion_rules(params.get("restr", {}))
71
+ rtors = utils.restraints.select_restrained_torsions(monlib, inc_tors, exc_tors)
72
+ self.geom.mon_tors_names = rtors["monomer"]
73
+ self.geom.link_tors_names = rtors["link"]
74
+ self.group_occ = GroupOccupancy(self.st, params.get("occu"))
75
+ if not self.unrestrained:
76
+ self.geom.load_topo(topo)
77
+ exte.read_external_restraints(params.get("exte", []), self.st, self.geom)
78
+ self.geom.finalize_restraints()
79
+ self.outlier_sigmas = dict(bond=5, angle=5, torsion=5, vdw=5, ncs=5, chir=5, plane=5, staca=5, stacd=5, per_atom=5)
80
+ self.parents = {}
81
+ self.ncslist = ncslist
82
+ # __init__()
83
+
84
+ def set_h_parents(self):
85
+ self.parents = {}
86
+ for bond in self.geom.bonds:
87
+ if bond.atoms[0].is_hydrogen():
88
+ self.parents[bond.atoms[0]] = bond.atoms[1]
89
+ elif bond.atoms[1].is_hydrogen():
90
+ self.parents[bond.atoms[1]] = bond.atoms[0]
91
+ # set_h_parents()
92
+ def setup_nonbonded(self, refine_xyz):
93
+ skip_critical_dist = not refine_xyz or self.unrestrained
94
+ self.geom.setup_nonbonded(skip_critical_dist=skip_critical_dist, group_idxes=self.group_occ.group_idxes)
95
+ if self.ncslist:
96
+ self.geom.setup_ncsr(self.ncslist)
97
+ def calc(self, target_only):
98
+ return self.geom.calc(check_only=target_only, **self.calc_kwds)
99
+ def calc_adp_restraint(self, target_only):
100
+ return self.geom.calc_adp_restraint(target_only, self.adpr_w)
101
+ def calc_occ_restraint(self, target_only):
102
+ return self.geom.calc_occ_restraint(target_only, self.occr_w)
103
+ def calc_target(self, target_only, refine_xyz, adp_mode, use_occr):
104
+ self.geom.clear_target()
105
+ geom_x = self.calc(target_only) if refine_xyz else 0
106
+ geom_a = self.calc_adp_restraint(target_only) if adp_mode > 0 else 0
107
+ geom_q = self.calc_occ_restraint(target_only) if use_occr > 0 else 0
108
+ logger.writeln(" geom_x = {}".format(geom_x))
109
+ logger.writeln(" geom_a = {}".format(geom_a))
110
+ logger.writeln(" geom_q = {}".format(geom_q))
111
+ geom = geom_x + geom_a + geom_q
112
+ if not target_only:
113
+ self.geom.spec_correction()
114
+ return geom
115
+
116
+ def show_model_stats(self, refine_xyz=True, adp_mode=1, use_occr=False, show_outliers=True):
117
+ if refine_xyz:
118
+ self.calc(True)
119
+ if adp_mode > 0:
120
+ self.calc_adp_restraint(True)
121
+ if use_occr:
122
+ self.calc_occ_restraint(True)
123
+ ret = {"outliers": {}}
124
+ if show_outliers:
125
+ get_table = dict(bond=self.geom.reporting.get_bond_outliers,
126
+ angle=self.geom.reporting.get_angle_outliers,
127
+ torsion=self.geom.reporting.get_torsion_outliers,
128
+ chir=self.geom.reporting.get_chiral_outliers,
129
+ plane=self.geom.reporting.get_plane_outliers,
130
+ staca=self.geom.reporting.get_stacking_angle_outliers,
131
+ stacd=self.geom.reporting.get_stacking_dist_outliers,
132
+ vdw=self.geom.reporting.get_vdw_outliers,
133
+ #ncs=self.geom.reporting.get_ncsr_outliers, # not useful?
134
+ )
135
+ labs = dict(bond="Bond distances",
136
+ angle="Bond angles",
137
+ torsion="Torsion angles",
138
+ chir="Chiral centres",
139
+ plane="Planar groups",
140
+ staca="Stacking plane angles",
141
+ stacd="Stacking plane distances",
142
+ vdw="VDW repulsions",
143
+ ncs="Local NCS restraints")
144
+
145
+ for k in get_table:
146
+ kwgs = {"min_z": self.outlier_sigmas[k]}
147
+ if k == "bond": kwgs["use_nucleus"] = self.use_nucleus
148
+ table = get_table[k](**kwgs)
149
+ if table["z"]:
150
+ for kk in table:
151
+ if kk.startswith(("atom", "plane", "1_atom", "2_atom")):
152
+ table[kk] = [str(self.lookup[x]) for x in table[kk]]
153
+ df = pandas.DataFrame(table)
154
+ df = df.reindex(df.z.abs().sort_values(ascending=False).index)
155
+ ret["outliers"][k] = df
156
+ if k == "bond":
157
+ df0 = df[df.type < 2].drop(columns=["type", "alpha"])
158
+ if len(df0.index) > 0:
159
+ logger.writeln(" *** {} outliers (Z >= {}) ***\n".format(labs[k], self.outlier_sigmas[k]))
160
+ logger.writeln(df0.to_string(float_format="{:.3f}".format, index=False) + "\n")
161
+ df0 = df[df.type == 2].drop(columns=["type"])
162
+ if len(df0.index) > 0:
163
+ logger.writeln(" *** External bond outliers (Z >= {}) ***\n".format(self.outlier_sigmas[k]))
164
+ logger.writeln(df0.to_string(float_format="{:.3f}".format, index=False) + "\n")
165
+ else:
166
+ logger.writeln(" *** {} outliers (Z >= {}) ***\n".format(labs[k], self.outlier_sigmas[k]))
167
+ logger.writeln(df.to_string(float_format="{:.3f}".format, index=False) + "\n")
168
+
169
+ # Per-atom score
170
+ if 0:
171
+ peratom = self.geom.reporting.per_atom_score(len(self.atoms), self.use_nucleus, "mean")
172
+ df = pandas.DataFrame(peratom)
173
+ df.insert(0, "atom", [str(self.lookup[x]) for x in self.atoms])
174
+ df = df[df["total"] >= self.outlier_sigmas["per_atom"]]
175
+ if show_outliers and len(df.index) > 0:
176
+ df.sort_values("total", ascending=False, inplace=True)
177
+ ret["outliers"]["per_atom"] = df
178
+ logger.writeln(" *** Per-atom violations (Z >= {}) ***\n".format(self.outlier_sigmas["per_atom"]))
179
+ logger.writeln(df.to_string(float_format="{:.2f}".format, index=False) + "\n")
180
+
181
+ df = pandas.DataFrame(self.geom.reporting.get_summary_table(self.use_nucleus))
182
+ df = df.set_index("Restraint type").rename_axis(index=None)
183
+ ret["summary"] = df
184
+ logger.writeln(df.to_string(float_format="{:.3f}".format) + "\n")
185
+ return ret
186
+
187
+ def show_binstats(df, cycle_number):
188
+ forplot = []
189
+ rlabs = [x for x in df if x.startswith("R")]
190
+ fsclabs = [x for x in df if x.startswith("fsc")]
191
+ cclabs = [x for x in df if x.startswith("CC")]
192
+ dlabs = [x for x in df if re.search("^D[0-9]*", x)]
193
+ if "fsc_model" in df: forplot.append(["FSC", ["fsc_model"]])
194
+ if rlabs: forplot.append(["R", rlabs])
195
+ if fsclabs: forplot.append(["FSC", fsclabs])
196
+ if cclabs: forplot.append(["CC", cclabs])
197
+ if dlabs: forplot.append(["ML parameters - D", dlabs])
198
+ if "S" in df: forplot.append(["ML parameters - Sigma", ["S"]])
199
+ lstr = utils.make_loggraph_str(df, "Data stats in cycle {}".format(cycle_number), forplot,
200
+ s2=1/df["d_min"]**2,
201
+ float_format="{:.4f}".format)
202
+ logger.writeln(lstr)
203
+ # show_binstats()
204
+
205
+ def convert_stats_to_dicts(stats):
206
+ tmp = []
207
+ for s in stats: # stats must be a list of dict
208
+ tmp.append({})
209
+ for k in s:
210
+ if k == "geom":
211
+ tmp[-1]["geom"] = {"summary": s["geom"]["summary"].to_dict()}
212
+ for kk in s["geom"]["outliers"]:
213
+ tmp[-1]["geom"].setdefault("outliers", {})[kk] = s["geom"]["outliers"][kk].to_dict(orient="records")
214
+ else:
215
+ tmp[-1][k] = s[k]
216
+ return tmp
217
+ # convert_stats_to_dicts()
218
+
219
+ def write_stats_json_safe(stats, json_out):
220
+ tmp = convert_stats_to_dicts(stats)
221
+ out_tmp = json_out + ".part"
222
+ with open(out_tmp, "w") as ofs:
223
+ json.dump(tmp, ofs, indent=2)
224
+ os.replace(out_tmp, json_out)
225
+ logger.writeln(f"Refinement statistics saved: {json_out}")
226
+ # write_stats_json_safe()
227
+
228
+ class GroupOccupancy:
229
+ # TODO max may not be one. should check multiplicity
230
+ def __init__(self, st, params):
231
+ self.groups = []
232
+ self.consts = []
233
+ self.group_idxes = [0 for _ in range(st[0].count_atom_sites())]
234
+ self.ncycle = 0
235
+ if not params or not params.get("groups"):
236
+ return
237
+ logger.writeln("Occupancy groups:")
238
+ self.atom_pos = [-1 for _ in range(st[0].count_atom_sites())]
239
+ count = 0
240
+ for igr in params["groups"]:
241
+ self.groups.append([[], []]) # list of [indexes, atoms]
242
+ n_curr = count
243
+ for sel in params["groups"][igr]:
244
+ sel_chains = sel.get("chains")
245
+ sel_from = sel.get("resi_from")
246
+ sel_to = sel.get("resi_to")
247
+ sel_seq = sel.get("resi")
248
+ sel_atom = sel.get("atom")
249
+ sel_alt = sel.get("alt")
250
+ for chain in st[0]:
251
+ if sel_chains and chain.name not in sel_chains:
252
+ continue
253
+ flag = False
254
+ for res in chain:
255
+ if sel_seq and res.seqid != sel_seq:
256
+ continue
257
+ if sel_from and res.seqid == sel_from:
258
+ flag = True
259
+ if sel_from and not flag:
260
+ continue
261
+ for atom in res:
262
+ if sel_atom and atom.name != sel_atom:
263
+ continue
264
+ if sel_alt and atom.altloc != sel_alt:
265
+ continue
266
+ self.atom_pos[atom.serial-1] = count
267
+ self.groups[-1][0].append(count)
268
+ self.groups[-1][1].append(atom)
269
+ self.group_idxes[atom.serial-1] = len(self.groups)
270
+ count += 1
271
+ if sel_to and res.seqid == sel_to:
272
+ flag = False
273
+ logger.writeln(" id= {} atoms= {}".format(igr, count - n_curr))
274
+
275
+ igr_idxes = {igr:i for i, igr in enumerate(params["groups"])}
276
+ self.consts = [(is_comp, [igr_idxes[g] for g in gids])
277
+ for is_comp, gids in params["const"]]
278
+ self.ncycle = params.get("ncycle", 5)
279
+ # __init__()
280
+
281
+ def constraint(self, x):
282
+ # x: occupancy parameters
283
+ ret = []
284
+ for is_comp, ids in self.consts:
285
+ x_sum = numpy.sum(x[ids])
286
+ if is_comp or x_sum > 1:
287
+ ret.append(x_sum - 1)
288
+ else:
289
+ ret.append(0.)
290
+ return numpy.array(ret)
291
+
292
+ def ensure_constraints(self):
293
+ vals = []
294
+ for _, atoms in self.groups:
295
+ occ = numpy.mean([a.occ for a in atoms])
296
+ vals.append(occ)
297
+ for is_comp, idxes in self.consts:
298
+ sum_occ = sum(vals[i] for i in idxes)
299
+ if not is_comp and sum_occ < 1:
300
+ sum_occ = 1. # do nothing
301
+ for i in idxes:
302
+ #logger.writeln("Imposing constraints: {} {}".format(vals[i], vals[i]/sum_occ))
303
+ vals[i] /= sum_occ
304
+ for occ, (_, atoms) in zip(vals, self.groups):
305
+ for a in atoms: a.occ = occ
306
+
307
+ def get_x(self):
308
+ return numpy.array([atoms[0].occ for _, atoms in self.groups])
309
+
310
+ def set_x(self, x):
311
+ for p, (_, atoms) in zip(x, self.groups):
312
+ for a in atoms:
313
+ a.occ = p
314
+
315
+ def target(self, x, ll, ls, u):
316
+ self.set_x(x)
317
+ ll.update_fc()
318
+ c = self.constraint(x)
319
+ f = ll.calc_target() - numpy.dot(ls, c) + 0.5 * u * numpy.sum(c**2)
320
+ return f
321
+
322
+ def grad(self, x, ll, ls, u, refine_h):
323
+ c = self.constraint(x)
324
+ ll.calc_grad(self.atom_pos, refine_xyz=False, adp_mode=0, refine_occ=True, refine_h=refine_h, specs=None)
325
+ #print("grad=", ll.ll.vn)
326
+ #print("diag=", ll.ll.am)
327
+ assert len(ll.ll.vn) == len(ll.ll.am)
328
+ vn = []
329
+ diag = []
330
+ for idxes, atoms in self.groups:
331
+ if not refine_h:
332
+ idxes = [i for i, a in zip(idxes, atoms) if not a.is_hydrogen()]
333
+ vn.append(numpy.sum(numpy.array(ll.ll.vn)[idxes]))
334
+ diag.append(numpy.sum(numpy.array(ll.ll.am)[idxes]))
335
+ vn, diag = numpy.array(vn), numpy.array(diag)
336
+ for i, (is_comp, idxes) in enumerate(self.consts):
337
+ dcdx = numpy.zeros(len(self.groups))
338
+ dcdx[idxes] = 1.
339
+ if is_comp or c[i] != 0:
340
+ vn -= (ls[i] - u * c[i]) * dcdx
341
+ diag += u * dcdx**2
342
+
343
+ return vn, diag
344
+
345
+ def refine(self, ll, refine_h, alpha=1.1):
346
+ # Refinement of grouped occupancies using augmented Lagrangian
347
+ # f(x) = LL(x) - sum_j (lambda_j c_j(x)) + u/2 sum_j (c_j(x))^2
348
+ # with c_j(x) = 0 constraints
349
+ if not self.groups:
350
+ return
351
+ logger.writeln("\n== Group occupancy refinement ==")
352
+ self.ensure_constraints() # make sure constrained groups have the same occupancies.
353
+ ls = 0 * numpy.ones(len(self.consts)) # Lagrange multiplier
354
+ u = 10000. # penalty parameter. in Refmac 1/0.01**2
355
+ x0 = self.get_x()
356
+ #logger.writeln(" parameters: {}".format(len(x0)))
357
+ f0 = self.target(x0, ll, ls, u)
358
+ ret = []
359
+ for cyc in range(self.ncycle):
360
+ ret.append({"Ncyc": cyc+1, "f0": f0})
361
+ logger.writeln("occ_{}_f0= {:.4e}".format(cyc, f0))
362
+ vn, diag = self.grad(x0, ll, ls, u, refine_h)
363
+ diag[diag < 1e-6] = 1.
364
+ dx = -vn / diag
365
+ if 0:
366
+ ofs = open("debug.dat", "w")
367
+ for scale in (-1, -0.5, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 1, 2):
368
+ self.set_x(x0 + scale * dx)
369
+ ll.update_fc()
370
+ c = self.constraint(x0 + dx)
371
+ f = ll.calc_target() + numpy.dot(ls, c) + 0.5 * u * numpy.sum(c**2)
372
+ ofs.write("{} {}\n".format(scale, f))
373
+ ofs.close()
374
+ import scipy.optimize
375
+ print(scipy.optimize.line_search(f=lambda x: self.target(x, ll, ls, u),
376
+ myfprime= lambda x: self.grad(ll, ls, u, refine_h)[0],
377
+ xk= x0,
378
+ pk= dx))
379
+ quit()
380
+
381
+ scale = 1
382
+ for i in range(3):
383
+ scale = 1/2**i
384
+ f1 = self.target(x0 + dx * scale, ll, ls, u)
385
+ logger.writeln("occ_{}_f1, {}= {:.4e}".format(cyc, i, f1))
386
+ if f1 < f0: break
387
+ else:
388
+ logger.writeln("WARNING: function not minimised")
389
+ #self.set_x(x0) # Refmac accepts it even when function increases
390
+ c = self.constraint(x0 + dx * scale)
391
+ ret[-1]["f1"] = f1
392
+ ret[-1]["shift_scale"] = scale
393
+ f0 = f1
394
+ x0 = x0 + dx * scale
395
+ ls -= u * c
396
+ u = alpha * u
397
+ ret[-1]["const_viol"] = list(c)
398
+ ret[-1]["lambda_new"] = list(ls)
399
+ self.ensure_constraints()
400
+ ll.update_fc()
401
+ f = ll.calc_target()
402
+ logger.writeln("final -LL= {}".format(f))
403
+ return ret
404
+
405
+
406
+ class Refine:
407
+ def __init__(self, st, geom, ll=None, refine_xyz=True, adp_mode=1, refine_h=False, refine_occ=False,
408
+ unrestrained=False, params=None):
409
+ assert adp_mode in (0, 1, 2) # 0=fix, 1=iso, 2=aniso
410
+ assert geom is not None
411
+ self.st = st # clone()?
412
+ self.st_traj = None
413
+ self.atoms = geom.atoms # not a copy
414
+ self.geom = geom
415
+ self.ll = ll
416
+ self.gamma = 0
417
+ self.adp_mode = 0 if self.ll is None else adp_mode
418
+ self.refine_xyz = refine_xyz
419
+ self.refine_occ = refine_occ
420
+ self.use_occr = self.refine_occ # for now?
421
+ self.unrestrained = unrestrained
422
+ self.refine_h = refine_h
423
+ self.h_inherit_parent_adp = self.adp_mode > 0 and not self.refine_h and self.st[0].has_hydrogen()
424
+ if self.h_inherit_parent_adp:
425
+ self.geom.set_h_parents()
426
+ if params and params.get("write_trajectory"):
427
+ self.st_traj = self.st.clone()
428
+ self.st_traj[-1].name = "0"
429
+ assert self.geom.group_occ.groups or self.n_params() > 0
430
+ # __init__()
431
+
432
+ def print_weights(self): # TODO unfinished
433
+ logger.writeln("Geometry weights")
434
+ g = self.geom.geom
435
+ if self.adp_mode > 0:
436
+ logger.writeln(" ADP restraints")
437
+ logger.writeln(" weight: {}".format(self.geom.adpr_w))
438
+ logger.writeln(" mode: {}".format(g.adpr_mode))
439
+ if g.adpr_mode == "diff":
440
+ logger.writeln(" sigmas: {}".format(" ".join("{:.2f}".format(x) for x in g.adpr_diff_sigs)))
441
+ elif g.adpr_mode == "kldiv":
442
+ logger.writeln(" sigmas: {}".format(" ".join("{:.2f}".format(x) for x in g.adpr_kl_sigs)))
443
+ else:
444
+ raise LookupError("unknown adpr_mode")
445
+
446
+ def scale_shifts(self, dx, scale):
447
+ n_atoms = self.geom.n_refine_atoms
448
+ #ave_shift = numpy.mean(dx)
449
+ #max_shift = numpy.maximum(dx)
450
+ #rms_shift = numpy.std(dx)
451
+ shift_allow_high = 1.0
452
+ shift_allow_low = -1.0
453
+ shift_max_allow_B = 30.0
454
+ shift_min_allow_B = -30.0
455
+ shift_max_allow_q = 0.5
456
+ shift_min_allow_q = -0.5
457
+ dx = scale * dx
458
+ offset_b = n_atoms * 3 if self.refine_xyz else 0
459
+ offset_q = offset_b + n_atoms * {0: 0, 1: 1, 2: 6}[self.adp_mode]
460
+ if self.refine_xyz:
461
+ dxx = dx[:offset_b]
462
+ logger.writeln("min(dx) = {}".format(numpy.min(dxx)))
463
+ logger.writeln("max(dx) = {}".format(numpy.max(dxx)))
464
+ logger.writeln("mean(dx)= {}".format(numpy.mean(dxx)))
465
+ dxx[dxx > shift_allow_high] = shift_allow_high
466
+ dxx[dxx < shift_allow_low] = shift_allow_low
467
+ if self.adp_mode == 1:
468
+ dxb = dx[offset_b:offset_q]
469
+ logger.writeln("min(dB) = {}".format(numpy.min(dxb)))
470
+ logger.writeln("max(dB) = {}".format(numpy.max(dxb)))
471
+ logger.writeln("mean(dB)= {}".format(numpy.mean(dxb)))
472
+ dxb[dxb > shift_max_allow_B] = shift_max_allow_B
473
+ dxb[dxb < shift_min_allow_B] = shift_min_allow_B
474
+ elif self.adp_mode == 2:
475
+ dxb = dx[offset_b:offset_q]
476
+ # TODO this is misleading
477
+ logger.writeln("min(dB) = {}".format(numpy.min(dxb)))
478
+ logger.writeln("max(dB) = {}".format(numpy.max(dxb)))
479
+ logger.writeln("mean(dB)= {}".format(numpy.mean(dxb)))
480
+ for i in range(len(dxb)//6):
481
+ j = i * 6
482
+ a = numpy.array([[dxb[j], dxb[j+3], dxb[j+4]],
483
+ [dxb[j+3], dxb[j+1], dxb[j+5]],
484
+ [dxb[j+4], dxb[j+5], dxb[j+2]]])
485
+ v, Q = numpy.linalg.eigh(a)
486
+ v[v > shift_max_allow_B] = shift_max_allow_B
487
+ v[v < shift_min_allow_B] = shift_min_allow_B
488
+ a = Q.dot(numpy.diag(v)).dot(Q.T)
489
+ dxb[j:j+6] = a[0,0], a[1,1], a[2,2], a[0,1], a[0,2], a[1,2]
490
+ if self.refine_occ:
491
+ dxq = dx[offset_q:]
492
+ logger.writeln("min(dq) = {}".format(numpy.min(dxq)))
493
+ logger.writeln("max(dq) = {}".format(numpy.max(dxq)))
494
+ logger.writeln("mean(dq)= {}".format(numpy.mean(dxq)))
495
+ dxq[dxq > shift_max_allow_q] = shift_max_allow_q
496
+ dxq[dxq < shift_min_allow_q] = shift_min_allow_q
497
+
498
+ return dx
499
+
500
+ def n_params(self):
501
+ n_atoms = self.geom.n_refine_atoms
502
+ n_params = 0
503
+ if self.refine_xyz: n_params += 3 * n_atoms
504
+ if self.adp_mode == 1:
505
+ n_params += n_atoms
506
+ elif self.adp_mode == 2:
507
+ n_params += 6 * n_atoms
508
+ if self.refine_occ:
509
+ n_params += n_atoms
510
+ return n_params
511
+
512
+ def set_x(self, x):
513
+ n_atoms = self.geom.n_refine_atoms
514
+ offset_b = n_atoms * 3 if self.refine_xyz else 0
515
+ offset_q = offset_b + n_atoms * {0: 0, 1: 1, 2: 6}[self.adp_mode]
516
+ max_occ = {}
517
+ if self.refine_occ and self.geom.specs:
518
+ max_occ = {atom: 1./(len(images)+1) for atom, images, _, _ in self.geom.specs}
519
+ for i, j in enumerate(self.geom.atom_pos):
520
+ if j < 0: continue
521
+ if self.refine_xyz:
522
+ self.atoms[i].pos.fromlist(x[3*j:3*j+3]) # faster than substituting pos.x,pos.y,pos.z
523
+ if self.adp_mode == 1:
524
+ self.atoms[i].b_iso = max(0.5, x[offset_b + j]) # minimum B = 0.5
525
+ elif self.adp_mode == 2:
526
+ a = x[offset_b + 6 * j: offset_b + 6 * (j+1)]
527
+ a = gemmi.SMat33d(*a)
528
+ M = numpy.array(a.as_mat33())
529
+ v, Q = numpy.linalg.eigh(M) # eig() may return complex due to numerical precision?
530
+ v = numpy.maximum(v, 0.5) # avoid NPD with minimum B = 0.5
531
+ M2 = Q.dot(numpy.diag(v)).dot(Q.T)
532
+ self.atoms[i].b_iso = M2.trace() / 3
533
+ M2 *= b_to_u
534
+ self.atoms[i].aniso = gemmi.SMat33f(M2[0,0], M2[1,1], M2[2,2], M2[0,1], M2[0,2], M2[1,2])
535
+ if self.refine_occ:
536
+ self.atoms[i].occ = min(max_occ.get(self.atoms[i], 1), max(1e-3, x[offset_q + j]))
537
+
538
+ # Copy B of hydrogen from parent
539
+ if self.h_inherit_parent_adp:
540
+ for h in self.geom.parents:
541
+ p = self.geom.parents[h]
542
+ h.b_iso = p.b_iso
543
+ h.aniso = p.aniso
544
+
545
+ if self.ll is not None:
546
+ self.ll.update_fc()
547
+
548
+ self.geom.setup_nonbonded(self.refine_xyz) # if refine_xyz=False, no need to do it every time
549
+ self.geom.geom.setup_target(self.refine_xyz, self.adp_mode, self.refine_occ, self.use_occr)
550
+ logger.writeln("vdws = {}".format(len(self.geom.geom.vdws)))
551
+ logger.writeln(f"atoms = {self.geom.geom.target.n_atoms()}")
552
+ logger.writeln(f"pairs = {self.geom.geom.target.n_pairs()}")
553
+
554
+ def get_x(self):
555
+ n_atoms = self.geom.n_refine_atoms
556
+ offset_b = n_atoms * 3 if self.refine_xyz else 0
557
+ offset_q = offset_b + n_atoms * {0: 0, 1: 1, 2: 6}[self.adp_mode]
558
+ x = numpy.zeros(self.n_params())
559
+ for i, j in enumerate(self.geom.atom_pos):
560
+ if j < 0: continue
561
+ a = self.atoms[i]
562
+ if self.refine_xyz:
563
+ x[3*j:3*(j+1)] = a.pos.tolist()
564
+ if self.adp_mode == 1:
565
+ x[offset_b + j] = self.atoms[i].b_iso
566
+ elif self.adp_mode == 2:
567
+ x[offset_b + 6*j : offset_b + 6*(j+1)] = self.atoms[i].aniso.elements_pdb()
568
+ x[offset_b + 6*j : offset_b + 6*(j+1)] *= u_to_b
569
+ if self.refine_occ:
570
+ x[offset_q + j] = a.occ
571
+
572
+ return x
573
+ #@profile
574
+ def calc_target(self, w=1, target_only=False):
575
+ N = self.n_params()
576
+ geom = self.geom.calc_target(target_only,
577
+ not self.unrestrained and self.refine_xyz,
578
+ self.adp_mode, self.use_occr)
579
+ if self.ll is not None:
580
+ ll = self.ll.calc_target()
581
+ logger.writeln(" ll= {}".format(ll))
582
+ if not target_only:
583
+ self.ll.calc_grad(self.geom.atom_pos, self.refine_xyz, self.adp_mode, self.refine_occ,
584
+ self.refine_h, self.geom.geom.specials)
585
+ else:
586
+ ll = 0
587
+
588
+ f = w * ll + geom
589
+ return f
590
+
591
+ #@profile
592
+ def run_cycle(self, weight=1):
593
+ if 0: # test of grad
594
+ self.ll.update_fc()
595
+ x0 = self.get_x()
596
+ f0,ader,_ = self.calc_target(weight)
597
+ i = 1
598
+ for e in 1e-1,1e-2,1e-3, 1e-4, 1e-5:
599
+ x1 = numpy.copy(x0)
600
+ x1[i] += e
601
+ self.set_x(x1)
602
+ self.ll.update_fc()
603
+ f1,_,_ = self.calc_target(weight, target_only=True)
604
+ nder = (f1 - f0) / e
605
+ print("e=", e)
606
+ print("NUM DER=", nder)
607
+ print("ANA DER=", ader[i])
608
+ print("ratio=", nder/ader[i])
609
+ quit()
610
+
611
+ f0 = self.calc_target(weight)
612
+ x0 = self.get_x()
613
+ logger.writeln("f0= {:.4e}".format(f0))
614
+ if 1:
615
+ use_ic = False # incomplete cholesky. problematic at least in geometry optimisation case
616
+ logger.writeln("using cgsolve in c++, ic={}".format(use_ic))
617
+ cgsolver = ext.CgSolve(self.geom.geom.target, None if self.ll is None else self.ll.ll)
618
+ if use_ic:
619
+ cgsolver.gamma = 0
620
+ cgsolver.max_gamma_cyc = 1
621
+ else:
622
+ cgsolver.gamma = self.gamma
623
+ dx = cgsolver.solve(weight, logger, use_ic)
624
+ self.gamma = cgsolver.gamma
625
+ else:
626
+ logger.writeln("using cgsolve in py")
627
+ am = self.geom.geom.target.am_spmat
628
+ vn = numpy.array(self.geom.geom.target.vn)
629
+ if self.ll is not None:
630
+ am += self.ll.ll.fisher_spmat * weight
631
+ vn += numpy.array(self.ll.ll.vn) * weight
632
+ diag = am.diagonal()
633
+ diag[diag<=0] = 1.
634
+ diag = numpy.sqrt(diag)
635
+ rdiag = 1./diag # sk
636
+ M = scipy.sparse.diags(rdiag)
637
+ dx, self.gamma = cgsolve.cgsolve_rm(A=am, v=vn, M=M, gamma=self.gamma)
638
+
639
+ if 0: # to check hessian scale
640
+ with open("minimise_line.dat", "w") as ofs:
641
+ ofs.write("s f\n")
642
+ for s in numpy.arange(-2, 2, 0.1):
643
+ dx2 = self.scale_shifts(dx, s)
644
+ self.set_x(x0 + dx2)
645
+ fval = self.calc_target(weight, target_only=True)[0]
646
+ ofs.write("{} {}\n".format(s, fval))
647
+ quit()
648
+
649
+ ret = True # success
650
+ shift_scale = 1
651
+ for i in range(3):
652
+ shift_scale = 1/2**i
653
+ dx2 = self.scale_shifts(dx, shift_scale)
654
+ self.set_x(x0 - dx2)
655
+ f1 = self.calc_target(weight, target_only=True)
656
+ logger.writeln("f1, {}= {:.4e}".format(i, f1))
657
+ if f1 < f0: break
658
+ else:
659
+ ret = False
660
+ logger.writeln("WARNING: function not minimised")
661
+ #self.set_x(x0) # Refmac accepts it even when function increases
662
+
663
+ return ret, shift_scale, f1
664
+
665
+ def run_cycles(self, ncycles, weight=1, weight_adjust=False, debug=False,
666
+ weight_adjust_bond_rmsz_range=(0.5, 1.), stats_json_out=None):
667
+ self.print_weights()
668
+ stats = [{"Ncyc": 0}]
669
+ self.geom.setup_nonbonded(self.refine_xyz)
670
+ self.geom.geom.setup_target(self.refine_xyz, self.adp_mode, self.refine_occ, self.use_occr)
671
+ logger.writeln("vdws = {}".format(len(self.geom.geom.vdws)))
672
+ logger.writeln(f"atoms = {self.geom.geom.target.n_atoms()}")
673
+ logger.writeln(f"pairs = {self.geom.geom.target.n_pairs()}")
674
+ stats[-1]["geom"] = self.geom.show_model_stats(refine_xyz=self.refine_xyz and not self.unrestrained,
675
+ adp_mode=self.adp_mode,
676
+ use_occr=self.refine_occ,
677
+ show_outliers=True)
678
+ if self.ll is not None:
679
+ self.ll.update_fc()
680
+ self.ll.overall_scale()
681
+ self.ll.update_ml_params()
682
+ self.ll.prepare_target()
683
+ llstats = self.ll.calc_stats(bin_stats=True)
684
+ stats[-1]["data"] = {"summary": llstats["summary"],
685
+ "binned": llstats["bin_stats"].to_dict(orient="records")}
686
+ if "twin_alpha" in llstats:
687
+ stats[-1]["twin_alpha"] = llstats["twin_alpha"]
688
+ show_binstats(llstats["bin_stats"], 0)
689
+ if self.adp_mode > 0:
690
+ utils.model.adp_analysis(self.st)
691
+ if stats_json_out:
692
+ write_stats_json_safe(stats, stats_json_out)
693
+ occ_refine_flag = self.ll is not None and self.geom.group_occ.groups and self.geom.group_occ.ncycle > 0
694
+
695
+ for i in range(ncycles):
696
+ logger.writeln("\n====== CYCLE {:2d} ======\n".format(i+1))
697
+ logger.writeln(f" weight = {weight:.4e}")
698
+ if self.refine_xyz or self.adp_mode > 0 or self.refine_occ:
699
+ is_ok, shift_scale, fval = self.run_cycle(weight=weight)
700
+ stats.append({"Ncyc": len(stats), "shift_scale": shift_scale, "fval": fval, "fval_decreased": is_ok,
701
+ "weight": weight})
702
+ elif occ_refine_flag:
703
+ stats.append({"Ncyc": len(stats)})
704
+ if occ_refine_flag:
705
+ stats[-1]["occ_refine"] = self.geom.group_occ.refine(self.ll, self.refine_h)
706
+ if debug: utils.fileio.write_model(self.st, "refined_{:02d}".format(i+1), pdb=True)#, cif=True)
707
+ stats[-1]["geom"] = self.geom.show_model_stats(refine_xyz=self.refine_xyz and not self.unrestrained,
708
+ adp_mode=self.adp_mode,
709
+ use_occr=self.refine_occ,
710
+ show_outliers=(i==ncycles-1))
711
+ if self.ll is not None:
712
+ self.ll.overall_scale()
713
+ f0 = self.ll.calc_target()
714
+ self.ll.update_ml_params()
715
+ self.ll.prepare_target()
716
+ llstats = self.ll.calc_stats(bin_stats=True)#(i==ncycles-1))
717
+ if llstats["summary"]["-LL"] > f0:
718
+ logger.writeln("WARNING: -LL has increased after ML parameter optimization:"
719
+ "{} to {}".format(f0, llstats["summary"]["-LL"]))
720
+ stats[-1]["data"] = {"summary": llstats["summary"],
721
+ "binned": llstats["bin_stats"].to_dict(orient="records")}
722
+ if "twin_alpha" in llstats:
723
+ stats[-1]["twin_alpha"] = llstats["twin_alpha"]
724
+ show_binstats(llstats["bin_stats"], i+1)
725
+ if self.adp_mode > 0:
726
+ utils.model.adp_analysis(self.st)
727
+ if (weight_adjust and self.refine_xyz and not self.unrestrained and self.ll is not None and
728
+ len(stats) > 2 and "Bond distances, non H" in stats[-1]["geom"]["summary"].index):
729
+ rmsz = stats[-1]["geom"]["summary"]["r.m.s.Z"]["Bond distances, non H"]
730
+ rmsz0 = stats[-2]["geom"]["summary"]["r.m.s.Z"]["Bond distances, non H"]
731
+ if rmsz > weight_adjust_bond_rmsz_range[1] and rmsz > rmsz0:
732
+ weight /= 1.1
733
+ elif rmsz < weight_adjust_bond_rmsz_range[0] and rmsz0 < weight_adjust_bond_rmsz_range[0] and rmsz < rmsz0:
734
+ weight *= 1.3
735
+ elif rmsz > 1.5 * rmsz0:
736
+ weight /= 1.1
737
+ if self.st_traj is not None:
738
+ self.st_traj.add_model(self.st[0])
739
+ self.st_traj[-1].name = str(len(self.st_traj))
740
+ if stats_json_out:
741
+ write_stats_json_safe(stats, stats_json_out)
742
+
743
+ logger.writeln("")
744
+
745
+ # Make table
746
+ data_keys, geom_keys = set(), set()
747
+ tmp = []
748
+ for d in stats:
749
+ x = {"Ncyc": d["Ncyc"]}
750
+ if "data" in d and "summary" in d["data"]:
751
+ x.update(d["data"]["summary"])
752
+ data_keys.update(d["data"]["summary"])
753
+ if "geom" in d:
754
+ for k, n, l in (("r.m.s.d.", "Bond distances, non H", "rmsBOND"),
755
+ ("r.m.s.Z", "Bond distances, non H", "zBOND"),
756
+ ("r.m.s.d.", "Bond angles, non H", "rmsANGL"),
757
+ ("r.m.s.Z", "Bond angles, non H", "zANGL")):
758
+ if k in d["geom"]["summary"] and n in d["geom"]["summary"][k]:
759
+ x[l] = d["geom"]["summary"][k].get(n)
760
+ geom_keys.add(l)
761
+ tmp.append(x)
762
+ df = pandas.DataFrame(tmp)
763
+ forplot = []
764
+ if "FSCaverage" in data_keys:
765
+ forplot.append(["FSC", ["Ncyc", "FSCaverage"]])
766
+ r_keys = [x for x in data_keys if x.startswith("R")]
767
+ if r_keys:
768
+ forplot.append(["R", ["Ncyc"] + r_keys])
769
+ cc_keys = [x for x in data_keys if x.startswith("CC")]
770
+ if cc_keys:
771
+ forplot.append(["CC", ["Ncyc"] + cc_keys])
772
+ if "-LL" in data_keys:
773
+ forplot.append(["-LL", ["Ncyc", "-LL"]])
774
+ rms_keys = [x for x in geom_keys if x.startswith("rms")]
775
+ if rms_keys:
776
+ forplot.append(["Geometry", ["Ncyc"] + rms_keys])
777
+ z_keys = [x for x in geom_keys if x.startswith("z")]
778
+ if z_keys:
779
+ forplot.append(["Geometry Z", ["Ncyc"] + z_keys])
780
+
781
+ lstr = utils.make_loggraph_str(df, "stats vs cycle", forplot,
782
+ float_format="{:.4f}".format)
783
+ logger.writeln(lstr)
784
+ self.update_meta(stats[-1])
785
+ return stats
786
+
787
+ def update_meta(self, stats):
788
+ # TODO write stats. probably geom.reporting.get_summary_table should return with _refine_ls_restr.type names
789
+ # should remove st.mod_residues?
790
+ self.st.helices.clear()
791
+ self.st.sheets.clear()
792
+ raw_remarks = [f'REMARK 3',
793
+ f'REMARK 3 REFINEMENT.',
794
+ f'REMARK 3 PROGRAM : SERVALCAT {servalcat.__version__}',
795
+ f'REMARK 3 AUTHORS : YAMASHITA,MURSHUDOV',
796
+ f'REMARK 3',
797
+ ]
798
+ si = gemmi.SoftwareItem()
799
+ si.classification = gemmi.SoftwareItem.Classification.Refinement
800
+ si.name = "Servalcat"
801
+ si.version = servalcat.__version__
802
+ si.date = servalcat.__date__
803
+ self.st.meta.software = [si]
804
+
805
+ ri = gemmi.RefinementInfo()
806
+ if "geom" in stats:
807
+ restr_stats = []
808
+ raw_remarks.append("REMARK 3 RMS DEVIATIONS FROM IDEAL VALUES COUNT RMS WEIGHT")
809
+ for k, n, l, pl in (("r.m.s.d.", "Bond distances, non H", "s_bond_nonh_d", "BOND LENGTHS REFINED ATOMS (A)"),
810
+ ("r.m.s.d.", "Bond angles, non H", "s_angle_nonh_d", "BOND ANGLES REFINED ATOMS (DEGREES)")):
811
+ if k in stats["geom"]["summary"] and n in stats["geom"]["summary"][k]:
812
+ rr = gemmi.RefinementInfo.Restr(l)
813
+ rr.dev_ideal = stats["geom"]["summary"][k].get(n)
814
+ rr.count = stats["geom"]["summary"]["N restraints"].get(n)
815
+ rr.weight = stats["geom"]["summary"]["Mn(sigma)"].get(n)
816
+ restr_stats.append(rr)
817
+ raw_remarks.append(f"REMARK 3 {pl}:{rr.count:6d} ;{rr.dev_ideal:6.3f} ;{rr.weight:6.3f}")
818
+ ri.restr_stats = restr_stats
819
+ raw_remarks.append("REMARK 3")
820
+ self.st.meta.refinement = [ri]
821
+ self.st.raw_remarks = raw_remarks
822
+
823
+ # class Refine