clari 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. clari/__init__.py +0 -0
  2. clari/assets/Elemental_Radii_Alvarez.csv +111 -0
  3. clari/assets/posebusters_no_strain.yml +112 -0
  4. clari/chem/__init__.py +10 -0
  5. clari/chem/common.py +79 -0
  6. clari/chem/crystal.py +585 -0
  7. clari/chem/draw.py +152 -0
  8. clari/chem/featurize.py +55 -0
  9. clari/csd.py +1052 -0
  10. clari/datamodules/__init__.py +1 -0
  11. clari/datamodules/csd.py +258 -0
  12. clari/evaluation/__init__.py +1 -0
  13. clari/evaluation/build_test_cifs_cache.py +38 -0
  14. clari/evaluation/collision.py +96 -0
  15. clari/evaluation/compack.py +340 -0
  16. clari/evaluation/compute_energies.py +492 -0
  17. clari/evaluation/results_utils.py +161 -0
  18. clari/evaluation/sample.py +526 -0
  19. clari/evaluation/summarize.py +276 -0
  20. clari/geometry.py +72 -0
  21. clari/inference/README.md +311 -0
  22. clari/inference/SKILL.md +196 -0
  23. clari/inference/__init__.py +7 -0
  24. clari/inference/cli.py +104 -0
  25. clari/inference/export.py +117 -0
  26. clari/inference/io.py +205 -0
  27. clari/inference/rank.py +71 -0
  28. clari/inference/runner.py +256 -0
  29. clari/inference/sampler.py +722 -0
  30. clari/models/__init__.py +1 -0
  31. clari/models/dit.py +214 -0
  32. clari/models/layers/__init__.py +8 -0
  33. clari/models/layers/embedders.py +102 -0
  34. clari/models/layers/transformer.py +129 -0
  35. clari/paths.py +23 -0
  36. clari/pipelines/base/interfaces.py +288 -0
  37. clari/pipelines/base/lit.py +116 -0
  38. clari/pipelines/base/samplers.py +186 -0
  39. clari/pipelines/base/train.py +53 -0
  40. clari/pipelines/utils/__init__.py +4 -0
  41. clari/pipelines/utils/ema.py +391 -0
  42. clari/pipelines/utils/lit.py +74 -0
  43. clari/pipelines/utils/metrics.py +184 -0
  44. clari/pipelines/utils/muon.py +331 -0
  45. clari/pipelines/utils/trainers.py +143 -0
  46. clari/pipelines/utils/utils.py +32 -0
  47. clari/skill.py +11 -0
  48. clari-0.1.0.dist-info/METADATA +351 -0
  49. clari-0.1.0.dist-info/RECORD +52 -0
  50. clari-0.1.0.dist-info/WHEEL +4 -0
  51. clari-0.1.0.dist-info/entry_points.txt +10 -0
  52. clari-0.1.0.dist-info/licenses/LICENSE +407 -0
clari/__init__.py ADDED
File without changes
@@ -0,0 +1,111 @@
1
+ Element Name,Symbol,Period,Group,Block,Atomic Number,Atomic Weight,Covalent Radius,vdW Radius
2
+ Actinium,Ac,7Act,20,f,89,[227],2.15,2.8
3
+ Aluminium,Al,3,13,p,13,26.982,1.21,2.25
4
+ Americium,Am,7Act,20,f,95,[243],1.8,2.83
5
+ Antimony,Sb,5,15,p,51,121.76,1.39,2.47
6
+ Argon,Ar,3,18,p,18,39.948,1.51,1.83
7
+ Arsenic,As,4,15,p,33,74.922,1.21,1.88
8
+ Astatine,At,6,17,p,85,[210],1.21,2
9
+ Barium,Ba,6,2,s,56,137.327,2.15,3.03
10
+ Berkelium,Bk,7Act,20,f,97,[247],1.54,3.4
11
+ Beryllium,Be,2,2,s,4,9.012,0.96,1.98
12
+ Bismuth,Bi,6,15,p,83,208.98,1.48,2.54
13
+ Bohrium,Bh,7,7,d,107,[264],1.5,2
14
+ Boron,B,2,13,p,5,10.811,0.83,1.91
15
+ Bromine,Br,4,17,p,35,79.904,1.21,1.86
16
+ Cadmium,Cd,5,12,d,48,112.411,1.54,2.49
17
+ Caesium,Cs,6,1,s,55,132.905,2.44,3.48
18
+ Calcium,Ca,4,2,s,20,40.078,1.76,2.62
19
+ Californium,Cf,7Act,20,f,98,[251],1.83,3.05
20
+ Carbon,C,2,14,p,6,12.011,0.68,1.77
21
+ Cerium,Ce,6Lan,19,f,58,140.116,2.04,2.88
22
+ Chlorine,Cl,3,17,p,17,35.453,0.99,1.82
23
+ Chromium,Cr,4,6,d,24,51.996,1.39,2.45
24
+ Cobalt,Co,4,9,d,27,58.933,1.26,2.4
25
+ Copper,Cu,4,11,d,29,63.546,1.32,2.38
26
+ Curium,Cm,7Act,20,f,96,[247],1.69,3.05
27
+ Darmstadtium,Ds,7,10,d,110,[271],1.5,2
28
+ Dubnium,Db,7,5,d,105,[262],1.5,2
29
+ Dysprosium,Dy,6Lan,19,f,66,162.5,1.92,2.87
30
+ Einsteinium,Es,7Act,20,f,99,[252],1.5,2.7
31
+ Erbium,Er,6Lan,19,f,68,167.26,1.89,2.83
32
+ Europium,Eu,6Lan,19,f,63,151.964,1.98,2.87
33
+ Fermium,Fm,7Act,20,f,100,[257],1.5,2
34
+ Fluorine,F,2,17,p,9,18.998,0.64,1.46
35
+ Francium,Fr,7,1,s,87,[223],2.6,2
36
+ Gadolinium,Gd,6Lan,19,f,64,157.25,1.96,2.83
37
+ Gallium,Ga,4,13,p,31,69.723,1.22,2.32
38
+ Germanium,Ge,4,14,p,32,72.61,1.17,2.29
39
+ Gold,Au,6,11,d,79,196.967,1.36,2.32
40
+ Hafnium,Hf,6,4,d,72,178.49,1.75,2.63
41
+ Hassium,Hs,7,8,d,108,[269],1.5,2
42
+ Helium,He,1,18,p,2,4.003,1.5,1.43
43
+ Holmium,Ho,6Lan,19,f,67,164.93,1.92,2.81
44
+ Hydrogen,H,1,1,s,1,1.008,0.23,1.2
45
+ Indium,In,5,13,p,49,114.818,1.42,2.43
46
+ Iodine,I,5,17,p,53,126.904,1.4,2.04
47
+ Iridium,Ir,6,9,d,77,192.217,1.41,2.41
48
+ Iron,Fe,4,8,d,26,55.845,1.52,2.44
49
+ Krypton,Kr,4,18,p,36,83.8,1.5,2.25
50
+ Lanthanum,La,6Lan,19,f,57,138.906,2.07,2.98
51
+ Lawrencium,Lr (Lw),7,3,d,103,[262],1.5,2
52
+ Lead,Pb,6,14,p,82,207.2,1.46,2.6
53
+ Lithium,Li,2,1,s,3,6.941,1.28,2.12
54
+ Lutetium,Lu,6,3,d,71,174.967,1.87,2.74
55
+ Magnesium,Mg,3,2,s,12,24.305,1.41,2.51
56
+ Manganese,Mn,4,7,d,25,54.938,1.61,2.45
57
+ Meitnerium,Mt,7,9,d,109,[268],1.5,2
58
+ Mendelevium,Md,7Act,20,f,101,[258],1.5,2
59
+ Mercury,Hg,6,12,d,80,200.59,1.32,2.45
60
+ Molybdenum,Mo,5,6,d,42,95.94,1.54,2.45
61
+ Neodymium,Nd,6Lan,19,f,60,144.24,2.01,2.95
62
+ Neon,Ne,2,18,p,10,20.18,1.5,1.58
63
+ Neptunium,Np,7Act,20,f,93,[237],1.9,2.82
64
+ Nickel,Ni,4,10,d,28,58.693,1.24,2.4
65
+ Niobium,Nb,5,5,d,41,92.906,1.64,2.56
66
+ Nitrogen,N,2,15,p,7,14.007,0.68,1.66
67
+ Nobelium,No,7Act,20,f,102,[259],1.5,2
68
+ Osmium,Os,6,8,d,76,190.23,1.44,2.48
69
+ Oxygen,O,2,16,p,8,15.999,0.68,1.5
70
+ Palladium,Pd,5,10,d,46,106.42,1.39,2.15
71
+ Phosphorus,P,3,15,p,15,30.974,1.05,1.9
72
+ Platinum,Pt,6,10,d,78,195.078,1.36,2.29
73
+ Plutonium,Pu,7Act,20,f,94,[244],1.87,2.81
74
+ Polonium,Po,6,16,p,84,[210],1.4,2
75
+ Potassium,K,4,1,s,19,39.098,2.03,2.73
76
+ Praseodymium,Pr,6Lan,19,f,59,140.908,2.03,2.92
77
+ Promethium,Pm,6Lan,19,f,61,[145],1.99,2
78
+ Protactinium,Pa,7Act,20,f,91,231.036,2,2.88
79
+ Radium,Ra,7,2,s,88,[226],2.21,2
80
+ Radon,Rn,6,18,p,86,[222],1.5,2
81
+ Rhenium,Re,6,7,d,75,186.207,1.51,2.49
82
+ Rhodium,Rh,5,9,d,45,102.906,1.42,2.44
83
+ Rubidium,Rb,5,1,s,37,85.468,2.2,3.21
84
+ Ruthenium,Ru,5,8,d,44,101.07,1.46,2.46
85
+ Rutherfordium,Rf,7,4,d,104,[261],1.5,2
86
+ Samarium,Sm,6Lan,19,f,62,150.36,1.98,2.9
87
+ Scandium,Sc,4,3,d,21,44.956,1.7,2.58
88
+ Seaborgium,Sg,7,6,d,106,[266],1.5,2
89
+ Selenium,Se,4,16,p,34,78.96,1.22,1.82
90
+ Silicon,Si,3,14,p,14,28.086,1.2,2.19
91
+ Silver,Ag,5,11,d,47,107.868,1.45,2.53
92
+ Sodium,Na,3,1,s,11,22.991,1.66,2.5
93
+ Strontium,Sr,5,2,s,38,87.62,1.95,2.84
94
+ Sulphur,S,3,16,p,16,32.066,1.02,1.89
95
+ Tantalum,Ta,6,5,d,73,180.948,1.7,2.53
96
+ Technetium,Tc,5,7,d,43,[98],1.47,2.44
97
+ Tellurium,Te,5,16,p,52,127.6,1.47,1.99
98
+ Terbium,Tb,6Lan,19,f,65,158.925,1.94,2.79
99
+ Thallium,Tl,6,13,p,81,204.383,1.45,2.47
100
+ Thorium,Th,7Act,20,f,90,232.038,2.06,2.93
101
+ Thulium,Tm,6Lan,19,f,69,168.934,1.9,2.79
102
+ Tin,Sn,5,14,p,50,118.71,1.39,2.42
103
+ Titanium,Ti,4,4,d,22,47.867,1.6,2.46
104
+ Tungsten,W,6,6,d,74,183.84,1.62,2.57
105
+ Uranium,U,7Act,20,f,92,238.029,1.96,2.71
106
+ Vanadium,V,4,5,d,23,50.942,1.53,2.42
107
+ Xenon,Xe,5,18,p,54,131.29,1.5,2.06
108
+ Ytterbium,Yb,6Lan,19,f,70,173.04,1.87,2.8
109
+ Yttrium,Y,5,3,d,39,88.906,1.9,2.75
110
+ Zinc,Zn,4,12,d,30,65.39,1.22,2.39
111
+ Zirconium,Zr,5,4,d,40,91.224,1.75,2.52
@@ -0,0 +1,112 @@
1
+ # Setup test modules
2
+ modules:
3
+ - name: "Loading"
4
+ function: loading
5
+ chosen_binary_test_output:
6
+ - mol_pred_loaded
7
+ rename_outputs:
8
+ mol_pred_loaded: "MOL_PRED loaded"
9
+
10
+ - name: "Chemistry"
11
+ function: rdkit_sanity
12
+ chosen_binary_test_output:
13
+ - passes_rdkit_sanity_checks
14
+ rename_outputs:
15
+ passes_rdkit_sanity_checks: "Sanitization"
16
+
17
+ - name: "Chemistry"
18
+ function: inchi_convertible
19
+ chosen_binary_test_output:
20
+ - inchi_convertible
21
+ rename_outputs:
22
+ inchi_convertible: "InChI convertible"
23
+
24
+ - name: "Chemistry"
25
+ function: atoms_connected
26
+ chosen_binary_test_output:
27
+ - all_atoms_connected
28
+ rename_outputs:
29
+ all_atoms_connected: "All atoms connected"
30
+
31
+ - name: "Geometry"
32
+ function: "distance_geometry"
33
+ parameters:
34
+ bound_matrix_params:
35
+ set15bounds: True # topology based bounds also for 1,5- not just until 1,4-
36
+ scaleVDW: True # scale down lower bounds for atoms less than 5 bonds apart
37
+ doTriangleSmoothing: True
38
+ useMacrocycle14config: False
39
+ threshold_bad_bond_length: 0.25 # widens DG bound by this factor
40
+ threshold_bad_angle: 0.25 # widens DG bound by this factor
41
+ threshold_clash: 0.3 # widens DG bound by this factor
42
+ ignore_hydrogens: False # ignore hydrogens
43
+ sanitize: True # sanitize molecule before running DG module (recommended)
44
+ chosen_binary_test_output:
45
+ - bond_lengths_within_bounds
46
+ - bond_angles_within_bounds
47
+ - no_internal_clash
48
+ rename_outputs:
49
+ bond_lengths_within_bounds: "Bond lengths"
50
+ bond_angles_within_bounds: "Bond angles"
51
+ no_internal_clash: "Internal steric clash"
52
+
53
+ - name: "Ring flatness"
54
+ function: "flatness"
55
+ parameters:
56
+ flat_systems: # list atoms which together should lie on plane as SMARTS matches
57
+ aromatic_5_membered_rings_sp2: "[ar5^2]1[ar5^2][ar5^2][ar5^2][ar5^2]1"
58
+ aromatic_6_membered_rings_sp2: "[ar6^2]1[ar6^2][ar6^2][ar6^2][ar6^2][ar6^2]1"
59
+ threshold_flatness: 0.25 # max distance in A to closest shared plane
60
+ chosen_binary_test_output:
61
+ - flatness_passes
62
+ rename_outputs:
63
+ num_systems_checked: number_aromatic_rings_checked
64
+ num_systems_passed: number_aromatic_rings_pass
65
+ max_distance: aromatic_ring_maximum_distance_from_plane
66
+ flatness_passes: "Aromatic ring flatness"
67
+
68
+ - name: "Double bond flatness"
69
+ function: "flatness"
70
+ parameters:
71
+ flat_systems: # list atoms which together should lie on plane as SMARTS matches
72
+ trigonal_planar_double_bonds: "[C;X3;^2](*)(*)=[C;X3;^2](*)(*)"
73
+ threshold_flatness: 0.25 # max distance in A to closest shared plane
74
+ chosen_binary_test_output:
75
+ - flatness_passes
76
+ rename_outputs:
77
+ num_systems_checked: number_double_bonds_checked
78
+ num_systems_passed: number_double_bonds_pass
79
+ max_distance: double_bond_maximum_distance_from_plane
80
+ flatness_passes: "Double bond flatness"
81
+
82
+ # - name: "Energy ratio"
83
+ # function: energy_ratio
84
+ # parameters:
85
+ # threshold_energy_ratio: 100.0
86
+ # ensemble_number_conformations: 50
87
+ # inchi_strict: False
88
+ # chosen_binary_test_output:
89
+ # - energy_ratio_passes
90
+ # rename_outputs:
91
+ # energy_ratio_passes: "Internal energy"
92
+
93
+ # Options for loading molecule files with RDKit
94
+ loading:
95
+ mol_pred:
96
+ cleanup: False
97
+ sanitize: False
98
+ add_hs: False
99
+ assign_stereo: False
100
+ load_all: True
101
+ mol_true:
102
+ cleanup: False
103
+ sanitize: False
104
+ add_hs: False
105
+ assign_stereo: False
106
+ load_all: True
107
+ mol_cond:
108
+ cleanup: False
109
+ sanitize: False
110
+ add_hs: False
111
+ assign_stereo: False
112
+ proximityBonding: False
clari/chem/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ from clari.chem.common import (
2
+ BOND_TO_INDEX,
3
+ INDEX_TO_BOND,
4
+ PTABLE,
5
+ distance_lbound,
6
+ element_radii,
7
+ silenced_rdlogger,
8
+ )
9
+ from clari.chem.crystal import Crystal
10
+ from clari.chem.featurize import ATOM_FEATURES, featurize
clari/chem/common.py ADDED
@@ -0,0 +1,79 @@
1
+ import contextlib
2
+ import pathlib
3
+
4
+ import polars as pl
5
+ import torch
6
+ from rdkit import Chem, RDLogger
7
+ from rdkit.Chem.rdchem import GetPeriodicTable
8
+
9
+ from clari.paths import ASSETS_DIR
10
+
11
+ PTABLE = GetPeriodicTable()
12
+
13
+
14
+ # 0 means no bonds!
15
+ INDEX_TO_BOND = {
16
+ 1: Chem.BondType.SINGLE,
17
+ 2: Chem.BondType.DOUBLE,
18
+ 3: Chem.BondType.TRIPLE,
19
+ 4: Chem.BondType.AROMATIC,
20
+ 5: Chem.BondType.UNSPECIFIED,
21
+ }
22
+
23
+ BOND_TO_INDEX = [5] * 22
24
+ for i, btype in INDEX_TO_BOND.items():
25
+ BOND_TO_INDEX[int(btype)] = i
26
+
27
+
28
+ @contextlib.contextmanager
29
+ def silenced_rdlogger():
30
+ logger = RDLogger.logger()
31
+ logger.setLevel(RDLogger.CRITICAL)
32
+ try:
33
+ yield
34
+ finally:
35
+ logger.setLevel(RDLogger.INFO)
36
+
37
+
38
+ def xyzfile(atoms, coords):
39
+ file = [f"{len(atoms)}\n"]
40
+ for a, p in zip(atoms, coords, strict=False):
41
+ symbol = PTABLE.GetElementSymbol(int(a))
42
+ x, y, z = p.tolist()
43
+ file.append(f"{symbol} {x:f} {y:f} {z:f}")
44
+ return "\n".join(file)
45
+
46
+
47
+ # Downloaded from: https://www.ccdc.cam.ac.uk/media/Elemental_Radii_Alvarez.xlsx
48
+ def load_radii():
49
+ radii = {"vdw": torch.full([118], 2.0), "cov": torch.full([118], 1.5)}
50
+ df = pl.read_csv(pathlib.Path(ASSETS_DIR / "Elemental_Radii_Alvarez.csv"))
51
+ for row in df.iter_rows(named=True):
52
+ radii["vdw"][row["Atomic Number"]] = row["vdW Radius"]
53
+ radii["cov"][row["Atomic Number"]] = row["Covalent Radius"]
54
+ return radii
55
+
56
+
57
+ RADII_CACHE = load_radii()
58
+
59
+
60
+ def element_radii(z, rtype):
61
+ radii = RADII_CACHE[rtype]
62
+ if isinstance(z, int):
63
+ r = radii[z].item()
64
+ else:
65
+ r = radii.to(z)[z]
66
+ return r
67
+
68
+
69
+ def distance_lbound(z1, z2, bond_mask):
70
+ nonmetals = torch.tensor([1, 2, 6, 7, 8, 9, 10, 15, 16, 17, 18, 34, 35, 36, 53, 54]).to(z1)
71
+ is_metal1 = ~torch.isin(z1, nonmetals)
72
+ is_metal2 = ~torch.isin(z2, nonmetals)
73
+ metal_mask = is_metal1.unsqueeze(-1) | is_metal2.unsqueeze(-2)
74
+
75
+ r1 = element_radii(z1, "cov")
76
+ r2 = element_radii(z2, "cov")
77
+ lb = r1.unsqueeze(-1) + r2.unsqueeze(-2)
78
+ lb = torch.clip(torch.where(bond_mask | metal_mask, 0.6, 1.0) * lb, min=0.5)
79
+ return lb