chempleter 0.1.0b1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Davis Thomas Daniel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,157 @@
1
+ Metadata-Version: 2.4
2
+ Name: chempleter
3
+ Version: 0.1.0b1
4
+ Summary: A lightweight generative model that extends SMILES fragments into syntactically valid molecules
5
+ License-Expression: MIT
6
+ License-File: LICENSE
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Requires-Dist: nicegui>=3.4.1
9
+ Requires-Dist: pandas>=2.3.3
10
+ Requires-Dist: pathlib>=1.0.1
11
+ Requires-Dist: rdkit>=2025.9.3
12
+ Requires-Dist: selfies>=2.2.0
13
+ Requires-Dist: torch>=2.9.1
14
+ Requires-Dist: torch>=2.9.1 ; extra == 'cpu'
15
+ Requires-Dist: torch>=2.9.1 ; extra == 'gpu128'
16
+ Requires-Python: >=3.13
17
+ Provides-Extra: cpu
18
+ Provides-Extra: gpu128
19
+ Description-Content-Type: text/markdown
20
+
21
+ # Chempleter
22
+
23
+ Chempleter is lightweight generative model which utlises a simple Gated Recurrent Unit (GRU) to predict syntactically valid extensions of a provided molecular fragment.
24
+ It accepts SMILES notation as input and enforces chemical syntax validity using SELFIES for the generated molecules.
25
+
26
+ <div align="center">
27
+ <img src="https://raw.githubusercontent.com/davistdaniel/chempleter/main/screenshots/demo.gif" alt="Demo Gif" width="400">
28
+ </div>
29
+
30
+
31
+ * Why was Chempleter made?
32
+ * Mainly for me to get into Pytorch. Also, I find it fun to generate random, possibly unsynthesisable molecules from a starting structure.
33
+
34
+ * What can Chempleter do?
35
+
36
+ * Currently, Chempleter accepts an intial molecule/molecular fragment in SMILES format and generates a larger molecule with that intial structure included, while respecting chemical syntax.
37
+
38
+ * It can be used to generate a wide range of structural analogs which the share same core structure (by changing the sampling temperature) or decorate a core scaffold iteratively (by increasing generated token lengths)
39
+
40
+ * In the future, it might be adapated to predict structures with a specific chemical property using a regressor to rank predictions and transition towards more "goal-directed" predictions.
41
+
42
+
43
+ ## Prerequisites
44
+ * Python ">=3.13"
45
+ * See [pyproject.toml](pyproject.toml) for dependencies.
46
+ * [uv](https://docs.astral.sh/uv/) (optional but recommended)
47
+
48
+ ## Get started
49
+
50
+
51
+ You can install chempleter using any one of the following ways:
52
+
53
+ - #### Install from PyPi
54
+
55
+ ``python -m pip install chempleter``
56
+
57
+ By default, the CPU version of pytorch will be installed. Alternatively, you can install a PyTorch version compatible with your CUDA version by following the [Pytorch documentation](https://pytorch.org/get-started/locally/).
58
+
59
+ - #### Install using uv
60
+
61
+ 1. Clone this repo
62
+
63
+ ``git clone https://github.com/davistdaniel/chempleter.git``
64
+
65
+ 2. Inside the project directory, exceute in a terminal:
66
+
67
+ ``uv sync``
68
+
69
+ By default, the CPU version of pytorch will be installed, in case of using GPU as accelerator and CUDA 12.8:
70
+
71
+ ``uv sync --extra gpu128``
72
+
73
+ Alternatively, you can install a PyTorch version compatible with your CUDA version by following the [Pytorch documentation](https://pytorch.org/get-started/locally/).
74
+
75
+
76
+
77
+
78
+ ### Usage
79
+
80
+ #### GUI
81
+ * To start the Chempleter GUI:
82
+
83
+ ``chempleter-gui``
84
+
85
+ or
86
+
87
+ ``uv run src/chempleter/gui.py``
88
+
89
+
90
+ * Type in the SMILES notation for the starting structure or leave it empty to generate random molecules. Click on ``GENERATE`` button to generate a molecule.
91
+ * Options:
92
+ * Temperature : Increasing the temperature would result in more unusual molecules, while lower values would generate more common structures.
93
+ * Sampling : `Most probable` selects the molecule with the highest likelihood for the given starting structure, producing the same result on repeated generations. `Random` generates a new molecule each time, while still including the input structure.
94
+
95
+
96
+ #### As a python library
97
+
98
+ * To use Chempleter as a python library:
99
+
100
+ ```python
101
+ from chempleter.inference import extend
102
+ generated_mol, generated_smiles, generated_selfies = extend(smiles="c1ccccc1")
103
+ print(generated_smiles)
104
+ >> C1=CC=CC=C1C2=CC=C(CN3C=NC4=CC=CC=C4C3=O)O2
105
+ ```
106
+
107
+ To draw the generated molecule :
108
+
109
+ ```python
110
+ from rdkit import Chem
111
+ Chem.Draw.MolToImage(generated_mol)
112
+ ```
113
+ * For details on available parameters, refer to the ``extend`` (``chempleter.inference`` module) function’s docstring.
114
+
115
+ ### Current model performance
116
+
117
+ Performance metrics were evaluated across 500 independent generations using a model checkpoint trained for 80 epochs with a batch size of 64.
118
+
119
+ | Metric | Value | Description |
120
+ |------------|-------|--------------------------------------------------------------------------------------------------------------|
121
+ | Validity | 1.0 | Proportion of Generated SMILES which respect chemical syntax; tested using selfies decoder and RDkit parser. |
122
+ | Uniqueness | 0.96 | Proportion of Generated SMILES which were unique |
123
+ | Novelty | 0.85 | Proportion of Generated SMILES which were not present in the training datatset |
124
+
125
+
126
+ ### Project structure
127
+ * src/chempleter: Contains python modules relating to different functions.
128
+ * src/chempleter/processor.py: Contains fucntions for processing csv files containing SMILES data and generating training-related files.
129
+ * src/chempleter/dataset.py: ChempleterDataset class
130
+ * src/chempleter/model.py: ChempleterModel class
131
+ * src/chempleter/inference.py: Contains functions for inference
132
+ * src/chempleter/train.py: Contains functions for training
133
+ * src/chempleter/gui.py: Chempleter GUI built using NiceGUI
134
+ * src/chempleter/data : Contains trained model, vocabulary files
135
+
136
+ # License
137
+
138
+ [MIT](https://github.com/davistdaniel/chempleter/tree/main?tab=MIT-1-ov-file#readme) License
139
+
140
+ Copyright (c) 2025 Davis Thomas Daniel
141
+
142
+ # Contributing
143
+
144
+ Any contribution, improvements, feature ideas or bug fixes are always welcome.
145
+
146
+ ## Random Notes
147
+
148
+ * Training data
149
+ * QM9 and ZINC datasets. 379997 molecules were used for training in total.
150
+ * Running wihout a GPU
151
+ * Chempleter uses a 2-layer GRU, it should run comfortably on a CPU.
152
+
153
+
154
+
155
+
156
+
157
+
@@ -0,0 +1,137 @@
1
+ # Chempleter
2
+
3
+ Chempleter is lightweight generative model which utlises a simple Gated Recurrent Unit (GRU) to predict syntactically valid extensions of a provided molecular fragment.
4
+ It accepts SMILES notation as input and enforces chemical syntax validity using SELFIES for the generated molecules.
5
+
6
+ <div align="center">
7
+ <img src="https://raw.githubusercontent.com/davistdaniel/chempleter/main/screenshots/demo.gif" alt="Demo Gif" width="400">
8
+ </div>
9
+
10
+
11
+ * Why was Chempleter made?
12
+ * Mainly for me to get into Pytorch. Also, I find it fun to generate random, possibly unsynthesisable molecules from a starting structure.
13
+
14
+ * What can Chempleter do?
15
+
16
+ * Currently, Chempleter accepts an intial molecule/molecular fragment in SMILES format and generates a larger molecule with that intial structure included, while respecting chemical syntax.
17
+
18
+ * It can be used to generate a wide range of structural analogs which the share same core structure (by changing the sampling temperature) or decorate a core scaffold iteratively (by increasing generated token lengths)
19
+
20
+ * In the future, it might be adapated to predict structures with a specific chemical property using a regressor to rank predictions and transition towards more "goal-directed" predictions.
21
+
22
+
23
+ ## Prerequisites
24
+ * Python ">=3.13"
25
+ * See [pyproject.toml](pyproject.toml) for dependencies.
26
+ * [uv](https://docs.astral.sh/uv/) (optional but recommended)
27
+
28
+ ## Get started
29
+
30
+
31
+ You can install chempleter using any one of the following ways:
32
+
33
+ - #### Install from PyPi
34
+
35
+ ``python -m pip install chempleter``
36
+
37
+ By default, the CPU version of pytorch will be installed. Alternatively, you can install a PyTorch version compatible with your CUDA version by following the [Pytorch documentation](https://pytorch.org/get-started/locally/).
38
+
39
+ - #### Install using uv
40
+
41
+ 1. Clone this repo
42
+
43
+ ``git clone https://github.com/davistdaniel/chempleter.git``
44
+
45
+ 2. Inside the project directory, exceute in a terminal:
46
+
47
+ ``uv sync``
48
+
49
+ By default, the CPU version of pytorch will be installed, in case of using GPU as accelerator and CUDA 12.8:
50
+
51
+ ``uv sync --extra gpu128``
52
+
53
+ Alternatively, you can install a PyTorch version compatible with your CUDA version by following the [Pytorch documentation](https://pytorch.org/get-started/locally/).
54
+
55
+
56
+
57
+
58
+ ### Usage
59
+
60
+ #### GUI
61
+ * To start the Chempleter GUI:
62
+
63
+ ``chempleter-gui``
64
+
65
+ or
66
+
67
+ ``uv run src/chempleter/gui.py``
68
+
69
+
70
+ * Type in the SMILES notation for the starting structure or leave it empty to generate random molecules. Click on ``GENERATE`` button to generate a molecule.
71
+ * Options:
72
+ * Temperature : Increasing the temperature would result in more unusual molecules, while lower values would generate more common structures.
73
+ * Sampling : `Most probable` selects the molecule with the highest likelihood for the given starting structure, producing the same result on repeated generations. `Random` generates a new molecule each time, while still including the input structure.
74
+
75
+
76
+ #### As a python library
77
+
78
+ * To use Chempleter as a python library:
79
+
80
+ ```python
81
+ from chempleter.inference import extend
82
+ generated_mol, generated_smiles, generated_selfies = extend(smiles="c1ccccc1")
83
+ print(generated_smiles)
84
+ >> C1=CC=CC=C1C2=CC=C(CN3C=NC4=CC=CC=C4C3=O)O2
85
+ ```
86
+
87
+ To draw the generated molecule :
88
+
89
+ ```python
90
+ from rdkit import Chem
91
+ Chem.Draw.MolToImage(generated_mol)
92
+ ```
93
+ * For details on available parameters, refer to the ``extend`` (``chempleter.inference`` module) function’s docstring.
94
+
95
+ ### Current model performance
96
+
97
+ Performance metrics were evaluated across 500 independent generations using a model checkpoint trained for 80 epochs with a batch size of 64.
98
+
99
+ | Metric | Value | Description |
100
+ |------------|-------|--------------------------------------------------------------------------------------------------------------|
101
+ | Validity | 1.0 | Proportion of Generated SMILES which respect chemical syntax; tested using selfies decoder and RDkit parser. |
102
+ | Uniqueness | 0.96 | Proportion of Generated SMILES which were unique |
103
+ | Novelty | 0.85 | Proportion of Generated SMILES which were not present in the training datatset |
104
+
105
+
106
+ ### Project structure
107
+ * src/chempleter: Contains python modules relating to different functions.
108
+ * src/chempleter/processor.py: Contains fucntions for processing csv files containing SMILES data and generating training-related files.
109
+ * src/chempleter/dataset.py: ChempleterDataset class
110
+ * src/chempleter/model.py: ChempleterModel class
111
+ * src/chempleter/inference.py: Contains functions for inference
112
+ * src/chempleter/train.py: Contains functions for training
113
+ * src/chempleter/gui.py: Chempleter GUI built using NiceGUI
114
+ * src/chempleter/data : Contains trained model, vocabulary files
115
+
116
+ # License
117
+
118
+ [MIT](https://github.com/davistdaniel/chempleter/tree/main?tab=MIT-1-ov-file#readme) License
119
+
120
+ Copyright (c) 2025 Davis Thomas Daniel
121
+
122
+ # Contributing
123
+
124
+ Any contribution, improvements, feature ideas or bug fixes are always welcome.
125
+
126
+ ## Random Notes
127
+
128
+ * Training data
129
+ * QM9 and ZINC datasets. 379997 molecules were used for training in total.
130
+ * Running wihout a GPU
131
+ * Chempleter uses a 2-layer GRU, it should run comfortably on a CPU.
132
+
133
+
134
+
135
+
136
+
137
+
@@ -0,0 +1,81 @@
1
+ [project]
2
+ name = "chempleter"
3
+ version = "0.1.0b1"
4
+ description = "A lightweight generative model that extends SMILES fragments into syntactically valid molecules"
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "nicegui>=3.4.1",
9
+ "pandas>=2.3.3",
10
+ "pathlib>=1.0.1",
11
+ "rdkit>=2025.9.3",
12
+ "selfies>=2.2.0",
13
+ "torch>=2.9.1",
14
+ ]
15
+
16
+ classifiers = [
17
+ "License :: OSI Approved :: MIT License"
18
+ ]
19
+ license = "MIT"
20
+ license-files = ["LICENSE"]
21
+
22
+ [project.scripts]
23
+ chempleter-gui = "chempleter.gui:run_chempleter_gui"
24
+
25
+ [project.optional-dependencies]
26
+ cpu = [
27
+ "torch>=2.9.1",
28
+ ]
29
+ gpu128 = [
30
+ "torch>=2.9.1",
31
+ ]
32
+
33
+ [tool.uv.sources]
34
+ torch = [
35
+ { index = "pytorch-cpu" , extra = "cpu"},
36
+ { index = "pytorch-cu128", extra = "gpu128" },
37
+ ]
38
+
39
+ [tool.uv]
40
+ conflicts = [
41
+ [
42
+ { extra = "cpu" },
43
+ { extra = "gpu128" },
44
+ ],
45
+ ]
46
+
47
+ [[tool.uv.index]]
48
+ name = "pytorch-cu128"
49
+ url = "https://download.pytorch.org/whl/cu128"
50
+ explicit = true
51
+
52
+ [[tool.uv.index]]
53
+ name = "pytorch-cpu"
54
+ url = "https://download.pytorch.org/whl/cpu"
55
+ explicit = true
56
+
57
+ [[tool.uv.index]]
58
+ name = "testpypi"
59
+ url = "https://test.pypi.org/simple/"
60
+ publish-url = "https://test.pypi.org/legacy/"
61
+ explicit = true
62
+
63
+ [build-system]
64
+ requires = ["uv_build>=0.8.8,<0.9.0"]
65
+ build-backend = "uv_build"
66
+
67
+ [tool.uv_build]
68
+ src-layout = true
69
+ package-data = {"chempleter" = ["data/*"]}
70
+
71
+ [dependency-groups]
72
+ dev = [
73
+ "ipykernel>=7.1.0",
74
+ "jupyter>=1.1.1",
75
+ "marimo>=0.18.4",
76
+ "pytest>=9.0.2",
77
+ "pytest-cov>=7.0.0",
78
+ "ruff>=0.14.10",
79
+ "tqdm>=4.67.1",
80
+ "twine>=6.2.0",
81
+ ]
@@ -0,0 +1,31 @@
1
+ # chempleter
2
+
3
+ __version__ = "0.1.0b1"
4
+
5
+ from pathlib import Path
6
+ import logging
7
+
8
+ # logging setup
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format="%(asctime)s - %(levelname)s - %(filename)s - %(message)s",
12
+ )
13
+
14
+
15
+ def start_experiment(experiment_name, working_dir=None):
16
+ """
17
+ Docstring for start_experiment
18
+
19
+ :param experiment_name: Description
20
+ :param working_dir: Description
21
+ """
22
+
23
+ if not working_dir:
24
+ working_dir = Path().cwd() / experiment_name
25
+ else:
26
+ working_dir = Path(working_dir) / experiment_name
27
+
28
+ # make dir
29
+ working_dir.mkdir(parents=True, exist_ok=True)
30
+
31
+ return working_dir
@@ -0,0 +1 @@
1
+ ["[PAD]", "[START]", "[END]", "[#Branch1]", "[#Branch2]", "[#C]", "[#N+1]", "[#N]", "[-/Ring1]", "[-/Ring2]", "[-\\Ring1]", "[/Br]", "[/C@@H1]", "[/C@@]", "[/C@H1]", "[/C@]", "[/C]", "[/Cl]", "[/F]", "[/N+1]", "[/N-1]", "[/NH1+1]", "[/NH1-1]", "[/NH1]", "[/NH2+1]", "[/N]", "[/O+1]", "[/O-1]", "[/O]", "[/S-1]", "[/S@]", "[/S]", "[=Branch1]", "[=Branch2]", "[=C]", "[=N+1]", "[=N-1]", "[=NH1+1]", "[=NH2+1]", "[=N]", "[=O+1]", "[=OH1+1]", "[=O]", "[=P@@]", "[=P@]", "[=PH2]", "[=P]", "[=Ring1]", "[=Ring2]", "[=S+1]", "[=S@@]", "[=S@]", "[=SH1+1]", "[=S]", "[Br]", "[Branch1]", "[Branch2]", "[C-1]", "[C@@H1]", "[C@@]", "[C@H1]", "[C@]", "[CH1-1]", "[CH2-1]", "[C]", "[Cl]", "[F]", "[I]", "[N+1]", "[N-1]", "[NH1+1]", "[NH1-1]", "[NH1]", "[NH2+1]", "[NH3+1]", "[N]", "[O-1]", "[O]", "[P+1]", "[P@@H1]", "[P@@]", "[P@]", "[PH1+1]", "[PH1]", "[P]", "[Ring1]", "[Ring2]", "[S+1]", "[S-1]", "[S@@+1]", "[S@@]", "[S@]", "[S]", "[\\Br]", "[\\C@@H1]", "[\\C@H1]", "[\\C]", "[\\Cl]", "[\\F]", "[\\I]", "[\\N+1]", "[\\N-1]", "[\\NH1+1]", "[\\NH1]", "[\\NH2+1]", "[\\N]", "[\\O-1]", "[\\O]", "[\\S-1]", "[\\S@]", "[\\S]"]
@@ -0,0 +1 @@
1
+ {"[PAD]": 0, "[START]": 1, "[END]": 2, "[#Branch1]": 3, "[#Branch2]": 4, "[#C]": 5, "[#N+1]": 6, "[#N]": 7, "[-/Ring1]": 8, "[-/Ring2]": 9, "[-\\Ring1]": 10, "[/Br]": 11, "[/C@@H1]": 12, "[/C@@]": 13, "[/C@H1]": 14, "[/C@]": 15, "[/C]": 16, "[/Cl]": 17, "[/F]": 18, "[/N+1]": 19, "[/N-1]": 20, "[/NH1+1]": 21, "[/NH1-1]": 22, "[/NH1]": 23, "[/NH2+1]": 24, "[/N]": 25, "[/O+1]": 26, "[/O-1]": 27, "[/O]": 28, "[/S-1]": 29, "[/S@]": 30, "[/S]": 31, "[=Branch1]": 32, "[=Branch2]": 33, "[=C]": 34, "[=N+1]": 35, "[=N-1]": 36, "[=NH1+1]": 37, "[=NH2+1]": 38, "[=N]": 39, "[=O+1]": 40, "[=OH1+1]": 41, "[=O]": 42, "[=P@@]": 43, "[=P@]": 44, "[=PH2]": 45, "[=P]": 46, "[=Ring1]": 47, "[=Ring2]": 48, "[=S+1]": 49, "[=S@@]": 50, "[=S@]": 51, "[=SH1+1]": 52, "[=S]": 53, "[Br]": 54, "[Branch1]": 55, "[Branch2]": 56, "[C-1]": 57, "[C@@H1]": 58, "[C@@]": 59, "[C@H1]": 60, "[C@]": 61, "[CH1-1]": 62, "[CH2-1]": 63, "[C]": 64, "[Cl]": 65, "[F]": 66, "[I]": 67, "[N+1]": 68, "[N-1]": 69, "[NH1+1]": 70, "[NH1-1]": 71, "[NH1]": 72, "[NH2+1]": 73, "[NH3+1]": 74, "[N]": 75, "[O-1]": 76, "[O]": 77, "[P+1]": 78, "[P@@H1]": 79, "[P@@]": 80, "[P@]": 81, "[PH1+1]": 82, "[PH1]": 83, "[P]": 84, "[Ring1]": 85, "[Ring2]": 86, "[S+1]": 87, "[S-1]": 88, "[S@@+1]": 89, "[S@@]": 90, "[S@]": 91, "[S]": 92, "[\\Br]": 93, "[\\C@@H1]": 94, "[\\C@H1]": 95, "[\\C]": 96, "[\\Cl]": 97, "[\\F]": 98, "[\\I]": 99, "[\\N+1]": 100, "[\\N-1]": 101, "[\\NH1+1]": 102, "[\\NH1]": 103, "[\\NH2+1]": 104, "[\\N]": 105, "[\\O-1]": 106, "[\\O]": 107, "[\\S-1]": 108, "[\\S@]": 109, "[\\S]": 110}
@@ -0,0 +1,82 @@
1
+ import json
2
+ import torch
3
+ import selfies as sf
4
+ import pandas as pd
5
+ from torch.utils.data import Dataset
6
+ from torch.utils.data import DataLoader
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ class ChempleterDataset(Dataset):
11
+ """
12
+ PyTorch Dataset for SELFIES molecular representations.
13
+
14
+ :param selfies_file: Path to CSV file containing SELFIES strings in a "selfies" column.
15
+ :type selfies_file: str
16
+ :param stoi_file: Path to JSON file mapping SELFIES symbols to integer tokens.
17
+ :type stoi_file: str
18
+ :returns: Integer tensor representation of tokenized molecule with dtype=torch.long.
19
+ :rtype: torch.Tensor
20
+ """
21
+
22
+ def __init__(self, selfies_file, stoi_file):
23
+ super().__init__()
24
+ selfies_dataframe = pd.read_csv(selfies_file)
25
+ self.data = selfies_dataframe["selfies"].to_list()
26
+ with open(stoi_file) as f:
27
+ self.selfies_to_integer = json.load(f)
28
+
29
+ def __len__(self):
30
+ return len(self.data)
31
+
32
+ def __getitem__(self, index):
33
+ molecule = self.data[index]
34
+ symbols_molecule = ["[START]"] + list(sf.split_selfies(molecule)) + ["[END]"]
35
+ integer_molecule = [
36
+ self.selfies_to_integer[symbol] for symbol in symbols_molecule
37
+ ]
38
+ return torch.tensor(integer_molecule, dtype=torch.long)
39
+
40
+
41
+ def collate_fn(batch):
42
+ """
43
+ Collate function for a PyTorch DataLoader.
44
+ Sorts the incoming batch by sequence length in descending order, pads the sequences
45
+ to the same length (batch_first=True, padding_value=0) using torch.nn.utils.rnn.pad_sequence,
46
+ and returns the padded batch together with the sorted original lengths.
47
+ :param batch: Iterable of 1D tensors representing variable-length sequences.
48
+ :type batch: list[torch.Tensor]
49
+ :returns: A tuple (padded_batch, tensor_lengths) where padded_batch is a 2D tensor
50
+ of shape (batch_size, max_seq_len) containing padded sequences, and
51
+ tensor_lengths is a 1D tensor of original sequence lengths sorted in
52
+ descending order.
53
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
54
+ """
55
+
56
+ tensor_lengths = torch.tensor([len(x) for x in batch])
57
+ tensor_lengths, sorted_idx = tensor_lengths.sort(descending=True)
58
+ batch = [batch[i] for i in sorted_idx]
59
+
60
+ padded_batch = pad_sequence(batch, batch_first=True, padding_value=0)
61
+
62
+ return padded_batch, tensor_lengths
63
+
64
+
65
+ def get_dataloader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn):
66
+ """
67
+ Create a PyTorch DataLoader.
68
+
69
+ :param dataset: PyTorch Dataset instance.
70
+ :type dataset: torch.utils.data.Dataset
71
+ :param batch_size: Number of samples per batch.
72
+ :type batch_size: int
73
+ :param shuffle: Whether to shuffle the data each epoch.
74
+ :type shuffle: bool
75
+ :param collate_fn: Function to merge a list of samples to form a mini-batch.
76
+ :type collate_fn: callable
77
+ :return: Configured DataLoader.
78
+ :rtype: torch.utils.data.DataLoader
79
+ """
80
+ return DataLoader(
81
+ dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn
82
+ )