torchref 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchref-0.3.0/LICENSE +21 -0
- torchref-0.3.0/PKG-INFO +116 -0
- torchref-0.3.0/README.md +85 -0
- torchref-0.3.0/pyproject.toml +86 -0
- torchref-0.3.0/setup.cfg +4 -0
- torchref-0.3.0/torchref/__init__.py +134 -0
- torchref-0.3.0/torchref/_bootstrap.py +49 -0
- torchref-0.3.0/torchref/config.py +147 -0
- torchref-0.3.0/torchref.egg-info/PKG-INFO +116 -0
- torchref-0.3.0/torchref.egg-info/SOURCES.txt +12 -0
- torchref-0.3.0/torchref.egg-info/dependency_links.txt +1 -0
- torchref-0.3.0/torchref.egg-info/entry_points.txt +9 -0
- torchref-0.3.0/torchref.egg-info/requires.txt +22 -0
- torchref-0.3.0/torchref.egg-info/top_level.txt +1 -0
torchref-0.3.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Hans Peter Seidel
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
torchref-0.3.0/PKG-INFO
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchref
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Tools for multicopy refinement of crystallographic models
|
|
5
|
+
Author: HansPeterSeidel
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: numpy<2.4.0,>=1.24.0
|
|
11
|
+
Requires-Dist: pandas<2.4.0,>=2.0.0
|
|
12
|
+
Requires-Dist: torch<2.10.0,>=2.0.0
|
|
13
|
+
Requires-Dist: tqdm<4.68.0,>=4.61.0
|
|
14
|
+
Requires-Dist: numba<0.64.0,>=0.59.0
|
|
15
|
+
Requires-Dist: gemmi<0.8.0,>=0.5.0
|
|
16
|
+
Requires-Dist: scipy<1.18.0,>=1.10.0
|
|
17
|
+
Requires-Dist: matplotlib<3.11.0,>=3.7.0
|
|
18
|
+
Requires-Dist: reciprocalspaceship<1.1.0,>=0.9.18
|
|
19
|
+
Requires-Dist: pyarrow<23.0.0,>=12.0.0
|
|
20
|
+
Provides-Extra: dev
|
|
21
|
+
Requires-Dist: pytest>=6.0.0; extra == "dev"
|
|
22
|
+
Requires-Dist: pytest-cov>=2.12.0; extra == "dev"
|
|
23
|
+
Requires-Dist: black>=21.5b2; extra == "dev"
|
|
24
|
+
Requires-Dist: isort>=5.9.0; extra == "dev"
|
|
25
|
+
Requires-Dist: flake8>=3.9.0; extra == "dev"
|
|
26
|
+
Provides-Extra: docs
|
|
27
|
+
Requires-Dist: sphinx>=4.0.0; extra == "docs"
|
|
28
|
+
Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == "docs"
|
|
29
|
+
Requires-Dist: numpydoc>=1.1.0; extra == "docs"
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
# TorchRef
|
|
33
|
+
|
|
34
|
+
**A PyTorch-based crystallographic refinement library**
|
|
35
|
+
|
|
36
|
+
[](https://www.python.org/downloads/)
|
|
37
|
+
[](https://pytorch.org/)
|
|
38
|
+
[](https://opensource.org/licenses/MIT)
|
|
39
|
+
|
|
40
|
+
TorchRef is a crystallographic refinement package built entirely on PyTorch. By leveraging PyTorch's automatic differentiation and GPU acceleration, TorchRef enables seamless integration with machine learning workflows and provides a flexible, extensible framework for crystallographic structure refinement.
|
|
41
|
+
|
|
42
|
+
## Key Features
|
|
43
|
+
|
|
44
|
+
- **Native PyTorch Integration**: Built on PyTorch's `nn.Module` architecture, TorchRef integrates naturally with the PyTorch ecosystem, including machine learning models, optimizers, and GPU acceleration.
|
|
45
|
+
|
|
46
|
+
- **Automatic Differentiation**: Dynamic computational graphs eliminate the need for manually implemented gradient calculations. Define new refinement targets directly—PyTorch handles the derivatives automatically.
|
|
47
|
+
|
|
48
|
+
- **Modular Architecture**: Following PyTorch's module pattern, components are easily composable and extensible. Add custom targets, restraints, or optimizers without modifying core code.
|
|
49
|
+
|
|
50
|
+
- **GPU Acceleration**: Leverage CUDA for structure factor calculations, scaling, and optimization—achieving significant speedups for large structures.
|
|
51
|
+
|
|
52
|
+
- **FFT-based Structure Factors**: Efficient structure factor calculation using Fast Fourier Transform (FFT) methods, enabling rapid F_calc computation even for large unit cells.
|
|
53
|
+
|
|
54
|
+
- **State Management**: Full `state_dict` support enables saving and loading complete refinement states, including model parameters, scaler settings, and restraints.
|
|
55
|
+
|
|
56
|
+
## Installation
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
# Clone the repository
|
|
60
|
+
git clone https://github.com/HatPdotS/TorchRef.git
|
|
61
|
+
cd torchref
|
|
62
|
+
|
|
63
|
+
# Install with pip
|
|
64
|
+
pip install -e .
|
|
65
|
+
|
|
66
|
+
# Or install with development dependencies
|
|
67
|
+
pip install -e ".[dev]"
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### Dependencies
|
|
71
|
+
|
|
72
|
+
- Python ≥ 3.8
|
|
73
|
+
- PyTorch ≥ 1.9
|
|
74
|
+
- NumPy ≥ 1.20
|
|
75
|
+
- Gemmi ≥ 0.5
|
|
76
|
+
- reciprocalspaceship ≥ 0.9
|
|
77
|
+
- SciPy ≥ 1.7
|
|
78
|
+
|
|
79
|
+
## Getting Started
|
|
80
|
+
|
|
81
|
+
For demonstrations and usage examples, see the example notebooks in [`example_notebooks/`](example_notebooks/):
|
|
82
|
+
|
|
83
|
+
- [`basic_usage.ipynb`](example_notebooks/basic_usage.ipynb) - Getting started tutorial
|
|
84
|
+
- [`code_examples.ipynb`](example_notebooks/code_examples.ipynb) - Code examples and patterns
|
|
85
|
+
- [`target_exploration.ipynb`](example_notebooks/target_exploration.ipynb) - Exploring refinement targets
|
|
86
|
+
|
|
87
|
+
## Testing
|
|
88
|
+
|
|
89
|
+
```bash
|
|
90
|
+
# Run all tests
|
|
91
|
+
pytest tests/
|
|
92
|
+
|
|
93
|
+
# Run with coverage
|
|
94
|
+
pytest tests/ --cov=torchref
|
|
95
|
+
|
|
96
|
+
# Run specific test categories
|
|
97
|
+
pytest tests/unit/ # Fast unit tests
|
|
98
|
+
pytest tests/integration/ # Integration tests
|
|
99
|
+
pytest tests/functional/ # Full workflow tests
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
## Contributing
|
|
103
|
+
|
|
104
|
+
Contributions are welcome! Please follow these guidelines:
|
|
105
|
+
|
|
106
|
+
1. Follow the [NumPy docstring style](https://numpydoc.readthedocs.io/en/latest/format.html)
|
|
107
|
+
2. Add tests for new functionality
|
|
108
|
+
3. Ensure all tests pass before submitting
|
|
109
|
+
|
|
110
|
+
## License
|
|
111
|
+
|
|
112
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
|
torchref-0.3.0/README.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# TorchRef
|
|
2
|
+
|
|
3
|
+
**A PyTorch-based crystallographic refinement library**
|
|
4
|
+
|
|
5
|
+
[](https://www.python.org/downloads/)
|
|
6
|
+
[](https://pytorch.org/)
|
|
7
|
+
[](https://opensource.org/licenses/MIT)
|
|
8
|
+
|
|
9
|
+
TorchRef is a crystallographic refinement package built entirely on PyTorch. By leveraging PyTorch's automatic differentiation and GPU acceleration, TorchRef enables seamless integration with machine learning workflows and provides a flexible, extensible framework for crystallographic structure refinement.
|
|
10
|
+
|
|
11
|
+
## Key Features
|
|
12
|
+
|
|
13
|
+
- **Native PyTorch Integration**: Built on PyTorch's `nn.Module` architecture, TorchRef integrates naturally with the PyTorch ecosystem, including machine learning models, optimizers, and GPU acceleration.
|
|
14
|
+
|
|
15
|
+
- **Automatic Differentiation**: Dynamic computational graphs eliminate the need for manually implemented gradient calculations. Define new refinement targets directly—PyTorch handles the derivatives automatically.
|
|
16
|
+
|
|
17
|
+
- **Modular Architecture**: Following PyTorch's module pattern, components are easily composable and extensible. Add custom targets, restraints, or optimizers without modifying core code.
|
|
18
|
+
|
|
19
|
+
- **GPU Acceleration**: Leverage CUDA for structure factor calculations, scaling, and optimization—achieving significant speedups for large structures.
|
|
20
|
+
|
|
21
|
+
- **FFT-based Structure Factors**: Efficient structure factor calculation using Fast Fourier Transform (FFT) methods, enabling rapid F_calc computation even for large unit cells.
|
|
22
|
+
|
|
23
|
+
- **State Management**: Full `state_dict` support enables saving and loading complete refinement states, including model parameters, scaler settings, and restraints.
|
|
24
|
+
|
|
25
|
+
## Installation
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
# Clone the repository
|
|
29
|
+
git clone https://github.com/HatPdotS/TorchRef.git
|
|
30
|
+
cd torchref
|
|
31
|
+
|
|
32
|
+
# Install with pip
|
|
33
|
+
pip install -e .
|
|
34
|
+
|
|
35
|
+
# Or install with development dependencies
|
|
36
|
+
pip install -e ".[dev]"
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
### Dependencies
|
|
40
|
+
|
|
41
|
+
- Python ≥ 3.8
|
|
42
|
+
- PyTorch ≥ 1.9
|
|
43
|
+
- NumPy ≥ 1.20
|
|
44
|
+
- Gemmi ≥ 0.5
|
|
45
|
+
- reciprocalspaceship ≥ 0.9
|
|
46
|
+
- SciPy ≥ 1.7
|
|
47
|
+
|
|
48
|
+
## Getting Started
|
|
49
|
+
|
|
50
|
+
For demonstrations and usage examples, see the example notebooks in [`example_notebooks/`](example_notebooks/):
|
|
51
|
+
|
|
52
|
+
- [`basic_usage.ipynb`](example_notebooks/basic_usage.ipynb) - Getting started tutorial
|
|
53
|
+
- [`code_examples.ipynb`](example_notebooks/code_examples.ipynb) - Code examples and patterns
|
|
54
|
+
- [`target_exploration.ipynb`](example_notebooks/target_exploration.ipynb) - Exploring refinement targets
|
|
55
|
+
|
|
56
|
+
## Testing
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
# Run all tests
|
|
60
|
+
pytest tests/
|
|
61
|
+
|
|
62
|
+
# Run with coverage
|
|
63
|
+
pytest tests/ --cov=torchref
|
|
64
|
+
|
|
65
|
+
# Run specific test categories
|
|
66
|
+
pytest tests/unit/ # Fast unit tests
|
|
67
|
+
pytest tests/integration/ # Integration tests
|
|
68
|
+
pytest tests/functional/ # Full workflow tests
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
## Contributing
|
|
72
|
+
|
|
73
|
+
Contributions are welcome! Please follow these guidelines:
|
|
74
|
+
|
|
75
|
+
1. Follow the [NumPy docstring style](https://numpydoc.readthedocs.io/en/latest/format.html)
|
|
76
|
+
2. Add tests for new functionality
|
|
77
|
+
3. Ensure all tests pass before submitting
|
|
78
|
+
|
|
79
|
+
## License
|
|
80
|
+
|
|
81
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=42", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "torchref"
|
|
7
|
+
version = "0.3.0"
|
|
8
|
+
description = "Tools for multicopy refinement of crystallographic models"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "HansPeterSeidel"},
|
|
14
|
+
]
|
|
15
|
+
# Version requirements validated via tox compatibility testing (2026-01-13)
|
|
16
|
+
# Tested combinations:
|
|
17
|
+
# - Python 3.9: numpy 1.24-1.26, pandas 2.0.x, torch 2.0-2.4, numba 0.59-0.60, scipy 1.10-1.13
|
|
18
|
+
# - Python 3.11: numpy 1.24-2.3, pandas 2.0-2.3, torch 2.0-2.9, numba 0.59-0.63, scipy 1.10-1.17
|
|
19
|
+
# - Python 3.12: numpy 2.0-2.3, pandas 2.2-2.3, torch 2.4-2.9, numba 0.61-0.63, scipy 1.13-1.17
|
|
20
|
+
# Upper bounds set to next minor version above tested maximum
|
|
21
|
+
dependencies = [
|
|
22
|
+
"numpy>=1.24.0,<2.4.0",
|
|
23
|
+
"pandas>=2.0.0,<2.4.0",
|
|
24
|
+
"torch>=2.0.0,<2.10.0",
|
|
25
|
+
"tqdm>=4.61.0,<4.68.0",
|
|
26
|
+
"numba>=0.59.0,<0.64.0",
|
|
27
|
+
"gemmi>=0.5.0,<0.8.0",
|
|
28
|
+
"scipy>=1.10.0,<1.18.0",
|
|
29
|
+
"matplotlib>=3.7.0,<3.11.0",
|
|
30
|
+
"reciprocalspaceship>=0.9.18,<1.1.0",
|
|
31
|
+
"pyarrow>=12.0.0,<23.0.0",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[project.scripts]
|
|
35
|
+
torchref-refine = "torchref.cli.refine:main"
|
|
36
|
+
torchref-refine-screened = "torchref.cli.refine_screened:main"
|
|
37
|
+
torchref-refine-everything = "torchref.cli.refine_everything:main"
|
|
38
|
+
torchref-refine-random-weights = "torchref.cli.refine_everything_random_weights:main"
|
|
39
|
+
torchref-refine-policy = 'torchref.cli.refine_everything_policy:main'
|
|
40
|
+
torchref-refine-static = 'torchref.cli.refine_everything_static:main'
|
|
41
|
+
torchref-refine-hyperparameters = 'torchref.cli.refine_everything_hyperparameters:main'
|
|
42
|
+
torchref-download-data = 'torchref.cli.download_data:main'
|
|
43
|
+
|
|
44
|
+
[project.optional-dependencies]
|
|
45
|
+
|
|
46
|
+
dev = [
|
|
47
|
+
"pytest>=6.0.0",
|
|
48
|
+
"pytest-cov>=2.12.0",
|
|
49
|
+
"black>=21.5b2",
|
|
50
|
+
"isort>=5.9.0",
|
|
51
|
+
"flake8>=3.9.0",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
# Optional CCTBX/iotbx functionality (requires conda, not pip)
|
|
55
|
+
# Install with: conda install -c conda-forge cctbx-base
|
|
56
|
+
# The iotbx module is used in torchref/math_functions/CCTBX_related.py
|
|
57
|
+
|
|
58
|
+
docs = [
|
|
59
|
+
"sphinx>=4.0.0",
|
|
60
|
+
"sphinx-rtd-theme>=1.0.0",
|
|
61
|
+
"numpydoc>=1.1.0",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
[tool.setuptools]
|
|
65
|
+
packages = ["torchref"]
|
|
66
|
+
|
|
67
|
+
[tool.setuptools.package-data]
|
|
68
|
+
torchref = ["*.py"]
|
|
69
|
+
|
|
70
|
+
[tool.setuptools.exclude-package-data]
|
|
71
|
+
torchref = ["*.py"]
|
|
72
|
+
|
|
73
|
+
[tool.black]
|
|
74
|
+
line-length = 88
|
|
75
|
+
target-version = ["py39"]
|
|
76
|
+
|
|
77
|
+
[tool.isort]
|
|
78
|
+
profile = "black"
|
|
79
|
+
line_length = 88
|
|
80
|
+
|
|
81
|
+
[tool.pytest.ini_options]
|
|
82
|
+
testpaths = ["tests"]
|
|
83
|
+
python_files = "test_*.py"
|
|
84
|
+
|
|
85
|
+
[tool.ruff.lint]
|
|
86
|
+
ignore = ["F841","E741","F841","E402","E722"]
|
torchref-0.3.0/setup.cfg
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TorchRef - A PyTorch-based crystallographic refinement library.
|
|
3
|
+
|
|
4
|
+
TorchRef provides GPU-accelerated crystallographic structure refinement
|
|
5
|
+
using PyTorch's automatic differentiation and nn.Module architecture.
|
|
6
|
+
|
|
7
|
+
Key Features
|
|
8
|
+
------------
|
|
9
|
+
- Native PyTorch integration with nn.Module architecture
|
|
10
|
+
- Automatic differentiation for custom target functions
|
|
11
|
+
- GPU acceleration for structure factor calculations
|
|
12
|
+
- Modular design for easy extension
|
|
13
|
+
|
|
14
|
+
Quick Start
|
|
15
|
+
-----------
|
|
16
|
+
::
|
|
17
|
+
|
|
18
|
+
from torchref import Refinement, ReflectionData, Model
|
|
19
|
+
|
|
20
|
+
# Load data and model
|
|
21
|
+
data = ReflectionData().load_mtz('data.mtz')
|
|
22
|
+
model = Model().load_pdb('structure.pdb')
|
|
23
|
+
|
|
24
|
+
# Run refinement
|
|
25
|
+
refinement = Refinement(data=data, model=model, device='cuda')
|
|
26
|
+
refinement.run_refinement(macro_cycles=10)
|
|
27
|
+
|
|
28
|
+
Modules
|
|
29
|
+
-------
|
|
30
|
+
io
|
|
31
|
+
File I/O for MTZ, PDB, CIF formats.
|
|
32
|
+
model
|
|
33
|
+
Atomic structure models (coordinates, B-factors, occupancies).
|
|
34
|
+
refinement
|
|
35
|
+
Core refinement framework with targets and weighting schemes.
|
|
36
|
+
restraints
|
|
37
|
+
Geometry restraints (bonds, angles, torsions, planes).
|
|
38
|
+
scaling
|
|
39
|
+
Structure factor scaling and bulk solvent models.
|
|
40
|
+
symmetry
|
|
41
|
+
Crystallographic symmetry operations.
|
|
42
|
+
alignment
|
|
43
|
+
Patterson-based structure alignment.
|
|
44
|
+
math_functions
|
|
45
|
+
Mathematical utilities for crystallography.
|
|
46
|
+
utils
|
|
47
|
+
General utilities and debugging tools.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
__version__ = "0.3.0"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
import os
|
|
54
|
+
import warnings
|
|
55
|
+
from pathlib import Path
|
|
56
|
+
|
|
57
|
+
from torchref._bootstrap import configure_threading, detect_available_cpus
|
|
58
|
+
|
|
59
|
+
# Configure threading before importing torch
|
|
60
|
+
if "TORCHREF_NUM_THREADS" in os.environ:
|
|
61
|
+
N_CPUS = int(os.environ["TORCHREF_NUM_THREADS"])
|
|
62
|
+
warnings.warn(
|
|
63
|
+
f"TorchRef using user-specified {N_CPUS} threads from TORCHREF_NUM_THREADS.",
|
|
64
|
+
stacklevel=2,
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
N_CPUS = detect_available_cpus()
|
|
68
|
+
os.environ["TORCHREF_NUM_THREADS"] = str(N_CPUS)
|
|
69
|
+
warnings.warn(
|
|
70
|
+
f"TorchRef auto-configured {N_CPUS} threads. Set TORCHREF_NUM_THREADS to override.",
|
|
71
|
+
stacklevel=2,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
configure_threading(N_CPUS)
|
|
75
|
+
|
|
76
|
+
import torch
|
|
77
|
+
|
|
78
|
+
torch.set_num_threads(N_CPUS)
|
|
79
|
+
|
|
80
|
+
# Dtype configuration (must be imported after torch)
|
|
81
|
+
from torchref.config import dtypes
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# Project root path for referencing package files
|
|
85
|
+
ROOT_TORCHREF = Path(__file__).parent.parent.resolve()
|
|
86
|
+
|
|
87
|
+
# Package path for referencing internal files
|
|
88
|
+
PATH_TORCHREF = Path(__file__).parent.resolve()
|
|
89
|
+
|
|
90
|
+
PATH_TORCHREF_DATA = PATH_TORCHREF / "data"
|
|
91
|
+
|
|
92
|
+
# =============================================================================
|
|
93
|
+
# Convenience imports for common classes
|
|
94
|
+
# =============================================================================
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# Data I/O
|
|
98
|
+
from torchref.io import DatasetCollection, ReflectionData
|
|
99
|
+
|
|
100
|
+
# Model
|
|
101
|
+
from torchref.model import Model, ModelFT
|
|
102
|
+
|
|
103
|
+
# Refinement
|
|
104
|
+
from torchref.refinement import LBFGSRefinement, Refinement
|
|
105
|
+
|
|
106
|
+
# Restraints
|
|
107
|
+
from torchref.restraints import Restraints
|
|
108
|
+
|
|
109
|
+
# Scaling
|
|
110
|
+
from torchref.scaling import Scaler, SolventModel
|
|
111
|
+
|
|
112
|
+
__all__ = [
|
|
113
|
+
# Version and paths
|
|
114
|
+
"__version__",
|
|
115
|
+
"ROOT_TORCHREF",
|
|
116
|
+
"PATH_TORCHREF",
|
|
117
|
+
"N_CPUS",
|
|
118
|
+
# Dtype configuration
|
|
119
|
+
"dtypes",
|
|
120
|
+
# Data I/O
|
|
121
|
+
"ReflectionData",
|
|
122
|
+
"DatasetCollection",
|
|
123
|
+
# Model
|
|
124
|
+
"Model",
|
|
125
|
+
"ModelFT",
|
|
126
|
+
# Refinement
|
|
127
|
+
"Refinement",
|
|
128
|
+
"LBFGSRefinement",
|
|
129
|
+
# Restraints
|
|
130
|
+
"Restraints",
|
|
131
|
+
# Scaling
|
|
132
|
+
"Scaler",
|
|
133
|
+
"SolventModel",
|
|
134
|
+
]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
|
|
3
|
+
Bootstrap module to configure threading before importing heavy libraries.
|
|
4
|
+
|
|
5
|
+
Is imported automatically when torchref is imported.
|
|
6
|
+
If you are not on a slurm node and want to customize threading,
|
|
7
|
+
call `configure_threading()` before importing torchref.
|
|
8
|
+
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def detect_available_cpus(max_if_not_slurm=4) -> int:
|
|
15
|
+
"""Detect actual available CPUs respecting cgroups, affinity, and SLURM."""
|
|
16
|
+
|
|
17
|
+
# 1. Check SLURM first (most reliable on HPC)
|
|
18
|
+
slurm_cpus = os.environ.get("SLURM_CPUS_PER_TASK")
|
|
19
|
+
if slurm_cpus:
|
|
20
|
+
return int(slurm_cpus)
|
|
21
|
+
|
|
22
|
+
# 4. Check CPU affinity
|
|
23
|
+
try:
|
|
24
|
+
return min(len(os.sched_getaffinity(0)), 4)
|
|
25
|
+
except (AttributeError, OSError):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
# 5. Fallback to os.cpu_count() but cap it sensibly
|
|
29
|
+
return min(os.cpu_count() or 1, 4)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def configure_threading(num_threads: int = None, pin_threads=False) -> int:
|
|
33
|
+
"""Configure all threading libraries. Call BEFORE importing torch/numpy."""
|
|
34
|
+
|
|
35
|
+
if num_threads is None:
|
|
36
|
+
num_threads = detect_available_cpus()
|
|
37
|
+
|
|
38
|
+
n = str(num_threads)
|
|
39
|
+
|
|
40
|
+
os.environ["OMP_NUM_THREADS"] = n
|
|
41
|
+
os.environ["MKL_NUM_THREADS"] = n
|
|
42
|
+
os.environ["OPENBLAS_NUM_THREADS"] = n
|
|
43
|
+
|
|
44
|
+
if pin_threads:
|
|
45
|
+
# Optional: enable thread pinning
|
|
46
|
+
os.environ["OMP_PROC_BIND"] = "TRUE"
|
|
47
|
+
os.environ["OMP_PLACES"] = "cores"
|
|
48
|
+
|
|
49
|
+
return num_threads
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Centralized configuration for TorchRef.
|
|
3
|
+
|
|
4
|
+
Default dtypes can be set via environment variables at import time:
|
|
5
|
+
- TORCHREF_DTYPE_FLOAT: float32 (default) or float64
|
|
6
|
+
- TORCHREF_DTYPE_INT: int32 (default) or int64
|
|
7
|
+
- TORCHREF_DTYPE_COMPLEX: complex64 (default) or complex128
|
|
8
|
+
|
|
9
|
+
Users can also change dtypes at runtime via attribute assignment:
|
|
10
|
+
import torchref
|
|
11
|
+
torchref.dtypes.float = torch.float64
|
|
12
|
+
torchref.dtypes.int = torch.int64
|
|
13
|
+
torchref.dtypes.complex = torch.complex128
|
|
14
|
+
|
|
15
|
+
Or read current values:
|
|
16
|
+
torchref.dtypes.float # torch.float32
|
|
17
|
+
torchref.dtypes.int # torch.int32
|
|
18
|
+
torchref.dtypes.complex # torch.complex64
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
import os
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# Map strings to torch dtypes
|
|
26
|
+
_FLOAT_DTYPE_MAP = {
|
|
27
|
+
"float32": torch.float32,
|
|
28
|
+
"float64": torch.float64,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
_INT_DTYPE_MAP = {
|
|
32
|
+
"int32": torch.int32,
|
|
33
|
+
"int64": torch.int64,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
_COMPLEX_DTYPE_MAP = {
|
|
37
|
+
"complex64": torch.complex64,
|
|
38
|
+
"complex128": torch.complex128,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class DtypeConfig:
|
|
43
|
+
"""
|
|
44
|
+
Dtype configuration with property-based access.
|
|
45
|
+
|
|
46
|
+
Access dtypes as attributes:
|
|
47
|
+
dtypes.float # get current float dtype
|
|
48
|
+
dtypes.int # get current int dtype
|
|
49
|
+
dtypes.complex # get current complex dtype
|
|
50
|
+
|
|
51
|
+
Set dtypes via assignment:
|
|
52
|
+
dtypes.float = torch.float64
|
|
53
|
+
dtypes.int = torch.int64
|
|
54
|
+
dtypes.complex = torch.complex128
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self):
|
|
58
|
+
# Parse environment variables with defaults
|
|
59
|
+
float_str = os.environ.get("TORCHREF_DTYPE_FLOAT", "float32").lower()
|
|
60
|
+
int_str = os.environ.get("TORCHREF_DTYPE_INT", "int32").lower()
|
|
61
|
+
complex_str = os.environ.get("TORCHREF_DTYPE_COMPLEX", "complex64").lower()
|
|
62
|
+
|
|
63
|
+
# Validate and set
|
|
64
|
+
if float_str not in _FLOAT_DTYPE_MAP:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Invalid TORCHREF_DTYPE_FLOAT: {float_str}. "
|
|
67
|
+
f"Valid values: {list(_FLOAT_DTYPE_MAP.keys())}"
|
|
68
|
+
)
|
|
69
|
+
if int_str not in _INT_DTYPE_MAP:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Invalid TORCHREF_DTYPE_INT: {int_str}. "
|
|
72
|
+
f"Valid values: {list(_INT_DTYPE_MAP.keys())}"
|
|
73
|
+
)
|
|
74
|
+
if complex_str not in _COMPLEX_DTYPE_MAP:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Invalid TORCHREF_DTYPE_COMPLEX: {complex_str}. "
|
|
77
|
+
f"Valid values: {list(_COMPLEX_DTYPE_MAP.keys())}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self._float = _FLOAT_DTYPE_MAP[float_str]
|
|
81
|
+
self._int = _INT_DTYPE_MAP[int_str]
|
|
82
|
+
self._complex = _COMPLEX_DTYPE_MAP[complex_str]
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def float(self) -> torch.dtype:
|
|
86
|
+
"""Get the current default float dtype."""
|
|
87
|
+
return self._float
|
|
88
|
+
|
|
89
|
+
@float.setter
|
|
90
|
+
def float(self, dtype: torch.dtype) -> None:
|
|
91
|
+
"""Set the default float dtype for all future operations."""
|
|
92
|
+
if dtype not in (torch.float32, torch.float64):
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Invalid float dtype: {dtype}. Use torch.float32 or torch.float64."
|
|
95
|
+
)
|
|
96
|
+
self._float = dtype
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def int(self) -> torch.dtype:
|
|
100
|
+
"""Get the current default int dtype."""
|
|
101
|
+
return self._int
|
|
102
|
+
|
|
103
|
+
@int.setter
|
|
104
|
+
def int(self, dtype: torch.dtype) -> None:
|
|
105
|
+
"""Set the default int dtype for all future operations."""
|
|
106
|
+
if dtype not in (torch.int32, torch.int64):
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Invalid int dtype: {dtype}. Use torch.int32 or torch.int64."
|
|
109
|
+
)
|
|
110
|
+
self._int = dtype
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def complex(self) -> torch.dtype:
|
|
114
|
+
"""Get the current default complex dtype."""
|
|
115
|
+
return self._complex
|
|
116
|
+
|
|
117
|
+
@complex.setter
|
|
118
|
+
def complex(self, dtype: torch.dtype) -> None:
|
|
119
|
+
"""Set the default complex dtype for all future operations."""
|
|
120
|
+
if dtype not in (torch.complex64, torch.complex128):
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"Invalid complex dtype: {dtype}. Use torch.complex64 or torch.complex128."
|
|
123
|
+
)
|
|
124
|
+
self._complex = dtype
|
|
125
|
+
|
|
126
|
+
def __repr__(self) -> str:
|
|
127
|
+
return f"DtypeConfig(float={self._float}, int={self._int}, complex={self._complex})"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# Global singleton instance
|
|
131
|
+
dtypes = DtypeConfig()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# Convenience functions for internal use (avoid repeated attribute lookups)
|
|
135
|
+
def get_float_dtype() -> torch.dtype:
|
|
136
|
+
"""Get the current default float dtype."""
|
|
137
|
+
return dtypes.float
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def get_int_dtype() -> torch.dtype:
|
|
141
|
+
"""Get the current default int dtype."""
|
|
142
|
+
return dtypes.int
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_complex_dtype() -> torch.dtype:
|
|
146
|
+
"""Get the current default complex dtype."""
|
|
147
|
+
return dtypes.complex
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torchref
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Tools for multicopy refinement of crystallographic models
|
|
5
|
+
Author: HansPeterSeidel
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Requires-Dist: numpy<2.4.0,>=1.24.0
|
|
11
|
+
Requires-Dist: pandas<2.4.0,>=2.0.0
|
|
12
|
+
Requires-Dist: torch<2.10.0,>=2.0.0
|
|
13
|
+
Requires-Dist: tqdm<4.68.0,>=4.61.0
|
|
14
|
+
Requires-Dist: numba<0.64.0,>=0.59.0
|
|
15
|
+
Requires-Dist: gemmi<0.8.0,>=0.5.0
|
|
16
|
+
Requires-Dist: scipy<1.18.0,>=1.10.0
|
|
17
|
+
Requires-Dist: matplotlib<3.11.0,>=3.7.0
|
|
18
|
+
Requires-Dist: reciprocalspaceship<1.1.0,>=0.9.18
|
|
19
|
+
Requires-Dist: pyarrow<23.0.0,>=12.0.0
|
|
20
|
+
Provides-Extra: dev
|
|
21
|
+
Requires-Dist: pytest>=6.0.0; extra == "dev"
|
|
22
|
+
Requires-Dist: pytest-cov>=2.12.0; extra == "dev"
|
|
23
|
+
Requires-Dist: black>=21.5b2; extra == "dev"
|
|
24
|
+
Requires-Dist: isort>=5.9.0; extra == "dev"
|
|
25
|
+
Requires-Dist: flake8>=3.9.0; extra == "dev"
|
|
26
|
+
Provides-Extra: docs
|
|
27
|
+
Requires-Dist: sphinx>=4.0.0; extra == "docs"
|
|
28
|
+
Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == "docs"
|
|
29
|
+
Requires-Dist: numpydoc>=1.1.0; extra == "docs"
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
# TorchRef
|
|
33
|
+
|
|
34
|
+
**A PyTorch-based crystallographic refinement library**
|
|
35
|
+
|
|
36
|
+
[](https://www.python.org/downloads/)
|
|
37
|
+
[](https://pytorch.org/)
|
|
38
|
+
[](https://opensource.org/licenses/MIT)
|
|
39
|
+
|
|
40
|
+
TorchRef is a crystallographic refinement package built entirely on PyTorch. By leveraging PyTorch's automatic differentiation and GPU acceleration, TorchRef enables seamless integration with machine learning workflows and provides a flexible, extensible framework for crystallographic structure refinement.
|
|
41
|
+
|
|
42
|
+
## Key Features
|
|
43
|
+
|
|
44
|
+
- **Native PyTorch Integration**: Built on PyTorch's `nn.Module` architecture, TorchRef integrates naturally with the PyTorch ecosystem, including machine learning models, optimizers, and GPU acceleration.
|
|
45
|
+
|
|
46
|
+
- **Automatic Differentiation**: Dynamic computational graphs eliminate the need for manually implemented gradient calculations. Define new refinement targets directly—PyTorch handles the derivatives automatically.
|
|
47
|
+
|
|
48
|
+
- **Modular Architecture**: Following PyTorch's module pattern, components are easily composable and extensible. Add custom targets, restraints, or optimizers without modifying core code.
|
|
49
|
+
|
|
50
|
+
- **GPU Acceleration**: Leverage CUDA for structure factor calculations, scaling, and optimization—achieving significant speedups for large structures.
|
|
51
|
+
|
|
52
|
+
- **FFT-based Structure Factors**: Efficient structure factor calculation using Fast Fourier Transform (FFT) methods, enabling rapid F_calc computation even for large unit cells.
|
|
53
|
+
|
|
54
|
+
- **State Management**: Full `state_dict` support enables saving and loading complete refinement states, including model parameters, scaler settings, and restraints.
|
|
55
|
+
|
|
56
|
+
## Installation
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
# Clone the repository
|
|
60
|
+
git clone https://github.com/HatPdotS/TorchRef.git
|
|
61
|
+
cd torchref
|
|
62
|
+
|
|
63
|
+
# Install with pip
|
|
64
|
+
pip install -e .
|
|
65
|
+
|
|
66
|
+
# Or install with development dependencies
|
|
67
|
+
pip install -e ".[dev]"
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
### Dependencies
|
|
71
|
+
|
|
72
|
+
- Python ≥ 3.8
|
|
73
|
+
- PyTorch ≥ 1.9
|
|
74
|
+
- NumPy ≥ 1.20
|
|
75
|
+
- Gemmi ≥ 0.5
|
|
76
|
+
- reciprocalspaceship ≥ 0.9
|
|
77
|
+
- SciPy ≥ 1.7
|
|
78
|
+
|
|
79
|
+
## Getting Started
|
|
80
|
+
|
|
81
|
+
For demonstrations and usage examples, see the example notebooks in [`example_notebooks/`](example_notebooks/):
|
|
82
|
+
|
|
83
|
+
- [`basic_usage.ipynb`](example_notebooks/basic_usage.ipynb) - Getting started tutorial
|
|
84
|
+
- [`code_examples.ipynb`](example_notebooks/code_examples.ipynb) - Code examples and patterns
|
|
85
|
+
- [`target_exploration.ipynb`](example_notebooks/target_exploration.ipynb) - Exploring refinement targets
|
|
86
|
+
|
|
87
|
+
## Testing
|
|
88
|
+
|
|
89
|
+
```bash
|
|
90
|
+
# Run all tests
|
|
91
|
+
pytest tests/
|
|
92
|
+
|
|
93
|
+
# Run with coverage
|
|
94
|
+
pytest tests/ --cov=torchref
|
|
95
|
+
|
|
96
|
+
# Run specific test categories
|
|
97
|
+
pytest tests/unit/ # Fast unit tests
|
|
98
|
+
pytest tests/integration/ # Integration tests
|
|
99
|
+
pytest tests/functional/ # Full workflow tests
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
## Contributing
|
|
103
|
+
|
|
104
|
+
Contributions are welcome! Please follow these guidelines:
|
|
105
|
+
|
|
106
|
+
1. Follow the [NumPy docstring style](https://numpydoc.readthedocs.io/en/latest/format.html)
|
|
107
|
+
2. Add tests for new functionality
|
|
108
|
+
3. Ensure all tests pass before submitting
|
|
109
|
+
|
|
110
|
+
## License
|
|
111
|
+
|
|
112
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
torchref/__init__.py
|
|
5
|
+
torchref/_bootstrap.py
|
|
6
|
+
torchref/config.py
|
|
7
|
+
torchref.egg-info/PKG-INFO
|
|
8
|
+
torchref.egg-info/SOURCES.txt
|
|
9
|
+
torchref.egg-info/dependency_links.txt
|
|
10
|
+
torchref.egg-info/entry_points.txt
|
|
11
|
+
torchref.egg-info/requires.txt
|
|
12
|
+
torchref.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
[console_scripts]
|
|
2
|
+
torchref-download-data = torchref.cli.download_data:main
|
|
3
|
+
torchref-refine = torchref.cli.refine:main
|
|
4
|
+
torchref-refine-everything = torchref.cli.refine_everything:main
|
|
5
|
+
torchref-refine-hyperparameters = torchref.cli.refine_everything_hyperparameters:main
|
|
6
|
+
torchref-refine-policy = torchref.cli.refine_everything_policy:main
|
|
7
|
+
torchref-refine-random-weights = torchref.cli.refine_everything_random_weights:main
|
|
8
|
+
torchref-refine-screened = torchref.cli.refine_screened:main
|
|
9
|
+
torchref-refine-static = torchref.cli.refine_everything_static:main
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
numpy<2.4.0,>=1.24.0
|
|
2
|
+
pandas<2.4.0,>=2.0.0
|
|
3
|
+
torch<2.10.0,>=2.0.0
|
|
4
|
+
tqdm<4.68.0,>=4.61.0
|
|
5
|
+
numba<0.64.0,>=0.59.0
|
|
6
|
+
gemmi<0.8.0,>=0.5.0
|
|
7
|
+
scipy<1.18.0,>=1.10.0
|
|
8
|
+
matplotlib<3.11.0,>=3.7.0
|
|
9
|
+
reciprocalspaceship<1.1.0,>=0.9.18
|
|
10
|
+
pyarrow<23.0.0,>=12.0.0
|
|
11
|
+
|
|
12
|
+
[dev]
|
|
13
|
+
pytest>=6.0.0
|
|
14
|
+
pytest-cov>=2.12.0
|
|
15
|
+
black>=21.5b2
|
|
16
|
+
isort>=5.9.0
|
|
17
|
+
flake8>=3.9.0
|
|
18
|
+
|
|
19
|
+
[docs]
|
|
20
|
+
sphinx>=4.0.0
|
|
21
|
+
sphinx-rtd-theme>=1.0.0
|
|
22
|
+
numpydoc>=1.1.0
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torchref
|