torchref 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
torchref-0.3.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Hans Peter Seidel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,116 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchref
3
+ Version: 0.3.0
4
+ Summary: Tools for multicopy refinement of crystallographic models
5
+ Author: HansPeterSeidel
6
+ License-Expression: MIT
7
+ Requires-Python: >=3.9
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: numpy<2.4.0,>=1.24.0
11
+ Requires-Dist: pandas<2.4.0,>=2.0.0
12
+ Requires-Dist: torch<2.10.0,>=2.0.0
13
+ Requires-Dist: tqdm<4.68.0,>=4.61.0
14
+ Requires-Dist: numba<0.64.0,>=0.59.0
15
+ Requires-Dist: gemmi<0.8.0,>=0.5.0
16
+ Requires-Dist: scipy<1.18.0,>=1.10.0
17
+ Requires-Dist: matplotlib<3.11.0,>=3.7.0
18
+ Requires-Dist: reciprocalspaceship<1.1.0,>=0.9.18
19
+ Requires-Dist: pyarrow<23.0.0,>=12.0.0
20
+ Provides-Extra: dev
21
+ Requires-Dist: pytest>=6.0.0; extra == "dev"
22
+ Requires-Dist: pytest-cov>=2.12.0; extra == "dev"
23
+ Requires-Dist: black>=21.5b2; extra == "dev"
24
+ Requires-Dist: isort>=5.9.0; extra == "dev"
25
+ Requires-Dist: flake8>=3.9.0; extra == "dev"
26
+ Provides-Extra: docs
27
+ Requires-Dist: sphinx>=4.0.0; extra == "docs"
28
+ Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == "docs"
29
+ Requires-Dist: numpydoc>=1.1.0; extra == "docs"
30
+ Dynamic: license-file
31
+
32
+ # TorchRef
33
+
34
+ **A PyTorch-based crystallographic refinement library**
35
+
36
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
37
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.9+-ee4c2c.svg)](https://pytorch.org/)
38
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
39
+
40
+ TorchRef is a crystallographic refinement package built entirely on PyTorch. By leveraging PyTorch's automatic differentiation and GPU acceleration, TorchRef enables seamless integration with machine learning workflows and provides a flexible, extensible framework for crystallographic structure refinement.
41
+
42
+ ## Key Features
43
+
44
+ - **Native PyTorch Integration**: Built on PyTorch's `nn.Module` architecture, TorchRef integrates naturally with the PyTorch ecosystem, including machine learning models, optimizers, and GPU acceleration.
45
+
46
+ - **Automatic Differentiation**: Dynamic computational graphs eliminate the need for manually implemented gradient calculations. Define new refinement targets directly—PyTorch handles the derivatives automatically.
47
+
48
+ - **Modular Architecture**: Following PyTorch's module pattern, components are easily composable and extensible. Add custom targets, restraints, or optimizers without modifying core code.
49
+
50
+ - **GPU Acceleration**: Leverage CUDA for structure factor calculations, scaling, and optimization—achieving significant speedups for large structures.
51
+
52
+ - **FFT-based Structure Factors**: Efficient structure factor calculation using Fast Fourier Transform (FFT) methods, enabling rapid F_calc computation even for large unit cells.
53
+
54
+ - **State Management**: Full `state_dict` support enables saving and loading complete refinement states, including model parameters, scaler settings, and restraints.
55
+
56
+ ## Installation
57
+
58
+ ```bash
59
+ # Clone the repository
60
+ git clone https://github.com/HatPdotS/TorchRef.git
61
+ cd torchref
62
+
63
+ # Install with pip
64
+ pip install -e .
65
+
66
+ # Or install with development dependencies
67
+ pip install -e ".[dev]"
68
+ ```
69
+
70
+ ### Dependencies
71
+
72
+ - Python ≥ 3.8
73
+ - PyTorch ≥ 1.9
74
+ - NumPy ≥ 1.20
75
+ - Gemmi ≥ 0.5
76
+ - reciprocalspaceship ≥ 0.9
77
+ - SciPy ≥ 1.7
78
+
79
+ ## Getting Started
80
+
81
+ For demonstrations and usage examples, see the example notebooks in [`example_notebooks/`](example_notebooks/):
82
+
83
+ - [`basic_usage.ipynb`](example_notebooks/basic_usage.ipynb) - Getting started tutorial
84
+ - [`code_examples.ipynb`](example_notebooks/code_examples.ipynb) - Code examples and patterns
85
+ - [`target_exploration.ipynb`](example_notebooks/target_exploration.ipynb) - Exploring refinement targets
86
+
87
+ ## Testing
88
+
89
+ ```bash
90
+ # Run all tests
91
+ pytest tests/
92
+
93
+ # Run with coverage
94
+ pytest tests/ --cov=torchref
95
+
96
+ # Run specific test categories
97
+ pytest tests/unit/ # Fast unit tests
98
+ pytest tests/integration/ # Integration tests
99
+ pytest tests/functional/ # Full workflow tests
100
+ ```
101
+
102
+ ## Contributing
103
+
104
+ Contributions are welcome! Please follow these guidelines:
105
+
106
+ 1. Follow the [NumPy docstring style](https://numpydoc.readthedocs.io/en/latest/format.html)
107
+ 2. Add tests for new functionality
108
+ 3. Ensure all tests pass before submitting
109
+
110
+ ## License
111
+
112
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
113
+
114
+
115
+
116
+
@@ -0,0 +1,85 @@
1
+ # TorchRef
2
+
3
+ **A PyTorch-based crystallographic refinement library**
4
+
5
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
6
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.9+-ee4c2c.svg)](https://pytorch.org/)
7
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
8
+
9
+ TorchRef is a crystallographic refinement package built entirely on PyTorch. By leveraging PyTorch's automatic differentiation and GPU acceleration, TorchRef enables seamless integration with machine learning workflows and provides a flexible, extensible framework for crystallographic structure refinement.
10
+
11
+ ## Key Features
12
+
13
+ - **Native PyTorch Integration**: Built on PyTorch's `nn.Module` architecture, TorchRef integrates naturally with the PyTorch ecosystem, including machine learning models, optimizers, and GPU acceleration.
14
+
15
+ - **Automatic Differentiation**: Dynamic computational graphs eliminate the need for manually implemented gradient calculations. Define new refinement targets directly—PyTorch handles the derivatives automatically.
16
+
17
+ - **Modular Architecture**: Following PyTorch's module pattern, components are easily composable and extensible. Add custom targets, restraints, or optimizers without modifying core code.
18
+
19
+ - **GPU Acceleration**: Leverage CUDA for structure factor calculations, scaling, and optimization—achieving significant speedups for large structures.
20
+
21
+ - **FFT-based Structure Factors**: Efficient structure factor calculation using Fast Fourier Transform (FFT) methods, enabling rapid F_calc computation even for large unit cells.
22
+
23
+ - **State Management**: Full `state_dict` support enables saving and loading complete refinement states, including model parameters, scaler settings, and restraints.
24
+
25
+ ## Installation
26
+
27
+ ```bash
28
+ # Clone the repository
29
+ git clone https://github.com/HatPdotS/TorchRef.git
30
+ cd torchref
31
+
32
+ # Install with pip
33
+ pip install -e .
34
+
35
+ # Or install with development dependencies
36
+ pip install -e ".[dev]"
37
+ ```
38
+
39
+ ### Dependencies
40
+
41
+ - Python ≥ 3.8
42
+ - PyTorch ≥ 1.9
43
+ - NumPy ≥ 1.20
44
+ - Gemmi ≥ 0.5
45
+ - reciprocalspaceship ≥ 0.9
46
+ - SciPy ≥ 1.7
47
+
48
+ ## Getting Started
49
+
50
+ For demonstrations and usage examples, see the example notebooks in [`example_notebooks/`](example_notebooks/):
51
+
52
+ - [`basic_usage.ipynb`](example_notebooks/basic_usage.ipynb) - Getting started tutorial
53
+ - [`code_examples.ipynb`](example_notebooks/code_examples.ipynb) - Code examples and patterns
54
+ - [`target_exploration.ipynb`](example_notebooks/target_exploration.ipynb) - Exploring refinement targets
55
+
56
+ ## Testing
57
+
58
+ ```bash
59
+ # Run all tests
60
+ pytest tests/
61
+
62
+ # Run with coverage
63
+ pytest tests/ --cov=torchref
64
+
65
+ # Run specific test categories
66
+ pytest tests/unit/ # Fast unit tests
67
+ pytest tests/integration/ # Integration tests
68
+ pytest tests/functional/ # Full workflow tests
69
+ ```
70
+
71
+ ## Contributing
72
+
73
+ Contributions are welcome! Please follow these guidelines:
74
+
75
+ 1. Follow the [NumPy docstring style](https://numpydoc.readthedocs.io/en/latest/format.html)
76
+ 2. Add tests for new functionality
77
+ 3. Ensure all tests pass before submitting
78
+
79
+ ## License
80
+
81
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
82
+
83
+
84
+
85
+
@@ -0,0 +1,86 @@
1
+ [build-system]
2
+ requires = ["setuptools>=42", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "torchref"
7
+ version = "0.3.0"
8
+ description = "Tools for multicopy refinement of crystallographic models"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = "MIT"
12
+ authors = [
13
+ {name = "HansPeterSeidel"},
14
+ ]
15
+ # Version requirements validated via tox compatibility testing (2026-01-13)
16
+ # Tested combinations:
17
+ # - Python 3.9: numpy 1.24-1.26, pandas 2.0.x, torch 2.0-2.4, numba 0.59-0.60, scipy 1.10-1.13
18
+ # - Python 3.11: numpy 1.24-2.3, pandas 2.0-2.3, torch 2.0-2.9, numba 0.59-0.63, scipy 1.10-1.17
19
+ # - Python 3.12: numpy 2.0-2.3, pandas 2.2-2.3, torch 2.4-2.9, numba 0.61-0.63, scipy 1.13-1.17
20
+ # Upper bounds set to next minor version above tested maximum
21
+ dependencies = [
22
+ "numpy>=1.24.0,<2.4.0",
23
+ "pandas>=2.0.0,<2.4.0",
24
+ "torch>=2.0.0,<2.10.0",
25
+ "tqdm>=4.61.0,<4.68.0",
26
+ "numba>=0.59.0,<0.64.0",
27
+ "gemmi>=0.5.0,<0.8.0",
28
+ "scipy>=1.10.0,<1.18.0",
29
+ "matplotlib>=3.7.0,<3.11.0",
30
+ "reciprocalspaceship>=0.9.18,<1.1.0",
31
+ "pyarrow>=12.0.0,<23.0.0",
32
+ ]
33
+
34
+ [project.scripts]
35
+ torchref-refine = "torchref.cli.refine:main"
36
+ torchref-refine-screened = "torchref.cli.refine_screened:main"
37
+ torchref-refine-everything = "torchref.cli.refine_everything:main"
38
+ torchref-refine-random-weights = "torchref.cli.refine_everything_random_weights:main"
39
+ torchref-refine-policy = 'torchref.cli.refine_everything_policy:main'
40
+ torchref-refine-static = 'torchref.cli.refine_everything_static:main'
41
+ torchref-refine-hyperparameters = 'torchref.cli.refine_everything_hyperparameters:main'
42
+ torchref-download-data = 'torchref.cli.download_data:main'
43
+
44
+ [project.optional-dependencies]
45
+
46
+ dev = [
47
+ "pytest>=6.0.0",
48
+ "pytest-cov>=2.12.0",
49
+ "black>=21.5b2",
50
+ "isort>=5.9.0",
51
+ "flake8>=3.9.0",
52
+ ]
53
+
54
+ # Optional CCTBX/iotbx functionality (requires conda, not pip)
55
+ # Install with: conda install -c conda-forge cctbx-base
56
+ # The iotbx module is used in torchref/math_functions/CCTBX_related.py
57
+
58
+ docs = [
59
+ "sphinx>=4.0.0",
60
+ "sphinx-rtd-theme>=1.0.0",
61
+ "numpydoc>=1.1.0",
62
+ ]
63
+
64
+ [tool.setuptools]
65
+ packages = ["torchref"]
66
+
67
+ [tool.setuptools.package-data]
68
+ torchref = ["*.py"]
69
+
70
+ [tool.setuptools.exclude-package-data]
71
+ torchref = ["*.py"]
72
+
73
+ [tool.black]
74
+ line-length = 88
75
+ target-version = ["py39"]
76
+
77
+ [tool.isort]
78
+ profile = "black"
79
+ line_length = 88
80
+
81
+ [tool.pytest.ini_options]
82
+ testpaths = ["tests"]
83
+ python_files = "test_*.py"
84
+
85
+ [tool.ruff.lint]
86
+ ignore = ["F841","E741","F841","E402","E722"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,134 @@
1
+ """
2
+ TorchRef - A PyTorch-based crystallographic refinement library.
3
+
4
+ TorchRef provides GPU-accelerated crystallographic structure refinement
5
+ using PyTorch's automatic differentiation and nn.Module architecture.
6
+
7
+ Key Features
8
+ ------------
9
+ - Native PyTorch integration with nn.Module architecture
10
+ - Automatic differentiation for custom target functions
11
+ - GPU acceleration for structure factor calculations
12
+ - Modular design for easy extension
13
+
14
+ Quick Start
15
+ -----------
16
+ ::
17
+
18
+ from torchref import Refinement, ReflectionData, Model
19
+
20
+ # Load data and model
21
+ data = ReflectionData().load_mtz('data.mtz')
22
+ model = Model().load_pdb('structure.pdb')
23
+
24
+ # Run refinement
25
+ refinement = Refinement(data=data, model=model, device='cuda')
26
+ refinement.run_refinement(macro_cycles=10)
27
+
28
+ Modules
29
+ -------
30
+ io
31
+ File I/O for MTZ, PDB, CIF formats.
32
+ model
33
+ Atomic structure models (coordinates, B-factors, occupancies).
34
+ refinement
35
+ Core refinement framework with targets and weighting schemes.
36
+ restraints
37
+ Geometry restraints (bonds, angles, torsions, planes).
38
+ scaling
39
+ Structure factor scaling and bulk solvent models.
40
+ symmetry
41
+ Crystallographic symmetry operations.
42
+ alignment
43
+ Patterson-based structure alignment.
44
+ math_functions
45
+ Mathematical utilities for crystallography.
46
+ utils
47
+ General utilities and debugging tools.
48
+ """
49
+
50
+ __version__ = "0.3.0"
51
+
52
+
53
+ import os
54
+ import warnings
55
+ from pathlib import Path
56
+
57
+ from torchref._bootstrap import configure_threading, detect_available_cpus
58
+
59
+ # Configure threading before importing torch
60
+ if "TORCHREF_NUM_THREADS" in os.environ:
61
+ N_CPUS = int(os.environ["TORCHREF_NUM_THREADS"])
62
+ warnings.warn(
63
+ f"TorchRef using user-specified {N_CPUS} threads from TORCHREF_NUM_THREADS.",
64
+ stacklevel=2,
65
+ )
66
+ else:
67
+ N_CPUS = detect_available_cpus()
68
+ os.environ["TORCHREF_NUM_THREADS"] = str(N_CPUS)
69
+ warnings.warn(
70
+ f"TorchRef auto-configured {N_CPUS} threads. Set TORCHREF_NUM_THREADS to override.",
71
+ stacklevel=2,
72
+ )
73
+
74
+ configure_threading(N_CPUS)
75
+
76
+ import torch
77
+
78
+ torch.set_num_threads(N_CPUS)
79
+
80
+ # Dtype configuration (must be imported after torch)
81
+ from torchref.config import dtypes
82
+
83
+
84
+ # Project root path for referencing package files
85
+ ROOT_TORCHREF = Path(__file__).parent.parent.resolve()
86
+
87
+ # Package path for referencing internal files
88
+ PATH_TORCHREF = Path(__file__).parent.resolve()
89
+
90
+ PATH_TORCHREF_DATA = PATH_TORCHREF / "data"
91
+
92
+ # =============================================================================
93
+ # Convenience imports for common classes
94
+ # =============================================================================
95
+
96
+
97
+ # Data I/O
98
+ from torchref.io import DatasetCollection, ReflectionData
99
+
100
+ # Model
101
+ from torchref.model import Model, ModelFT
102
+
103
+ # Refinement
104
+ from torchref.refinement import LBFGSRefinement, Refinement
105
+
106
+ # Restraints
107
+ from torchref.restraints import Restraints
108
+
109
+ # Scaling
110
+ from torchref.scaling import Scaler, SolventModel
111
+
112
+ __all__ = [
113
+ # Version and paths
114
+ "__version__",
115
+ "ROOT_TORCHREF",
116
+ "PATH_TORCHREF",
117
+ "N_CPUS",
118
+ # Dtype configuration
119
+ "dtypes",
120
+ # Data I/O
121
+ "ReflectionData",
122
+ "DatasetCollection",
123
+ # Model
124
+ "Model",
125
+ "ModelFT",
126
+ # Refinement
127
+ "Refinement",
128
+ "LBFGSRefinement",
129
+ # Restraints
130
+ "Restraints",
131
+ # Scaling
132
+ "Scaler",
133
+ "SolventModel",
134
+ ]
@@ -0,0 +1,49 @@
1
+ """
2
+
3
+ Bootstrap module to configure threading before importing heavy libraries.
4
+
5
+ Is imported automatically when torchref is imported.
6
+ If you are not on a slurm node and want to customize threading,
7
+ call `configure_threading()` before importing torchref.
8
+
9
+ """
10
+
11
+ import os
12
+
13
+
14
+ def detect_available_cpus(max_if_not_slurm=4) -> int:
15
+ """Detect actual available CPUs respecting cgroups, affinity, and SLURM."""
16
+
17
+ # 1. Check SLURM first (most reliable on HPC)
18
+ slurm_cpus = os.environ.get("SLURM_CPUS_PER_TASK")
19
+ if slurm_cpus:
20
+ return int(slurm_cpus)
21
+
22
+ # 4. Check CPU affinity
23
+ try:
24
+ return min(len(os.sched_getaffinity(0)), 4)
25
+ except (AttributeError, OSError):
26
+ pass
27
+
28
+ # 5. Fallback to os.cpu_count() but cap it sensibly
29
+ return min(os.cpu_count() or 1, 4)
30
+
31
+
32
+ def configure_threading(num_threads: int = None, pin_threads=False) -> int:
33
+ """Configure all threading libraries. Call BEFORE importing torch/numpy."""
34
+
35
+ if num_threads is None:
36
+ num_threads = detect_available_cpus()
37
+
38
+ n = str(num_threads)
39
+
40
+ os.environ["OMP_NUM_THREADS"] = n
41
+ os.environ["MKL_NUM_THREADS"] = n
42
+ os.environ["OPENBLAS_NUM_THREADS"] = n
43
+
44
+ if pin_threads:
45
+ # Optional: enable thread pinning
46
+ os.environ["OMP_PROC_BIND"] = "TRUE"
47
+ os.environ["OMP_PLACES"] = "cores"
48
+
49
+ return num_threads
@@ -0,0 +1,147 @@
1
+ """
2
+ Centralized configuration for TorchRef.
3
+
4
+ Default dtypes can be set via environment variables at import time:
5
+ - TORCHREF_DTYPE_FLOAT: float32 (default) or float64
6
+ - TORCHREF_DTYPE_INT: int32 (default) or int64
7
+ - TORCHREF_DTYPE_COMPLEX: complex64 (default) or complex128
8
+
9
+ Users can also change dtypes at runtime via attribute assignment:
10
+ import torchref
11
+ torchref.dtypes.float = torch.float64
12
+ torchref.dtypes.int = torch.int64
13
+ torchref.dtypes.complex = torch.complex128
14
+
15
+ Or read current values:
16
+ torchref.dtypes.float # torch.float32
17
+ torchref.dtypes.int # torch.int32
18
+ torchref.dtypes.complex # torch.complex64
19
+ """
20
+
21
+ import os
22
+
23
+ import torch
24
+
25
+ # Map strings to torch dtypes
26
+ _FLOAT_DTYPE_MAP = {
27
+ "float32": torch.float32,
28
+ "float64": torch.float64,
29
+ }
30
+
31
+ _INT_DTYPE_MAP = {
32
+ "int32": torch.int32,
33
+ "int64": torch.int64,
34
+ }
35
+
36
+ _COMPLEX_DTYPE_MAP = {
37
+ "complex64": torch.complex64,
38
+ "complex128": torch.complex128,
39
+ }
40
+
41
+
42
+ class DtypeConfig:
43
+ """
44
+ Dtype configuration with property-based access.
45
+
46
+ Access dtypes as attributes:
47
+ dtypes.float # get current float dtype
48
+ dtypes.int # get current int dtype
49
+ dtypes.complex # get current complex dtype
50
+
51
+ Set dtypes via assignment:
52
+ dtypes.float = torch.float64
53
+ dtypes.int = torch.int64
54
+ dtypes.complex = torch.complex128
55
+ """
56
+
57
+ def __init__(self):
58
+ # Parse environment variables with defaults
59
+ float_str = os.environ.get("TORCHREF_DTYPE_FLOAT", "float32").lower()
60
+ int_str = os.environ.get("TORCHREF_DTYPE_INT", "int32").lower()
61
+ complex_str = os.environ.get("TORCHREF_DTYPE_COMPLEX", "complex64").lower()
62
+
63
+ # Validate and set
64
+ if float_str not in _FLOAT_DTYPE_MAP:
65
+ raise ValueError(
66
+ f"Invalid TORCHREF_DTYPE_FLOAT: {float_str}. "
67
+ f"Valid values: {list(_FLOAT_DTYPE_MAP.keys())}"
68
+ )
69
+ if int_str not in _INT_DTYPE_MAP:
70
+ raise ValueError(
71
+ f"Invalid TORCHREF_DTYPE_INT: {int_str}. "
72
+ f"Valid values: {list(_INT_DTYPE_MAP.keys())}"
73
+ )
74
+ if complex_str not in _COMPLEX_DTYPE_MAP:
75
+ raise ValueError(
76
+ f"Invalid TORCHREF_DTYPE_COMPLEX: {complex_str}. "
77
+ f"Valid values: {list(_COMPLEX_DTYPE_MAP.keys())}"
78
+ )
79
+
80
+ self._float = _FLOAT_DTYPE_MAP[float_str]
81
+ self._int = _INT_DTYPE_MAP[int_str]
82
+ self._complex = _COMPLEX_DTYPE_MAP[complex_str]
83
+
84
+ @property
85
+ def float(self) -> torch.dtype:
86
+ """Get the current default float dtype."""
87
+ return self._float
88
+
89
+ @float.setter
90
+ def float(self, dtype: torch.dtype) -> None:
91
+ """Set the default float dtype for all future operations."""
92
+ if dtype not in (torch.float32, torch.float64):
93
+ raise ValueError(
94
+ f"Invalid float dtype: {dtype}. Use torch.float32 or torch.float64."
95
+ )
96
+ self._float = dtype
97
+
98
+ @property
99
+ def int(self) -> torch.dtype:
100
+ """Get the current default int dtype."""
101
+ return self._int
102
+
103
+ @int.setter
104
+ def int(self, dtype: torch.dtype) -> None:
105
+ """Set the default int dtype for all future operations."""
106
+ if dtype not in (torch.int32, torch.int64):
107
+ raise ValueError(
108
+ f"Invalid int dtype: {dtype}. Use torch.int32 or torch.int64."
109
+ )
110
+ self._int = dtype
111
+
112
+ @property
113
+ def complex(self) -> torch.dtype:
114
+ """Get the current default complex dtype."""
115
+ return self._complex
116
+
117
+ @complex.setter
118
+ def complex(self, dtype: torch.dtype) -> None:
119
+ """Set the default complex dtype for all future operations."""
120
+ if dtype not in (torch.complex64, torch.complex128):
121
+ raise ValueError(
122
+ f"Invalid complex dtype: {dtype}. Use torch.complex64 or torch.complex128."
123
+ )
124
+ self._complex = dtype
125
+
126
+ def __repr__(self) -> str:
127
+ return f"DtypeConfig(float={self._float}, int={self._int}, complex={self._complex})"
128
+
129
+
130
+ # Global singleton instance
131
+ dtypes = DtypeConfig()
132
+
133
+
134
+ # Convenience functions for internal use (avoid repeated attribute lookups)
135
+ def get_float_dtype() -> torch.dtype:
136
+ """Get the current default float dtype."""
137
+ return dtypes.float
138
+
139
+
140
+ def get_int_dtype() -> torch.dtype:
141
+ """Get the current default int dtype."""
142
+ return dtypes.int
143
+
144
+
145
+ def get_complex_dtype() -> torch.dtype:
146
+ """Get the current default complex dtype."""
147
+ return dtypes.complex
@@ -0,0 +1,116 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchref
3
+ Version: 0.3.0
4
+ Summary: Tools for multicopy refinement of crystallographic models
5
+ Author: HansPeterSeidel
6
+ License-Expression: MIT
7
+ Requires-Python: >=3.9
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: numpy<2.4.0,>=1.24.0
11
+ Requires-Dist: pandas<2.4.0,>=2.0.0
12
+ Requires-Dist: torch<2.10.0,>=2.0.0
13
+ Requires-Dist: tqdm<4.68.0,>=4.61.0
14
+ Requires-Dist: numba<0.64.0,>=0.59.0
15
+ Requires-Dist: gemmi<0.8.0,>=0.5.0
16
+ Requires-Dist: scipy<1.18.0,>=1.10.0
17
+ Requires-Dist: matplotlib<3.11.0,>=3.7.0
18
+ Requires-Dist: reciprocalspaceship<1.1.0,>=0.9.18
19
+ Requires-Dist: pyarrow<23.0.0,>=12.0.0
20
+ Provides-Extra: dev
21
+ Requires-Dist: pytest>=6.0.0; extra == "dev"
22
+ Requires-Dist: pytest-cov>=2.12.0; extra == "dev"
23
+ Requires-Dist: black>=21.5b2; extra == "dev"
24
+ Requires-Dist: isort>=5.9.0; extra == "dev"
25
+ Requires-Dist: flake8>=3.9.0; extra == "dev"
26
+ Provides-Extra: docs
27
+ Requires-Dist: sphinx>=4.0.0; extra == "docs"
28
+ Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == "docs"
29
+ Requires-Dist: numpydoc>=1.1.0; extra == "docs"
30
+ Dynamic: license-file
31
+
32
+ # TorchRef
33
+
34
+ **A PyTorch-based crystallographic refinement library**
35
+
36
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
37
+ [![PyTorch](https://img.shields.io/badge/PyTorch-1.9+-ee4c2c.svg)](https://pytorch.org/)
38
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
39
+
40
+ TorchRef is a crystallographic refinement package built entirely on PyTorch. By leveraging PyTorch's automatic differentiation and GPU acceleration, TorchRef enables seamless integration with machine learning workflows and provides a flexible, extensible framework for crystallographic structure refinement.
41
+
42
+ ## Key Features
43
+
44
+ - **Native PyTorch Integration**: Built on PyTorch's `nn.Module` architecture, TorchRef integrates naturally with the PyTorch ecosystem, including machine learning models, optimizers, and GPU acceleration.
45
+
46
+ - **Automatic Differentiation**: Dynamic computational graphs eliminate the need for manually implemented gradient calculations. Define new refinement targets directly—PyTorch handles the derivatives automatically.
47
+
48
+ - **Modular Architecture**: Following PyTorch's module pattern, components are easily composable and extensible. Add custom targets, restraints, or optimizers without modifying core code.
49
+
50
+ - **GPU Acceleration**: Leverage CUDA for structure factor calculations, scaling, and optimization—achieving significant speedups for large structures.
51
+
52
+ - **FFT-based Structure Factors**: Efficient structure factor calculation using Fast Fourier Transform (FFT) methods, enabling rapid F_calc computation even for large unit cells.
53
+
54
+ - **State Management**: Full `state_dict` support enables saving and loading complete refinement states, including model parameters, scaler settings, and restraints.
55
+
56
+ ## Installation
57
+
58
+ ```bash
59
+ # Clone the repository
60
+ git clone https://github.com/HatPdotS/TorchRef.git
61
+ cd torchref
62
+
63
+ # Install with pip
64
+ pip install -e .
65
+
66
+ # Or install with development dependencies
67
+ pip install -e ".[dev]"
68
+ ```
69
+
70
+ ### Dependencies
71
+
72
+ - Python ≥ 3.8
73
+ - PyTorch ≥ 1.9
74
+ - NumPy ≥ 1.20
75
+ - Gemmi ≥ 0.5
76
+ - reciprocalspaceship ≥ 0.9
77
+ - SciPy ≥ 1.7
78
+
79
+ ## Getting Started
80
+
81
+ For demonstrations and usage examples, see the example notebooks in [`example_notebooks/`](example_notebooks/):
82
+
83
+ - [`basic_usage.ipynb`](example_notebooks/basic_usage.ipynb) - Getting started tutorial
84
+ - [`code_examples.ipynb`](example_notebooks/code_examples.ipynb) - Code examples and patterns
85
+ - [`target_exploration.ipynb`](example_notebooks/target_exploration.ipynb) - Exploring refinement targets
86
+
87
+ ## Testing
88
+
89
+ ```bash
90
+ # Run all tests
91
+ pytest tests/
92
+
93
+ # Run with coverage
94
+ pytest tests/ --cov=torchref
95
+
96
+ # Run specific test categories
97
+ pytest tests/unit/ # Fast unit tests
98
+ pytest tests/integration/ # Integration tests
99
+ pytest tests/functional/ # Full workflow tests
100
+ ```
101
+
102
+ ## Contributing
103
+
104
+ Contributions are welcome! Please follow these guidelines:
105
+
106
+ 1. Follow the [NumPy docstring style](https://numpydoc.readthedocs.io/en/latest/format.html)
107
+ 2. Add tests for new functionality
108
+ 3. Ensure all tests pass before submitting
109
+
110
+ ## License
111
+
112
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
113
+
114
+
115
+
116
+
@@ -0,0 +1,12 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ torchref/__init__.py
5
+ torchref/_bootstrap.py
6
+ torchref/config.py
7
+ torchref.egg-info/PKG-INFO
8
+ torchref.egg-info/SOURCES.txt
9
+ torchref.egg-info/dependency_links.txt
10
+ torchref.egg-info/entry_points.txt
11
+ torchref.egg-info/requires.txt
12
+ torchref.egg-info/top_level.txt
@@ -0,0 +1,9 @@
1
+ [console_scripts]
2
+ torchref-download-data = torchref.cli.download_data:main
3
+ torchref-refine = torchref.cli.refine:main
4
+ torchref-refine-everything = torchref.cli.refine_everything:main
5
+ torchref-refine-hyperparameters = torchref.cli.refine_everything_hyperparameters:main
6
+ torchref-refine-policy = torchref.cli.refine_everything_policy:main
7
+ torchref-refine-random-weights = torchref.cli.refine_everything_random_weights:main
8
+ torchref-refine-screened = torchref.cli.refine_screened:main
9
+ torchref-refine-static = torchref.cli.refine_everything_static:main
@@ -0,0 +1,22 @@
1
+ numpy<2.4.0,>=1.24.0
2
+ pandas<2.4.0,>=2.0.0
3
+ torch<2.10.0,>=2.0.0
4
+ tqdm<4.68.0,>=4.61.0
5
+ numba<0.64.0,>=0.59.0
6
+ gemmi<0.8.0,>=0.5.0
7
+ scipy<1.18.0,>=1.10.0
8
+ matplotlib<3.11.0,>=3.7.0
9
+ reciprocalspaceship<1.1.0,>=0.9.18
10
+ pyarrow<23.0.0,>=12.0.0
11
+
12
+ [dev]
13
+ pytest>=6.0.0
14
+ pytest-cov>=2.12.0
15
+ black>=21.5b2
16
+ isort>=5.9.0
17
+ flake8>=3.9.0
18
+
19
+ [docs]
20
+ sphinx>=4.0.0
21
+ sphinx-rtd-theme>=1.0.0
22
+ numpydoc>=1.1.0
@@ -0,0 +1 @@
1
+ torchref