torchani 0.6__tar.gz → 0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__init__.py +12 -4
  2. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/__init__.cpython-37.pyc +0 -0
  3. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/_six.cpython-37.pyc +0 -0
  4. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/aev.cpython-37.pyc +0 -0
  5. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/ase.cpython-37.pyc +0 -0
  6. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/ignite.cpython-37.pyc +0 -0
  7. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/models.cpython-37.pyc +0 -0
  8. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/nn.cpython-37.pyc +0 -0
  9. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/optim.cpython-37.pyc +0 -0
  10. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/__pycache__/utils.cpython-37.pyc +0 -0
  11. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/data/__pycache__/__init__.cpython-37.pyc +0 -0
  12. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/data/__pycache__/_pyanitools.cpython-37.pyc +0 -0
  13. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/data/__pycache__/cache_aev.cpython-37.pyc +0 -0
  14. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/ignite.py +8 -3
  15. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/neurochem/__init__.py +93 -53
  16. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/neurochem/__pycache__/__init__.cpython-37.pyc +0 -0
  17. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/neurochem/__pycache__/_six.cpython-37.pyc +0 -0
  18. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/neurochem/__pycache__/trainer.cpython-37.pyc +0 -0
  19. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/optim.py +113 -0
  20. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani/utils.py +5 -3
  21. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/{torchani-0.6-py3.7.egg-info → torchani-0.7-py3.7.egg-info}/PKG-INFO +1 -1
  22. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/{torchani-0.6-py3.7.egg-info → torchani-0.7-py3.7.egg-info}/SOURCES.txt +3 -90
  23. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani-0.7-py3.7.egg-info/requires.txt +2 -0
  24. opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/torchani-0.6-py3.7.egg-info/requires.txt +0 -4
  25. /opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/{torchani-0.6-py3.7.egg-info → torchani-0.7-py3.7.egg-info}/dependency_links.txt +0 -0
  26. /opt/hostedtoolcache/Python/3.7.2/x64/lib/python3.7/site-packages/{torchani-0.6-py3.7.egg-info → torchani-0.7-py3.7.egg-info}/top_level.txt +0 -0
@@ -30,11 +30,9 @@ from .aev import AEVComputer
30
30
  from . import utils
31
31
  from . import neurochem
32
32
  from . import models
33
+ from . import optim
33
34
  from pkg_resources import get_distribution, DistributionNotFound
34
35
  import sys
35
- if sys.version_info[0] > 2:
36
- from . import ignite
37
- from . import data
38
36
 
39
37
  try:
40
38
  __version__ = get_distribution(__name__).version
@@ -43,10 +41,20 @@ except DistributionNotFound:
43
41
  pass
44
42
 
45
43
  __all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble',
46
- 'ignite', 'utils', 'neurochem', 'data', 'models']
44
+ 'utils', 'neurochem', 'models', 'optim']
47
45
 
48
46
  try:
49
47
  from . import ase # noqa: F401
50
48
  __all__.append('ase')
51
49
  except ImportError:
52
50
  pass
51
+
52
+
53
+ if sys.version_info[0] > 2:
54
+ try:
55
+ from . import ignite # noqa: F401
56
+ __all__.append('ignite')
57
+ from . import data # noqa: F401
58
+ __all__.append('data')
59
+ except ImportError:
60
+ pass
@@ -4,7 +4,7 @@ from __future__ import absolute_import
4
4
  import torch
5
5
  from . import utils
6
6
  from torch.nn.modules.loss import _Loss
7
- from ignite.metrics import Metric, RootMeanSquaredError
7
+ from ignite.metrics import Metric, RootMeanSquaredError, MeanAbsoluteError
8
8
  from ignite.contrib.metrics.regression import MaximumAbsoluteError
9
9
 
10
10
 
@@ -111,10 +111,15 @@ def RMSEMetric(key):
111
111
  return DictMetric(key, RootMeanSquaredError())
112
112
 
113
113
 
114
- def MAEMetric(key):
114
+ def MaxAEMetric(key):
115
115
  """Create max absolute error metric on key."""
116
116
  return DictMetric(key, MaximumAbsoluteError())
117
117
 
118
118
 
119
+ def MAEMetric(key):
120
+ """Create max absolute error metric on key."""
121
+ return DictMetric(key, MeanAbsoluteError())
122
+
123
+
119
124
  __all__ = ['Container', 'MSELoss', 'TransformedLoss', 'RMSEMetric',
120
- 'MAEMetric']
125
+ 'MaxAEMetric']
@@ -8,7 +8,6 @@ import bz2
8
8
  import lark
9
9
  import struct
10
10
  import itertools
11
- import ignite
12
11
  import math
13
12
  import timeit
14
13
  from . import _six # noqa:F401
@@ -17,7 +16,9 @@ import sys
17
16
  from ..nn import ANIModel, Ensemble, Gaussian
18
17
  from ..utils import EnergyShifter, ChemicalSymbolsToInts
19
18
  from ..aev import AEVComputer
20
- from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
19
+ from ..optim import AdamW
20
+ import warnings
21
+ import textwrap
21
22
 
22
23
 
23
24
  class Constants(collections.abc.Mapping):
@@ -380,8 +381,6 @@ def hartree2kcal(x):
380
381
 
381
382
 
382
383
  if sys.version_info[0] > 2:
383
- from ..data import BatchedANIDataset # noqa: E402
384
- from ..data import AEVCacheLoader # noqa: E402
385
384
 
386
385
  class Trainer:
387
386
  """Train with NeuroChem training configurations.
@@ -391,7 +390,7 @@ if sys.version_info[0] > 2:
391
390
  device (:class:`torch.device`): device to train the model
392
391
  tqdm (bool): whether to enable tqdm
393
392
  tensorboard (str): Directory to store tensorboard log file, set to
394
- ``None`` to disable tensorboardX.
393
+ ``None`` to disable tensorboard.
395
394
  aev_caching (bool): Whether to use AEV caching.
396
395
  checkpoint_name (str): Name of the checkpoint file, checkpoints
397
396
  will be stored in the network directory with this file name.
@@ -400,18 +399,46 @@ if sys.version_info[0] > 2:
400
399
  def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
401
400
  tensorboard=None, aev_caching=False,
402
401
  checkpoint_name='model.pt'):
402
+ try:
403
+ import ignite
404
+ from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric, MaxAEMetric
405
+ from ..data import BatchedANIDataset # noqa: E402
406
+ from ..data import AEVCacheLoader # noqa: E402
407
+ except ImportError:
408
+ raise RuntimeError(
409
+ 'NeuroChem Trainer requires ignite,'
410
+ 'please install pytorch-ignite-nightly from PYPI')
411
+
412
+ self.ignite = ignite
413
+
414
+ class dummy:
415
+ pass
416
+
417
+ self.imports = dummy()
418
+ self.imports.Container = Container
419
+ self.imports.MSELoss = MSELoss
420
+ self.imports.TransformedLoss = TransformedLoss
421
+ self.imports.RMSEMetric = RMSEMetric
422
+ self.imports.MaxAEMetric = MaxAEMetric
423
+ self.imports.MAEMetric = MAEMetric
424
+ self.imports.BatchedANIDataset = BatchedANIDataset
425
+ self.imports.AEVCacheLoader = AEVCacheLoader
426
+
427
+ self.warned = False
428
+
403
429
  self.filename = filename
404
430
  self.device = device
405
431
  self.aev_caching = aev_caching
406
432
  self.checkpoint_name = checkpoint_name
433
+ self.parameters = []
407
434
  if tqdm:
408
435
  import tqdm
409
436
  self.tqdm = tqdm.tqdm
410
437
  else:
411
438
  self.tqdm = None
412
439
  if tensorboard is not None:
413
- import tensorboardX
414
- self.tensorboard = tensorboardX.SummaryWriter(
440
+ import torch.utils.tensorboard
441
+ self.tensorboard = torch.utils.tensorboard.SummaryWriter(
415
442
  log_dir=tensorboard)
416
443
  self.training_eval_every = 20
417
444
  else:
@@ -591,7 +618,6 @@ if sys.version_info[0] > 2:
591
618
  input_size, network_setup = network_setup
592
619
  if input_size != self.aev_computer.aev_length:
593
620
  raise ValueError('AEV size and input size does not match')
594
- l2reg = []
595
621
  atomic_nets = {}
596
622
  for atom_type in network_setup:
597
623
  layers = network_setup[atom_type]
@@ -610,19 +636,31 @@ if sys.version_info[0] > 2:
610
636
  modules.append(activation)
611
637
  del layer['activation']
612
638
  if 'l2norm' in layer:
639
+ if not self.warned:
640
+ warnings.warn(textwrap.dedent("""
641
+ Currently TorchANI training with weight decay can not reproduce the training
642
+ result of NeuroChem with the same training setup. If you really want to use
643
+ weight decay, consider smaller rates and and make sure you do enough validation
644
+ to check if you get expected result."""))
645
+ self.warned = True
613
646
  if layer['l2norm'] == 1:
614
- # NB: The "L2" implemented in NeuroChem is actually
615
- # not L2 but weight decay. The difference of these
616
- # two is:
617
- # https://arxiv.org/pdf/1711.05101.pdf
618
- # There is a pull request on github/pytorch
619
- # implementing AdamW, etc.:
620
- # https://github.com/pytorch/pytorch/pull/4429
621
- # There is no plan to support the "L2" settings in
622
- # input file before AdamW get merged into pytorch.
623
- raise NotImplementedError('L2 not supported yet')
647
+ self.parameters.append({
648
+ 'params': [module.weight],
649
+ 'weight_decay': layer['l2valu'],
650
+ })
651
+ self.parameters.append({
652
+ 'params': [module.bias],
653
+ })
654
+ else:
655
+ self.parameters.append({
656
+ 'params': module.parameters(),
657
+ })
624
658
  del layer['l2norm']
625
659
  del layer['l2valu']
660
+ else:
661
+ self.parameters.append({
662
+ 'params': module.parameters(),
663
+ })
626
664
  if layer:
627
665
  raise ValueError(
628
666
  'unrecognized parameter in layer setup')
@@ -634,16 +672,13 @@ if sys.version_info[0] > 2:
634
672
  self.nnp = self.model
635
673
  else:
636
674
  self.nnp = torch.nn.Sequential(self.aev_computer, self.model)
637
- self.container = Container({'energies': self.nnp}).to(self.device)
675
+ self.container = self.imports.Container({'energies': self.nnp}).to(self.device)
638
676
 
639
677
  # losses
640
- def l2():
641
- return sum([c * (m.weight ** 2).sum() for c, m in l2reg])
642
- self.mse_loss = TransformedLoss(MSELoss('energies'),
643
- lambda x: x + l2())
644
- self.exp_loss = TransformedLoss(
645
- MSELoss('energies'),
646
- lambda x: 0.5 * (torch.exp(2 * x) - 1) + l2())
678
+ self.mse_loss = self.imports.MSELoss('energies')
679
+ self.exp_loss = self.imports.TransformedLoss(
680
+ self.imports.MSELoss('energies'),
681
+ lambda x: 0.5 * (torch.exp(2 * x) - 1))
647
682
 
648
683
  if params:
649
684
  raise ValueError('unrecognized parameter')
@@ -653,17 +688,18 @@ if sys.version_info[0] > 2:
653
688
  self.best_validation_rmse = math.inf
654
689
 
655
690
  def evaluate(self, dataset):
656
- """Evaluate on given dataset to compute RMSE and MAE."""
657
- evaluator = ignite.engine.create_supervised_evaluator(
691
+ """Evaluate on given dataset to compute RMSE and MaxAE."""
692
+ evaluator = self.ignite.engine.create_supervised_evaluator(
658
693
  self.container,
659
694
  metrics={
660
- 'RMSE': RMSEMetric('energies'),
661
- 'MAE': MAEMetric('energies'),
695
+ 'RMSE': self.imports.RMSEMetric('energies'),
696
+ 'MAE': self.imports.MAEMetric('energies'),
697
+ 'MaxAE': self.imports.MaxAEMetric('energies'),
662
698
  }
663
699
  )
664
700
  evaluator.run(dataset)
665
701
  metrics = evaluator.state.metrics
666
- return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE'])
702
+ return hartree2kcal(metrics['RMSE']), hartree2kcal(metrics['MAE']), hartree2kcal(metrics['MaxAE'])
667
703
 
668
704
  def load_data(self, training_path, validation_path):
669
705
  """Load training and validation dataset from file.
@@ -672,14 +708,14 @@ if sys.version_info[0] > 2:
672
708
  directory, otherwise it should be path to the dataset.
673
709
  """
674
710
  if self.aev_caching:
675
- self.training_set = AEVCacheLoader(training_path)
676
- self.validation_set = AEVCacheLoader(validation_path)
711
+ self.training_set = self.imports.AEVCacheLoader(training_path)
712
+ self.validation_set = self.imports.AEVCacheLoader(validation_path)
677
713
  else:
678
- self.training_set = BatchedANIDataset(
714
+ self.training_set = self.imports.BatchedANIDataset(
679
715
  training_path, self.consts.species_to_tensor,
680
716
  self.training_batch_size, device=self.device,
681
717
  transform=[self.shift_energy.subtract_from_dataset])
682
- self.validation_set = BatchedANIDataset(
718
+ self.validation_set = self.imports.BatchedANIDataset(
683
719
  validation_path, self.consts.species_to_tensor,
684
720
  self.validation_batch_size, device=self.device,
685
721
  transform=[self.shift_energy.subtract_from_dataset])
@@ -690,40 +726,40 @@ if sys.version_info[0] > 2:
690
726
 
691
727
  def decorate(trainer):
692
728
 
693
- @trainer.on(ignite.engine.Events.STARTED)
729
+ @trainer.on(self.ignite.engine.Events.STARTED)
694
730
  def initialize(trainer):
695
731
  trainer.state.no_improve_count = 0
696
732
  trainer.state.epoch += self.global_epoch
697
733
  trainer.state.iteration += self.global_iteration
698
734
 
699
- @trainer.on(ignite.engine.Events.COMPLETED)
735
+ @trainer.on(self.ignite.engine.Events.COMPLETED)
700
736
  def finalize(trainer):
701
737
  self.global_epoch = trainer.state.epoch
702
738
  self.global_iteration = trainer.state.iteration
703
739
 
704
740
  if self.nmax > 0:
705
- @trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
741
+ @trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
706
742
  def terminate_when_nmax_reaches(trainer):
707
743
  if trainer.state.epoch >= self.nmax:
708
744
  trainer.terminate()
709
745
 
710
746
  if self.tqdm is not None:
711
- @trainer.on(ignite.engine.Events.EPOCH_STARTED)
747
+ @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
712
748
  def init_tqdm(trainer):
713
749
  trainer.state.tqdm = self.tqdm(
714
750
  total=len(self.training_set), desc='epoch')
715
751
 
716
- @trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
752
+ @trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
717
753
  def update_tqdm(trainer):
718
754
  trainer.state.tqdm.update(1)
719
755
 
720
- @trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
756
+ @trainer.on(self.ignite.engine.Events.EPOCH_COMPLETED)
721
757
  def finalize_tqdm(trainer):
722
758
  trainer.state.tqdm.close()
723
759
 
724
- @trainer.on(ignite.engine.Events.EPOCH_STARTED)
760
+ @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
725
761
  def validation_and_checkpoint(trainer):
726
- trainer.state.rmse, trainer.state.mae = \
762
+ trainer.state.rmse, trainer.state.mae, trainer.state.maxae = \
727
763
  self.evaluate(self.validation_set)
728
764
  if trainer.state.rmse < self.best_validation_rmse:
729
765
  trainer.state.no_improve_count = 0
@@ -737,7 +773,7 @@ if sys.version_info[0] > 2:
737
773
  trainer.terminate()
738
774
 
739
775
  if self.tensorboard is not None:
740
- @trainer.on(ignite.engine.Events.EPOCH_STARTED)
776
+ @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
741
777
  def log_per_epoch(trainer):
742
778
  elapsed = round(timeit.default_timer() - start, 2)
743
779
  epoch = trainer.state.epoch
@@ -749,6 +785,8 @@ if sys.version_info[0] > 2:
749
785
  trainer.state.rmse, epoch)
750
786
  self.tensorboard.add_scalar('validation_mae_vs_epoch',
751
787
  trainer.state.mae, epoch)
788
+ self.tensorboard.add_scalar('validation_maxae_vs_epoch',
789
+ trainer.state.maxae, epoch)
752
790
  self.tensorboard.add_scalar(
753
791
  'best_validation_rmse_vs_epoch',
754
792
  self.best_validation_rmse, epoch)
@@ -756,16 +794,18 @@ if sys.version_info[0] > 2:
756
794
  'no_improve_count_vs_epoch',
757
795
  trainer.state.no_improve_count, epoch)
758
796
 
759
- # compute training RMSE and MAE
797
+ # compute training RMSE, MAE and MaxAE
760
798
  if epoch % self.training_eval_every == 1:
761
- training_rmse, training_mae = \
799
+ training_rmse, training_mae, training_maxae = \
762
800
  self.evaluate(self.training_set)
763
801
  self.tensorboard.add_scalar(
764
802
  'training_rmse_vs_epoch', training_rmse, epoch)
765
803
  self.tensorboard.add_scalar(
766
804
  'training_mae_vs_epoch', training_mae, epoch)
805
+ self.tensorboard.add_scalar(
806
+ 'training_mae_vs_epoch', training_maxae, epoch)
767
807
 
768
- @trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
808
+ @trainer.on(self.ignite.engine.Events.ITERATION_COMPLETED)
769
809
  def log_loss(trainer):
770
810
  iteration = trainer.state.iteration
771
811
  loss = trainer.state.output
@@ -776,21 +816,21 @@ if sys.version_info[0] > 2:
776
816
 
777
817
  # training using mse loss first until the validation MAE decrease
778
818
  # to < 1 Hartree
779
- optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
780
- trainer = ignite.engine.create_supervised_trainer(
819
+ optimizer = AdamW(self.parameters, lr=lr)
820
+ trainer = self.ignite.engine.create_supervised_trainer(
781
821
  self.container, optimizer, self.mse_loss)
782
822
  decorate(trainer)
783
823
 
784
- @trainer.on(ignite.engine.Events.EPOCH_STARTED)
824
+ @trainer.on(self.ignite.engine.Events.EPOCH_STARTED)
785
825
  def terminate_if_smaller_enough(trainer):
786
- if trainer.state.mae < 1.0:
826
+ if trainer.state.rmse < 10.0:
787
827
  trainer.terminate()
788
828
 
789
829
  trainer.run(self.training_set, max_epochs=math.inf)
790
830
 
791
831
  while lr > self.min_lr:
792
- optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
793
- trainer = ignite.engine.create_supervised_trainer(
832
+ optimizer = AdamW(self.parameters, lr=lr)
833
+ trainer = self.ignite.engine.create_supervised_trainer(
794
834
  self.container, optimizer, self.exp_loss)
795
835
  decorate(trainer)
796
836
  trainer.run(self.training_set, max_epochs=math.inf)
@@ -0,0 +1,113 @@
1
+ """AdamW implementation"""
2
+ import math
3
+ import torch
4
+ from torch.optim.optimizer import Optimizer
5
+
6
+
7
+ # Copied and modified from: https://github.com/pytorch/pytorch/pull/4429
8
+ class AdamW(Optimizer):
9
+ r"""Implements AdamW algorithm.
10
+
11
+ It has been proposed in `Decoupled Weight Decay Regularization`_.
12
+
13
+ Arguments:
14
+ params (iterable): iterable of parameters to optimize or dicts defining
15
+ parameter groups
16
+ lr (float, optional): learning rate (default: 1e-3)
17
+ betas (Tuple[float, float], optional): coefficients used for computing
18
+ running averages of gradient and its square (default: (0.9, 0.999))
19
+ eps (float, optional): term added to the denominator to improve
20
+ numerical stability (default: 1e-8)
21
+ weight_decay (float, optional): weight decay factor (default: 0)
22
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
23
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
24
+ (default: False)
25
+
26
+ .. _Adam\: A Method for Stochastic Optimization:
27
+ https://arxiv.org/abs/1412.6980
28
+ .. _Decoupled Weight Decay Regularization:
29
+ https://arxiv.org/abs/1711.05101
30
+ .. _On the Convergence of Adam and Beyond:
31
+ https://openreview.net/forum?id=ryQu7f-RZ
32
+ """
33
+
34
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
35
+ weight_decay=0, amsgrad=False):
36
+ if not 0.0 <= lr:
37
+ raise ValueError("Invalid learning rate: {}".format(lr))
38
+ if not 0.0 <= eps:
39
+ raise ValueError("Invalid epsilon value: {}".format(eps))
40
+ if not 0.0 <= betas[0] < 1.0:
41
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
42
+ if not 0.0 <= betas[1] < 1.0:
43
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
44
+ defaults = dict(lr=lr, betas=betas, eps=eps,
45
+ weight_decay=weight_decay, amsgrad=amsgrad)
46
+ super(AdamW, self).__init__(params, defaults)
47
+
48
+ def __setstate__(self, state):
49
+ super(AdamW, self).__setstate__(state)
50
+ for group in self.param_groups:
51
+ group.setdefault('amsgrad', False)
52
+
53
+ def step(self, closure=None):
54
+ """Performs a single optimization step.
55
+
56
+ Arguments:
57
+ closure (callable, optional): A closure that reevaluates the model
58
+ and returns the loss.
59
+ """
60
+ loss = None
61
+ if closure is not None:
62
+ loss = closure()
63
+
64
+ for group in self.param_groups:
65
+ for p in group['params']:
66
+ if p.grad is None:
67
+ continue
68
+ grad = p.grad.data
69
+ if grad.is_sparse:
70
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
71
+ amsgrad = group['amsgrad']
72
+
73
+ state = self.state[p]
74
+
75
+ # State initialization
76
+ if len(state) == 0:
77
+ state['step'] = 0
78
+ # Exponential moving average of gradient values
79
+ state['exp_avg'] = torch.zeros_like(p.data)
80
+ # Exponential moving average of squared gradient values
81
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
82
+ if amsgrad:
83
+ # Maintains max of all exp. moving avg. of sq. grad. values
84
+ state['max_exp_avg_sq'] = torch.zeros_like(p.data)
85
+
86
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
87
+ if amsgrad:
88
+ max_exp_avg_sq = state['max_exp_avg_sq']
89
+ beta1, beta2 = group['betas']
90
+
91
+ state['step'] += 1
92
+
93
+ # Decay the first and second moment running average coefficient
94
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
95
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
96
+ if amsgrad:
97
+ # Maintains the maximum of all 2nd moment running avg. till now
98
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
99
+ # Use the max. for normalizing running avg. of gradient
100
+ denom = max_exp_avg_sq.sqrt().add_(group['eps'])
101
+ else:
102
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
103
+
104
+ bias_correction1 = 1 - beta1 ** state['step']
105
+ bias_correction2 = 1 - beta2 ** state['step']
106
+ step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
107
+
108
+ p.data.addcdiv_(-step_size, exp_avg, denom)
109
+
110
+ if group['weight_decay'] != 0:
111
+ p.data.add_(-group['weight_decay'], p.data)
112
+
113
+ return loss
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ import torch.utils.data
2
3
  import math
3
4
 
4
5
 
@@ -247,18 +248,19 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
247
248
  # We solve this eigenvalue problem through Lowdin diagnolization:
248
249
  # Hq = w^2 * Tq ==> Hq = w^2 * T^(1/2) T^(1/2) q
249
250
  # Letting q' = T^(1/2) q, we then have
250
- # T^(-1/2) H T^(1/2) q' = w^2 * q'
251
+ # T^(-1/2) H T^(-1/2) q' = w^2 * q'
251
252
  inv_sqrt_mass = (1 / masses.sqrt()).repeat_interleave(3, dim=1) # shape (molecule, 3 * atoms)
252
253
  mass_scaled_hessian = hessian * inv_sqrt_mass.unsqueeze(1) * inv_sqrt_mass.unsqueeze(2)
253
254
  if mass_scaled_hessian.shape[0] != 1:
254
255
  raise ValueError('The input should contain only one molecule')
255
256
  mass_scaled_hessian = mass_scaled_hessian.squeeze(0)
256
- eigenvalues = torch.symeig(mass_scaled_hessian).eigenvalues
257
+ eigenvalues, eigenvectors = torch.symeig(mass_scaled_hessian, eigenvectors=True)
257
258
  angular_frequencies = eigenvalues.sqrt()
258
259
  frequencies = angular_frequencies / (2 * math.pi)
259
260
  # converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
260
261
  wavenumbers = frequencies * 17092
261
- return wavenumbers
262
+ modes = (eigenvectors.t() * inv_sqrt_mass).reshape(frequencies.numel(), -1, 3)
263
+ return wavenumbers, modes
262
264
 
263
265
 
264
266
  __all__ = ['pad', 'pad_coordinates', 'present_species', 'hessian',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 1.0
2
2
  Name: torchani
3
- Version: 0.6
3
+ Version: 0.7
4
4
  Summary: PyTorch implementation of ANI
5
5
  Home-page: https://github.com/zasdfgbnm/torchani
6
6
  Author: Xiang Gao
@@ -42,6 +42,7 @@ examples/energy_force.py
42
42
  examples/load_from_neurochem.py
43
43
  examples/neurochem_trainer.py
44
44
  examples/nnp_training.py
45
+ examples/nnp_training_ignite.py
45
46
  examples/vibration_analysis.py
46
47
  tests/test_aev.py
47
48
  tests/test_ase.py
@@ -159,104 +160,14 @@ tests/test_data/NIST/all
159
160
  tests/test_data/NeuroChemOptimized/all
160
161
  tests/test_data/benzene-md/0.dat
161
162
  tests/test_data/benzene-md/1.dat
162
- tests/test_data/benzene-md/10.dat
163
- tests/test_data/benzene-md/11.dat
164
- tests/test_data/benzene-md/12.dat
165
- tests/test_data/benzene-md/13.dat
166
- tests/test_data/benzene-md/14.dat
167
- tests/test_data/benzene-md/15.dat
168
- tests/test_data/benzene-md/16.dat
169
- tests/test_data/benzene-md/17.dat
170
- tests/test_data/benzene-md/18.dat
171
- tests/test_data/benzene-md/19.dat
172
163
  tests/test_data/benzene-md/2.dat
173
- tests/test_data/benzene-md/20.dat
174
- tests/test_data/benzene-md/21.dat
175
- tests/test_data/benzene-md/22.dat
176
- tests/test_data/benzene-md/23.dat
177
- tests/test_data/benzene-md/24.dat
178
- tests/test_data/benzene-md/25.dat
179
- tests/test_data/benzene-md/26.dat
180
- tests/test_data/benzene-md/27.dat
181
- tests/test_data/benzene-md/28.dat
182
- tests/test_data/benzene-md/29.dat
183
164
  tests/test_data/benzene-md/3.dat
184
- tests/test_data/benzene-md/30.dat
185
- tests/test_data/benzene-md/31.dat
186
- tests/test_data/benzene-md/32.dat
187
- tests/test_data/benzene-md/33.dat
188
- tests/test_data/benzene-md/34.dat
189
- tests/test_data/benzene-md/35.dat
190
- tests/test_data/benzene-md/36.dat
191
- tests/test_data/benzene-md/37.dat
192
- tests/test_data/benzene-md/38.dat
193
- tests/test_data/benzene-md/39.dat
194
165
  tests/test_data/benzene-md/4.dat
195
- tests/test_data/benzene-md/40.dat
196
- tests/test_data/benzene-md/41.dat
197
- tests/test_data/benzene-md/42.dat
198
- tests/test_data/benzene-md/43.dat
199
- tests/test_data/benzene-md/44.dat
200
- tests/test_data/benzene-md/45.dat
201
- tests/test_data/benzene-md/46.dat
202
- tests/test_data/benzene-md/47.dat
203
- tests/test_data/benzene-md/48.dat
204
- tests/test_data/benzene-md/49.dat
205
166
  tests/test_data/benzene-md/5.dat
206
- tests/test_data/benzene-md/50.dat
207
- tests/test_data/benzene-md/51.dat
208
- tests/test_data/benzene-md/52.dat
209
- tests/test_data/benzene-md/53.dat
210
- tests/test_data/benzene-md/54.dat
211
- tests/test_data/benzene-md/55.dat
212
- tests/test_data/benzene-md/56.dat
213
- tests/test_data/benzene-md/57.dat
214
- tests/test_data/benzene-md/58.dat
215
- tests/test_data/benzene-md/59.dat
216
167
  tests/test_data/benzene-md/6.dat
217
- tests/test_data/benzene-md/60.dat
218
- tests/test_data/benzene-md/61.dat
219
- tests/test_data/benzene-md/62.dat
220
- tests/test_data/benzene-md/63.dat
221
- tests/test_data/benzene-md/64.dat
222
- tests/test_data/benzene-md/65.dat
223
- tests/test_data/benzene-md/66.dat
224
- tests/test_data/benzene-md/67.dat
225
- tests/test_data/benzene-md/68.dat
226
- tests/test_data/benzene-md/69.dat
227
168
  tests/test_data/benzene-md/7.dat
228
- tests/test_data/benzene-md/70.dat
229
- tests/test_data/benzene-md/71.dat
230
- tests/test_data/benzene-md/72.dat
231
- tests/test_data/benzene-md/73.dat
232
- tests/test_data/benzene-md/74.dat
233
- tests/test_data/benzene-md/75.dat
234
- tests/test_data/benzene-md/76.dat
235
- tests/test_data/benzene-md/77.dat
236
- tests/test_data/benzene-md/78.dat
237
- tests/test_data/benzene-md/79.dat
238
169
  tests/test_data/benzene-md/8.dat
239
- tests/test_data/benzene-md/80.dat
240
- tests/test_data/benzene-md/81.dat
241
- tests/test_data/benzene-md/82.dat
242
- tests/test_data/benzene-md/83.dat
243
- tests/test_data/benzene-md/84.dat
244
- tests/test_data/benzene-md/85.dat
245
- tests/test_data/benzene-md/86.dat
246
- tests/test_data/benzene-md/87.dat
247
- tests/test_data/benzene-md/88.dat
248
- tests/test_data/benzene-md/89.dat
249
170
  tests/test_data/benzene-md/9.dat
250
- tests/test_data/benzene-md/90.dat
251
- tests/test_data/benzene-md/91.dat
252
- tests/test_data/benzene-md/92.dat
253
- tests/test_data/benzene-md/93.dat
254
- tests/test_data/benzene-md/94.dat
255
- tests/test_data/benzene-md/95.dat
256
- tests/test_data/benzene-md/96.dat
257
- tests/test_data/benzene-md/97.dat
258
- tests/test_data/benzene-md/98.dat
259
- tests/test_data/benzene-md/99.dat
260
171
  tests/test_data/tripeptide-md/0.dat
261
172
  tests/test_data/tripeptide-md/1.dat
262
173
  tests/test_data/tripeptide-md/10.dat
@@ -2193,6 +2104,7 @@ tools/generate-unit-test-expect/nist-dataset/README.md
2193
2104
  tools/generate-unit-test-expect/nist-dataset/nist.py
2194
2105
  tools/generate-unit-test-expect/nist-dataset/result.json
2195
2106
  tools/generate-unit-test-expect/others/Benzene.cif
2107
+ tools/generate-unit-test-expect/others/Benzene.pdb
2196
2108
  tools/generate-unit-test-expect/tripeptide/tripeptide-000.ipt_optimization.xyz
2197
2109
  tools/generate-unit-test-expect/tripeptide/tripeptide-001.ipt_optimization.xyz
2198
2110
  tools/generate-unit-test-expect/tripeptide/tripeptide-002.ipt_optimization.xyz
@@ -2444,6 +2356,7 @@ torchani/ase.py
2444
2356
  torchani/ignite.py
2445
2357
  torchani/models.py
2446
2358
  torchani/nn.py
2359
+ torchani/optim.py
2447
2360
  torchani/utils.py
2448
2361
  torchani.egg-info/PKG-INFO
2449
2362
  torchani.egg-info/SOURCES.txt
@@ -1,4 +0,0 @@
1
- torch-nightly
2
- pytorch-ignite-nightly
3
- lark-parser
4
- h5py