joltax 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- joltax-0.1.1/LICENSE +21 -0
- joltax-0.1.1/PKG-INFO +74 -0
- joltax-0.1.1/README.md +52 -0
- joltax-0.1.1/joltax/__init__.py +1 -0
- joltax-0.1.1/joltax/joltree.py +734 -0
- joltax-0.1.1/joltax.egg-info/PKG-INFO +74 -0
- joltax-0.1.1/joltax.egg-info/SOURCES.txt +11 -0
- joltax-0.1.1/joltax.egg-info/dependency_links.txt +1 -0
- joltax-0.1.1/joltax.egg-info/requires.txt +3 -0
- joltax-0.1.1/joltax.egg-info/top_level.txt +1 -0
- joltax-0.1.1/pyproject.toml +34 -0
- joltax-0.1.1/setup.cfg +4 -0
- joltax-0.1.1/tests/test_tree.py +142 -0
joltax-0.1.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Swedish Biodiversity in Time and Space (SweBiTS)
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
joltax-0.1.1/PKG-INFO
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: joltax
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: A high-performance, vectorized taxonomy library for Python.
|
|
5
|
+
Author-email: Daniel Svensson <daniel.svensson@umu.se>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/SweBiTS/JolTax
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/SweBiTS/JolTax/issues
|
|
9
|
+
Project-URL: Source Code, https://github.com/SweBiTS/JolTax
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: numpy>=1.20.0
|
|
19
|
+
Requires-Dist: polars>=0.20.0
|
|
20
|
+
Requires-Dist: rapidfuzz>=3.0.0
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
<p align="center">
|
|
24
|
+
<img src="assets/logo.png" alt="joltax logo" width="300">
|
|
25
|
+
</p>
|
|
26
|
+
|
|
27
|
+
# joltax
|
|
28
|
+
|
|
29
|
+
**High-performance, vectorized taxonomy library for Python.**
|
|
30
|
+
|
|
31
|
+
`JolTax` is a Python library designed to handle massive taxonomies with extreme efficiency. By representing taxonomy trees as contiguous NumPy arrays and leveraging Polars for mass data handling, it achieves lightning-fast traversals, constant-time clade queries, and rapid mass annotation of large datasets.
|
|
32
|
+
|
|
33
|
+
## Key Features
|
|
34
|
+
- **Vectorized Performance:** Uses hardware-accelerated NumPy operations for million-scale property lookups.
|
|
35
|
+
- **Memory Efficient:** Optimized string store using Polars/Arrow reduces RAM footprint.
|
|
36
|
+
- **Fuzzy Name Search:** Rapid fuzzy matching using RapidFuzz to find TaxIDs from names.
|
|
37
|
+
- **Instant Clade Queries:** Quickly find all descendants of any node (even millions) using optimized range indexing.
|
|
38
|
+
- **Hyper-Vectorized LCA search:** Lowest Common Ancestor (LCA) search and node-to-node distance calculations at lightning speeds.
|
|
39
|
+
- **Mass Annotation:** Annotate massive TaxID tables with 2,000,000+ rows in under a second using Polars.
|
|
40
|
+
|
|
41
|
+
## Quick Start
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
from joltax.joltree import JolTree
|
|
45
|
+
|
|
46
|
+
# Build and process the NCBI taxonomy
|
|
47
|
+
tree = JolTree(nodes_file='nodes.dmp', names_file='names.dmp')
|
|
48
|
+
|
|
49
|
+
# Save for instant loading next time
|
|
50
|
+
tree.save('my_taxonomy_cache')
|
|
51
|
+
|
|
52
|
+
# Re-load in milliseconds (using zero-copy Arrow IPC)
|
|
53
|
+
tree = JolTree.load('my_taxonomy_cache')
|
|
54
|
+
|
|
55
|
+
# Batch LCA (process 10,000 pairs in <10ms)
|
|
56
|
+
lcas = tree.get_lca_batch(ids1, ids2)
|
|
57
|
+
|
|
58
|
+
# Fuzzy search for a name (returns a Polars DataFrame)
|
|
59
|
+
results = tree.search_name('Escherchia', fuzzy=True)
|
|
60
|
+
print(results)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Installation
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
cd joltax
|
|
67
|
+
pip install .
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Requires: `numpy`, `polars`, `rapidfuzz`.
|
|
71
|
+
|
|
72
|
+
## Documentation
|
|
73
|
+
|
|
74
|
+
For a detailed API reference and a comprehensive "How-To" guide with example workflows, please see [USAGE.md](./USAGE.md).
|
joltax-0.1.1/README.md
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<img src="assets/logo.png" alt="joltax logo" width="300">
|
|
3
|
+
</p>
|
|
4
|
+
|
|
5
|
+
# joltax
|
|
6
|
+
|
|
7
|
+
**High-performance, vectorized taxonomy library for Python.**
|
|
8
|
+
|
|
9
|
+
`JolTax` is a Python library designed to handle massive taxonomies with extreme efficiency. By representing taxonomy trees as contiguous NumPy arrays and leveraging Polars for mass data handling, it achieves lightning-fast traversals, constant-time clade queries, and rapid mass annotation of large datasets.
|
|
10
|
+
|
|
11
|
+
## Key Features
|
|
12
|
+
- **Vectorized Performance:** Uses hardware-accelerated NumPy operations for million-scale property lookups.
|
|
13
|
+
- **Memory Efficient:** Optimized string store using Polars/Arrow reduces RAM footprint.
|
|
14
|
+
- **Fuzzy Name Search:** Rapid fuzzy matching using RapidFuzz to find TaxIDs from names.
|
|
15
|
+
- **Instant Clade Queries:** Quickly find all descendants of any node (even millions) using optimized range indexing.
|
|
16
|
+
- **Hyper-Vectorized LCA search:** Lowest Common Ancestor (LCA) search and node-to-node distance calculations at lightning speeds.
|
|
17
|
+
- **Mass Annotation:** Annotate massive TaxID tables with 2,000,000+ rows in under a second using Polars.
|
|
18
|
+
|
|
19
|
+
## Quick Start
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
from joltax.joltree import JolTree
|
|
23
|
+
|
|
24
|
+
# Build and process the NCBI taxonomy
|
|
25
|
+
tree = JolTree(nodes_file='nodes.dmp', names_file='names.dmp')
|
|
26
|
+
|
|
27
|
+
# Save for instant loading next time
|
|
28
|
+
tree.save('my_taxonomy_cache')
|
|
29
|
+
|
|
30
|
+
# Re-load in milliseconds (using zero-copy Arrow IPC)
|
|
31
|
+
tree = JolTree.load('my_taxonomy_cache')
|
|
32
|
+
|
|
33
|
+
# Batch LCA (process 10,000 pairs in <10ms)
|
|
34
|
+
lcas = tree.get_lca_batch(ids1, ids2)
|
|
35
|
+
|
|
36
|
+
# Fuzzy search for a name (returns a Polars DataFrame)
|
|
37
|
+
results = tree.search_name('Escherchia', fuzzy=True)
|
|
38
|
+
print(results)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Installation
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
cd joltax
|
|
45
|
+
pip install .
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Requires: `numpy`, `polars`, `rapidfuzz`.
|
|
49
|
+
|
|
50
|
+
## Documentation
|
|
51
|
+
|
|
52
|
+
For a detailed API reference and a comprehensive "How-To" guide with example workflows, please see [USAGE.md](./USAGE.md).
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .joltree import JolTree
|
|
@@ -0,0 +1,734 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
joltax/joltree.py
|
|
4
|
+
Implementation of a high-performance, vectorized taxonomy tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.1"
|
|
8
|
+
|
|
9
|
+
# The minimum version of a saved taxonomy cache that is compatible with this software.
|
|
10
|
+
# Increment this when making breaking changes to the binary layout or metadata structure.
|
|
11
|
+
MINIMUM_CACHE_VERSION = "0.1.1"
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
import datetime
|
|
16
|
+
from typing import Dict, List, Optional, Set, Union, Tuple
|
|
17
|
+
from collections import namedtuple
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import polars as pl
|
|
21
|
+
from rapidfuzz import process, fuzz, utils
|
|
22
|
+
|
|
23
|
+
# Set up logging for the module
|
|
24
|
+
logging.basicConfig(
|
|
25
|
+
format='%(asctime)s %(levelname)-8s %(message)s',
|
|
26
|
+
level=logging.INFO,
|
|
27
|
+
datefmt='%Y-%m-%d [%H:%M:%S]'
|
|
28
|
+
)
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
# Standard canonical ranks in order (highest to lowest)
|
|
32
|
+
# Including both superkingdom and domain for compatibility with pre/post-2025 taxonomies
|
|
33
|
+
CANONICAL_RANKS = [
|
|
34
|
+
'superkingdom', 'domain', 'kingdom', 'phylum',
|
|
35
|
+
'class', 'order', 'family', 'genus', 'species'
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
# Mapping rank names to standard Kraken-style codes
|
|
39
|
+
RANK_TO_CODE = {
|
|
40
|
+
'superkingdom': 'D',
|
|
41
|
+
'domain': 'D',
|
|
42
|
+
'kingdom': 'K',
|
|
43
|
+
'phylum': 'P',
|
|
44
|
+
'class': 'C',
|
|
45
|
+
'order': 'O',
|
|
46
|
+
'family': 'F',
|
|
47
|
+
'genus': 'G',
|
|
48
|
+
'species': 'S'
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
class JolTree:
|
|
52
|
+
"""
|
|
53
|
+
A high-performance taxonomy representation using vectorized arrays.
|
|
54
|
+
|
|
55
|
+
This class replaces traditional object-oriented trees with contiguous
|
|
56
|
+
NumPy arrays for lightning-fast lookups, traversals, and mass annotations.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, nodes_file: Optional[str] = None, names_file: Optional[str] = None):
|
|
60
|
+
"""
|
|
61
|
+
Initialize the taxonomy tree. If files are provided, it builds from DMP files.
|
|
62
|
+
Otherwise, it can be loaded from a binary cache using `load()`.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
nodes_file: Path to NCBI nodes.dmp
|
|
66
|
+
names_file: Path to NCBI names.dmp
|
|
67
|
+
"""
|
|
68
|
+
# Vectorized internal index mapping (sorted array of TaxIDs)
|
|
69
|
+
self._index_to_id: np.ndarray = np.array([], dtype=np.int32)
|
|
70
|
+
|
|
71
|
+
# Primary arrays (indexed by the dense internal index)
|
|
72
|
+
self.parents: np.ndarray = np.array([], dtype=np.int32)
|
|
73
|
+
self.depths: np.ndarray = np.array([], dtype=np.int32)
|
|
74
|
+
self.ranks: np.ndarray = np.array([], dtype=np.uint8)
|
|
75
|
+
|
|
76
|
+
# Metadata storage (Polars Series for memory efficiency)
|
|
77
|
+
self._scientific_names: pl.Series = pl.Series("scientific_name", [], dtype=pl.String)
|
|
78
|
+
self._common_names: pl.Series = pl.Series("common_name", [], dtype=pl.String)
|
|
79
|
+
self.rank_names: List[str] = []
|
|
80
|
+
self.top_rank: str = "domain" # Default, will be detected
|
|
81
|
+
self._source_nodes: Optional[str] = None
|
|
82
|
+
self._source_names: Optional[str] = None
|
|
83
|
+
self._build_time: Optional[str] = None
|
|
84
|
+
|
|
85
|
+
# Clade query support (Euler Tour timestamps)
|
|
86
|
+
self.entry_times: np.ndarray = np.array([], dtype=np.int32)
|
|
87
|
+
self.exit_times: np.ndarray = np.array([], dtype=np.int32)
|
|
88
|
+
|
|
89
|
+
# Binary lifting table for LCA (initialized on demand)
|
|
90
|
+
self._up_table: Optional[np.ndarray] = None
|
|
91
|
+
|
|
92
|
+
# Pre-calculated canonical rank maps (dense internal index -> dense internal index)
|
|
93
|
+
# Values are internal indices, not TaxIDs. -1 means no ancestor at that rank.
|
|
94
|
+
self.canonical_maps: Dict[str, np.ndarray] = {}
|
|
95
|
+
|
|
96
|
+
# Search index (Polars DataFrame: name -> tax_id)
|
|
97
|
+
self._search_index: pl.DataFrame = pl.DataFrame(schema={"name": pl.String, "tax_id": pl.Int32})
|
|
98
|
+
|
|
99
|
+
# Caches for vectorized lookup (prepared during build/load)
|
|
100
|
+
self._sci_names_lookup: Optional[pl.Series] = None
|
|
101
|
+
self._rank_names_series: Optional[pl.Series] = None
|
|
102
|
+
self._ranks_extended: Optional[np.ndarray] = None
|
|
103
|
+
|
|
104
|
+
if nodes_file and names_file:
|
|
105
|
+
self.build_from_dmp(nodes_file, names_file)
|
|
106
|
+
|
|
107
|
+
def build_from_dmp(self, nodes_file: str, names_file: str) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Parses NCBI DMP files and builds the vectorized internal structure.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
nodes_file: Path to NCBI nodes.dmp
|
|
113
|
+
names_file: Path to NCBI names.dmp
|
|
114
|
+
"""
|
|
115
|
+
self._source_nodes = os.path.abspath(nodes_file)
|
|
116
|
+
self._source_names = os.path.abspath(names_file)
|
|
117
|
+
self._build_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
118
|
+
logger.info(f"Starting taxonomy build at {self._build_time}...")
|
|
119
|
+
|
|
120
|
+
# 1. Parse Names
|
|
121
|
+
logger.info(f"Parsing names from {names_file}...")
|
|
122
|
+
scientific_names = {}
|
|
123
|
+
common_names = {}
|
|
124
|
+
search_data = [] # List of (name, tax_id)
|
|
125
|
+
|
|
126
|
+
with open(names_file, 'r') as f:
|
|
127
|
+
for name_line in f:
|
|
128
|
+
parts = name_line.split('|')
|
|
129
|
+
name_type = parts[3].strip()
|
|
130
|
+
|
|
131
|
+
# Only care about scientific and genbank common names for now
|
|
132
|
+
if name_type not in ['scientific name', 'genbank common name']:
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
tax_id = int(parts[0].strip())
|
|
136
|
+
name_txt = parts[1].strip()
|
|
137
|
+
|
|
138
|
+
if name_type == 'scientific name':
|
|
139
|
+
scientific_names[tax_id] = name_txt
|
|
140
|
+
elif name_type == 'genbank common name':
|
|
141
|
+
common_names[tax_id] = name_txt
|
|
142
|
+
|
|
143
|
+
search_data.append({"name": name_txt, "tax_id": tax_id})
|
|
144
|
+
|
|
145
|
+
# Build search index
|
|
146
|
+
self._search_index = pl.DataFrame(search_data).sort("name")
|
|
147
|
+
|
|
148
|
+
# 2. Parse Nodes and initial parent structure
|
|
149
|
+
logger.info(f"Parsing nodes from {nodes_file}...")
|
|
150
|
+
temp_parents = {}
|
|
151
|
+
temp_ranks = {}
|
|
152
|
+
all_ranks = set()
|
|
153
|
+
|
|
154
|
+
with open(nodes_file, 'r') as f:
|
|
155
|
+
for line in f:
|
|
156
|
+
parts = line.split('|')
|
|
157
|
+
tax_id = int(parts[0].strip())
|
|
158
|
+
parent_id = int(parts[1].strip())
|
|
159
|
+
rank = parts[2].strip()
|
|
160
|
+
|
|
161
|
+
temp_parents[tax_id] = parent_id
|
|
162
|
+
temp_ranks[tax_id] = rank
|
|
163
|
+
all_ranks.add(rank)
|
|
164
|
+
|
|
165
|
+
# 2.1 Detect top rank (superkingdom vs domain)
|
|
166
|
+
has_sk = 'superkingdom' in all_ranks
|
|
167
|
+
has_dm = 'domain' in all_ranks
|
|
168
|
+
if has_sk and has_dm:
|
|
169
|
+
raise ValueError("Found both 'superkingdom' and 'domain' ranks. The taxonomy must use only one as the top rank.")
|
|
170
|
+
self.top_rank = 'superkingdom' if has_sk else 'domain'
|
|
171
|
+
logger.info(f"Detected top rank: {self.top_rank}")
|
|
172
|
+
|
|
173
|
+
# 3. Create dense mapping
|
|
174
|
+
logger.info("Creating dense mapping and vectorized arrays...")
|
|
175
|
+
sorted_tax_ids = sorted(temp_parents.keys())
|
|
176
|
+
num_nodes = len(sorted_tax_ids)
|
|
177
|
+
self._index_to_id = np.array(sorted_tax_ids, dtype=np.int32)
|
|
178
|
+
|
|
179
|
+
# Mapping rank names to indices
|
|
180
|
+
self.rank_names = sorted(list(all_ranks))
|
|
181
|
+
rank_to_idx = {r: i for i, r in enumerate(self.rank_names)}
|
|
182
|
+
|
|
183
|
+
self.parents = np.zeros(num_nodes, dtype=np.int32)
|
|
184
|
+
self.ranks = np.zeros(num_nodes, dtype=np.uint8)
|
|
185
|
+
|
|
186
|
+
# Temporary dict for building parent connections (will be discarded)
|
|
187
|
+
id_to_index_temp = {tid: i for i, tid in enumerate(sorted_tax_ids)}
|
|
188
|
+
|
|
189
|
+
for tid, i in id_to_index_temp.items():
|
|
190
|
+
parent_id = temp_parents[tid]
|
|
191
|
+
# Handle root (1) which is its own parent in NCBI
|
|
192
|
+
if tid == 1:
|
|
193
|
+
self.parents[i] = i
|
|
194
|
+
else:
|
|
195
|
+
self.parents[i] = id_to_index_temp[parent_id]
|
|
196
|
+
|
|
197
|
+
self.ranks[i] = rank_to_idx[temp_ranks[tid]]
|
|
198
|
+
|
|
199
|
+
# Populate names aligned with indices
|
|
200
|
+
logger.info("Aligning names and ranks...")
|
|
201
|
+
sci_names_list = [scientific_names.get(tid, f"Unknown_{tid}") for tid in sorted_tax_ids]
|
|
202
|
+
com_names_list = [common_names.get(tid) for tid in sorted_tax_ids]
|
|
203
|
+
self._scientific_names = pl.Series("scientific_name", sci_names_list)
|
|
204
|
+
self._common_names = pl.Series("common_name", com_names_list)
|
|
205
|
+
|
|
206
|
+
# 4. Calculate depths
|
|
207
|
+
logger.info("Calculating node depths...")
|
|
208
|
+
self.depths = np.zeros(num_nodes, dtype=np.int32)
|
|
209
|
+
for i in range(num_nodes):
|
|
210
|
+
self._calculate_depth(i)
|
|
211
|
+
|
|
212
|
+
# 5. Build Euler Tour for clade queries
|
|
213
|
+
self._build_euler_tour()
|
|
214
|
+
|
|
215
|
+
# 6. Pre-calculate canonical rank maps
|
|
216
|
+
self._build_canonical_maps()
|
|
217
|
+
|
|
218
|
+
# 7. Prepare caches for vectorized lookups
|
|
219
|
+
self._prepare_vectorized_caches()
|
|
220
|
+
|
|
221
|
+
logger.info("Taxonomy build complete.")
|
|
222
|
+
|
|
223
|
+
def _prepare_vectorized_caches(self) -> None:
|
|
224
|
+
"""Initializes caches used for high-performance vectorized lookups."""
|
|
225
|
+
logger.info("Preparing vectorized lookup caches...")
|
|
226
|
+
# Scientific names lookup (aligned with dense internal index + 1 for "Unknown")
|
|
227
|
+
self._sci_names_lookup = self._scientific_names.append(pl.Series([None]))
|
|
228
|
+
|
|
229
|
+
# Rank names lookup
|
|
230
|
+
self._rank_names_series = pl.Series(self.rank_names).append(pl.Series(["unclassified"]))
|
|
231
|
+
|
|
232
|
+
# Ranks extended with a pointer to "unclassified" for unknown nodes
|
|
233
|
+
self._ranks_extended = np.append(self.ranks, [len(self.rank_names)]).astype(np.int32)
|
|
234
|
+
|
|
235
|
+
def _build_canonical_maps(self) -> None:
|
|
236
|
+
"""Pre-calculates canonical rank ancestors for all nodes."""
|
|
237
|
+
logger.info("Pre-calculating canonical rank maps...")
|
|
238
|
+
num_nodes = len(self._index_to_id)
|
|
239
|
+
|
|
240
|
+
# Identify all canonical ranks to track
|
|
241
|
+
canonical_columns = [self.top_rank] + [r for r in CANONICAL_RANKS if r not in ['superkingdom', 'domain']]
|
|
242
|
+
|
|
243
|
+
# Initialize maps with -1 (meaning no ancestor at that rank)
|
|
244
|
+
self.canonical_maps = {rank: np.full(num_nodes, -1, dtype=np.int32) for rank in canonical_columns}
|
|
245
|
+
|
|
246
|
+
# Sort nodes by depth to ensure parents are processed before children
|
|
247
|
+
for i in range(num_nodes):
|
|
248
|
+
curr_idx = i
|
|
249
|
+
root_idx = 0 # TaxID 1 is always the first in sorted_tax_ids
|
|
250
|
+
while True:
|
|
251
|
+
rank_name = self.rank_names[self.ranks[curr_idx]]
|
|
252
|
+
|
|
253
|
+
# Normalize superkingdom/domain based on detected top_rank
|
|
254
|
+
mapped_rank = rank_name
|
|
255
|
+
if rank_name in ['superkingdom', 'domain']:
|
|
256
|
+
mapped_rank = self.top_rank
|
|
257
|
+
|
|
258
|
+
if mapped_rank in self.canonical_maps:
|
|
259
|
+
self.canonical_maps[mapped_rank][i] = curr_idx
|
|
260
|
+
|
|
261
|
+
if curr_idx == root_idx:
|
|
262
|
+
break
|
|
263
|
+
curr_idx = self.parents[curr_idx]
|
|
264
|
+
|
|
265
|
+
def _calculate_depth(self, index: int) -> int:
|
|
266
|
+
"""Recursive depth calculation with memoization."""
|
|
267
|
+
if index == 0: # TaxID 1 is always index 0
|
|
268
|
+
return 0
|
|
269
|
+
if self.depths[index] != 0:
|
|
270
|
+
return self.depths[index]
|
|
271
|
+
|
|
272
|
+
d = self._calculate_depth(self.parents[index]) + 1
|
|
273
|
+
self.depths[index] = d
|
|
274
|
+
return d
|
|
275
|
+
|
|
276
|
+
def _build_euler_tour(self) -> None:
|
|
277
|
+
"""Assigns entry/exit times to enable instant clade queries."""
|
|
278
|
+
logger.info("Building Euler Tour index for clade queries...")
|
|
279
|
+
num_nodes = len(self._index_to_id)
|
|
280
|
+
self.entry_times = np.zeros(num_nodes, dtype=np.int32)
|
|
281
|
+
self.exit_times = np.zeros(num_nodes, dtype=np.int32)
|
|
282
|
+
|
|
283
|
+
# Build adjacency list (children)
|
|
284
|
+
children = [[] for _ in range(num_nodes)]
|
|
285
|
+
root_idx = 0 # TaxID 1
|
|
286
|
+
for i, p in enumerate(self.parents):
|
|
287
|
+
if i != root_idx:
|
|
288
|
+
children[p].append(i)
|
|
289
|
+
|
|
290
|
+
timer = 0
|
|
291
|
+
stack = [(root_idx, False)] # (index, is_processed)
|
|
292
|
+
|
|
293
|
+
while stack:
|
|
294
|
+
idx, processed = stack.pop()
|
|
295
|
+
if not processed:
|
|
296
|
+
self.entry_times[idx] = timer
|
|
297
|
+
timer += 1
|
|
298
|
+
stack.append((idx, True))
|
|
299
|
+
for child in reversed(children[idx]):
|
|
300
|
+
stack.append((child, False))
|
|
301
|
+
else:
|
|
302
|
+
self.exit_times[idx] = timer - 1
|
|
303
|
+
|
|
304
|
+
def _get_index(self, tax_id: int) -> int:
|
|
305
|
+
"""Returns the internal index for a TaxID, or -1 if not found."""
|
|
306
|
+
idx = np.searchsorted(self._index_to_id, tax_id)
|
|
307
|
+
if idx < len(self._index_to_id) and self._index_to_id[idx] == tax_id:
|
|
308
|
+
return int(idx)
|
|
309
|
+
return -1
|
|
310
|
+
|
|
311
|
+
def _get_indices(self, tax_ids: np.ndarray) -> np.ndarray:
|
|
312
|
+
"""Returns internal indices for an array of TaxIDs, with -1 for missing."""
|
|
313
|
+
indices = np.searchsorted(self._index_to_id, tax_ids)
|
|
314
|
+
# Handle out of bounds
|
|
315
|
+
mask = indices < len(self._index_to_id)
|
|
316
|
+
# Check for actual equality
|
|
317
|
+
valid = np.zeros(len(tax_ids), dtype=bool)
|
|
318
|
+
valid[mask] = self._index_to_id[indices[mask]] == tax_ids[mask]
|
|
319
|
+
return np.where(valid, indices, -1)
|
|
320
|
+
|
|
321
|
+
def get_lineage(self, tax_id: int) -> List[int]:
|
|
322
|
+
"""Returns the full lineage from root to the given TaxID."""
|
|
323
|
+
idx = self._get_index(tax_id)
|
|
324
|
+
if idx == -1:
|
|
325
|
+
logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
|
|
326
|
+
return []
|
|
327
|
+
|
|
328
|
+
lineage = []
|
|
329
|
+
root_idx = 0
|
|
330
|
+
|
|
331
|
+
while True:
|
|
332
|
+
lineage.append(int(self._index_to_id[idx]))
|
|
333
|
+
if idx == root_idx:
|
|
334
|
+
break
|
|
335
|
+
idx = self.parents[idx]
|
|
336
|
+
|
|
337
|
+
return lineage[::-1]
|
|
338
|
+
|
|
339
|
+
def get_name(self, tax_id: int) -> str:
|
|
340
|
+
"""Returns the scientific name of the given TaxID."""
|
|
341
|
+
idx = self._get_index(tax_id)
|
|
342
|
+
if idx != -1:
|
|
343
|
+
return self._scientific_names[idx]
|
|
344
|
+
return f"Unknown_{tax_id}"
|
|
345
|
+
|
|
346
|
+
def get_common_name(self, tax_id: int) -> Optional[str]:
|
|
347
|
+
"""Returns the genbank common name of the given TaxID, if available."""
|
|
348
|
+
idx = self._get_index(tax_id)
|
|
349
|
+
if idx != -1:
|
|
350
|
+
return self._common_names[idx]
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
def get_rank(self, tax_id: int) -> str:
|
|
354
|
+
"""Returns the taxonomic rank of the given TaxID."""
|
|
355
|
+
idx = self._get_index(tax_id)
|
|
356
|
+
if idx == -1:
|
|
357
|
+
logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
|
|
358
|
+
return "unknown"
|
|
359
|
+
return self.rank_names[self.ranks[idx]]
|
|
360
|
+
|
|
361
|
+
def search_name(self, query: str, fuzzy: bool = False, limit: int = 10, score_cutoff: float = 60.0) -> pl.DataFrame:
|
|
362
|
+
"""
|
|
363
|
+
Searches for TaxIDs by name.
|
|
364
|
+
"""
|
|
365
|
+
if not fuzzy:
|
|
366
|
+
matches = self._search_index.filter(pl.col("name") == query)
|
|
367
|
+
if matches.is_empty():
|
|
368
|
+
return pl.DataFrame(schema=["tax_id", "name", "rank", "score"])
|
|
369
|
+
|
|
370
|
+
# Vectorized rank lookup for matches
|
|
371
|
+
tids = matches["tax_id"].to_numpy()
|
|
372
|
+
indices = self._get_indices(tids)
|
|
373
|
+
ranks = [self.rank_names[self.ranks[i]] if i != -1 else "unknown" for i in indices]
|
|
374
|
+
|
|
375
|
+
return matches.with_columns([
|
|
376
|
+
pl.Series("rank", ranks),
|
|
377
|
+
pl.lit(100.0).alias("score")
|
|
378
|
+
])
|
|
379
|
+
|
|
380
|
+
# Fuzzy matching path
|
|
381
|
+
unique_names = self._search_index["name"].unique().to_list()
|
|
382
|
+
|
|
383
|
+
# rapidfuzz extract
|
|
384
|
+
matches = process.extract(
|
|
385
|
+
query,
|
|
386
|
+
unique_names,
|
|
387
|
+
scorer=fuzz.WRatio,
|
|
388
|
+
limit=limit,
|
|
389
|
+
processor=utils.default_process,
|
|
390
|
+
score_cutoff=score_cutoff
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
data = []
|
|
394
|
+
for match_str, score, _ in matches:
|
|
395
|
+
# Find all TaxIDs associated with this name
|
|
396
|
+
tids = self._search_index.filter(pl.col("name") == match_str)["tax_id"].to_list()
|
|
397
|
+
for tid in tids:
|
|
398
|
+
idx = self._get_index(tid)
|
|
399
|
+
rank = self.rank_names[self.ranks[idx]] if idx != -1 else "unknown"
|
|
400
|
+
|
|
401
|
+
# Smart Ranking: Boost scores for canonical ranks
|
|
402
|
+
rank_boost = 0.0
|
|
403
|
+
if rank in CANONICAL_RANKS:
|
|
404
|
+
rank_boost = 2.0
|
|
405
|
+
|
|
406
|
+
data.append({
|
|
407
|
+
"tax_id": tid,
|
|
408
|
+
"matched_name": match_str,
|
|
409
|
+
"scientific_name": self.get_name(tid),
|
|
410
|
+
"rank": rank,
|
|
411
|
+
"score": score + rank_boost
|
|
412
|
+
})
|
|
413
|
+
|
|
414
|
+
if not data:
|
|
415
|
+
return pl.DataFrame(schema=["tax_id", "matched_name", "scientific_name", "rank", "score"])
|
|
416
|
+
|
|
417
|
+
return pl.DataFrame(data).sort("score", descending=True)
|
|
418
|
+
|
|
419
|
+
def get_clade(self, tax_id: int) -> List[int]:
|
|
420
|
+
"""Returns all TaxIDs in the clade rooted at the given TaxID."""
|
|
421
|
+
idx = self._get_index(tax_id)
|
|
422
|
+
if idx == -1:
|
|
423
|
+
logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
|
|
424
|
+
return []
|
|
425
|
+
|
|
426
|
+
entry = self.entry_times[idx]
|
|
427
|
+
exit = self.exit_times[idx]
|
|
428
|
+
|
|
429
|
+
mask = (self.entry_times >= entry) & (self.entry_times <= exit)
|
|
430
|
+
return self._index_to_id[mask].astype(int).tolist()
|
|
431
|
+
|
|
432
|
+
def get_clade_at_rank(self, tax_id: int, rank_name: str) -> List[int]:
|
|
433
|
+
"""
|
|
434
|
+
Returns all TaxIDs of a specific rank within the clade rooted at tax_id.
|
|
435
|
+
"""
|
|
436
|
+
idx = self._get_index(tax_id)
|
|
437
|
+
if idx == -1:
|
|
438
|
+
logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
|
|
439
|
+
return []
|
|
440
|
+
|
|
441
|
+
try:
|
|
442
|
+
target_rank_idx = self.rank_names.index(rank_name)
|
|
443
|
+
except ValueError:
|
|
444
|
+
logger.warning(f"Rank '{rank_name}' not found in taxonomy. Available ranks: {self.rank_names}")
|
|
445
|
+
return []
|
|
446
|
+
|
|
447
|
+
entry = self.entry_times[idx]
|
|
448
|
+
exit = self.exit_times[idx]
|
|
449
|
+
|
|
450
|
+
mask = (self.entry_times >= entry) & (self.entry_times <= exit) & (self.ranks == target_rank_idx)
|
|
451
|
+
return self._index_to_id[mask].astype(int).tolist()
|
|
452
|
+
|
|
453
|
+
def get_lca(self, tax_id_1: int, tax_id_2: int) -> int:
|
|
454
|
+
"""Finds the Lowest Common Ancestor using Binary Lifting."""
|
|
455
|
+
idx1 = self._get_index(tax_id_1)
|
|
456
|
+
idx2 = self._get_index(tax_id_2)
|
|
457
|
+
|
|
458
|
+
if idx1 == -1 or idx2 == -1:
|
|
459
|
+
logger.warning(f"One or both TaxIDs ({tax_id_1}, {tax_id_2}) not found.")
|
|
460
|
+
return 1
|
|
461
|
+
|
|
462
|
+
self._ensure_up_table()
|
|
463
|
+
|
|
464
|
+
if self.depths[idx1] < self.depths[idx2]:
|
|
465
|
+
idx1, idx2 = idx2, idx1
|
|
466
|
+
|
|
467
|
+
diff = self.depths[idx1] - self.depths[idx2]
|
|
468
|
+
max_log = self._up_table.shape[0]
|
|
469
|
+
|
|
470
|
+
for i in range(max_log):
|
|
471
|
+
if (diff >> i) & 1:
|
|
472
|
+
idx1 = self._up_table[i, idx1]
|
|
473
|
+
|
|
474
|
+
if idx1 == idx2:
|
|
475
|
+
return int(self._index_to_id[idx1])
|
|
476
|
+
|
|
477
|
+
for i in reversed(range(max_log)):
|
|
478
|
+
up1 = self._up_table[i, idx1]
|
|
479
|
+
up2 = self._up_table[i, idx2]
|
|
480
|
+
if up1 != up2:
|
|
481
|
+
idx1 = up1
|
|
482
|
+
idx2 = up2
|
|
483
|
+
|
|
484
|
+
return int(self._index_to_id[self.parents[idx1]])
|
|
485
|
+
|
|
486
|
+
def get_distance(self, tax_id_1: int, tax_id_2: int) -> int:
|
|
487
|
+
"""Calculates distance (number of edges) between two TaxIDs."""
|
|
488
|
+
lca_id = self.get_lca(tax_id_1, tax_id_2)
|
|
489
|
+
idx1 = self._get_index(tax_id_1)
|
|
490
|
+
idx2 = self._get_index(tax_id_2)
|
|
491
|
+
idx_lca = self._get_index(lca_id)
|
|
492
|
+
return int(self.depths[idx1] + self.depths[idx2] - 2 * self.depths[idx_lca])
|
|
493
|
+
|
|
494
|
+
def get_lca_batch(self, ids1: Union[List[int], np.ndarray], ids2: Union[List[int], np.ndarray]) -> np.ndarray:
|
|
495
|
+
"""
|
|
496
|
+
Calculates Lowest Common Ancestor for arrays of TaxIDs.
|
|
497
|
+
Hyper-vectorized implementation for peak performance.
|
|
498
|
+
"""
|
|
499
|
+
ids1 = np.array(ids1, dtype=np.int32)
|
|
500
|
+
ids2 = np.array(ids2, dtype=np.int32)
|
|
501
|
+
|
|
502
|
+
if ids1.shape != ids2.shape:
|
|
503
|
+
raise ValueError("Input arrays must have the same shape.")
|
|
504
|
+
|
|
505
|
+
self._ensure_up_table()
|
|
506
|
+
|
|
507
|
+
idx1 = self._get_indices(ids1)
|
|
508
|
+
idx2 = self._get_indices(ids2)
|
|
509
|
+
|
|
510
|
+
# Handle missing IDs by pointing to root (index 0)
|
|
511
|
+
valid_mask = (idx1 != -1) & (idx2 != -1)
|
|
512
|
+
s_idx1 = np.where(valid_mask, idx1, 0)
|
|
513
|
+
s_idx2 = np.where(valid_mask, idx2, 0)
|
|
514
|
+
|
|
515
|
+
# 1. Bring both nodes to the same depth
|
|
516
|
+
d1 = self.depths[s_idx1]
|
|
517
|
+
d2 = self.depths[s_idx2]
|
|
518
|
+
|
|
519
|
+
# Ensure s_idx1 is the deeper one
|
|
520
|
+
swap = d1 < d2
|
|
521
|
+
s_idx1[swap], s_idx2[swap] = s_idx2[swap], s_idx1[swap]
|
|
522
|
+
|
|
523
|
+
diff = np.abs(d1 - d2)
|
|
524
|
+
max_log = self._up_table.shape[0]
|
|
525
|
+
|
|
526
|
+
for i in range(max_log):
|
|
527
|
+
mask = (diff >> i) & 1 == 1
|
|
528
|
+
if np.any(mask):
|
|
529
|
+
s_idx1[mask] = self._up_table[i, s_idx1[mask]]
|
|
530
|
+
|
|
531
|
+
# 2. Binary search for the LCA
|
|
532
|
+
lca_indices = s_idx1.copy()
|
|
533
|
+
not_same = s_idx1 != s_idx2
|
|
534
|
+
|
|
535
|
+
if np.any(not_same):
|
|
536
|
+
sub1 = s_idx1[not_same]
|
|
537
|
+
sub2 = s_idx2[not_same]
|
|
538
|
+
|
|
539
|
+
for i in reversed(range(max_log)):
|
|
540
|
+
up1 = self._up_table[i, sub1]
|
|
541
|
+
up2 = self._up_table[i, sub2]
|
|
542
|
+
|
|
543
|
+
diff_up = up1 != up2
|
|
544
|
+
sub1[diff_up] = up1[diff_up]
|
|
545
|
+
sub2[diff_up] = up2[diff_up]
|
|
546
|
+
|
|
547
|
+
lca_indices[not_same] = self.parents[sub1]
|
|
548
|
+
|
|
549
|
+
results = self._index_to_id[lca_indices]
|
|
550
|
+
results[~valid_mask] = 1
|
|
551
|
+
return results
|
|
552
|
+
|
|
553
|
+
def get_distance_batch(self, ids1: Union[List[int], np.ndarray], ids2: Union[List[int], np.ndarray]) -> np.ndarray:
|
|
554
|
+
"""Vectorized distance calculation for arrays of TaxIDs."""
|
|
555
|
+
ids1 = np.array(ids1, dtype=np.int32)
|
|
556
|
+
ids2 = np.array(ids2, dtype=np.int32)
|
|
557
|
+
|
|
558
|
+
lca_ids = self.get_lca_batch(ids1, ids2)
|
|
559
|
+
|
|
560
|
+
idx1 = self._get_indices(ids1)
|
|
561
|
+
idx2 = self._get_indices(ids2)
|
|
562
|
+
idx_lca = self._get_indices(lca_ids)
|
|
563
|
+
|
|
564
|
+
# Mask invalid lookups to avoid OOB errors
|
|
565
|
+
valid = (idx1 != -1) & (idx2 != -1) & (idx_lca != -1)
|
|
566
|
+
|
|
567
|
+
dists = np.zeros(len(ids1), dtype=np.int32)
|
|
568
|
+
if np.any(valid):
|
|
569
|
+
v1, v2, vl = idx1[valid], idx2[valid], idx_lca[valid]
|
|
570
|
+
dists[valid] = self.depths[v1] + self.depths[v2] - 2 * self.depths[vl]
|
|
571
|
+
|
|
572
|
+
return dists
|
|
573
|
+
|
|
574
|
+
def annotate_table(self, tax_ids: Union[List[int], np.ndarray]) -> pl.DataFrame:
|
|
575
|
+
"""
|
|
576
|
+
Massively annotates a list of TaxIDs with scientific_names and canonical ranks.
|
|
577
|
+
Extremely efficient for large tables (e.g. 200k+ rows) using Polars and vectorized lookups.
|
|
578
|
+
"""
|
|
579
|
+
logger.info(f"Annotating {len(tax_ids)} taxa...")
|
|
580
|
+
canonical_columns = [self.top_rank] + [r for r in CANONICAL_RANKS if r not in ['superkingdom', 'domain']]
|
|
581
|
+
|
|
582
|
+
tax_ids_arr = np.array(tax_ids, dtype=np.int32)
|
|
583
|
+
indices = self._get_indices(tax_ids_arr)
|
|
584
|
+
valid_mask = indices != -1
|
|
585
|
+
|
|
586
|
+
# dummy_idx points to the "Unknown/None" entry at the end of the lookup series
|
|
587
|
+
dummy_idx = len(self._index_to_id)
|
|
588
|
+
safe_indices = np.where(valid_mask, indices, dummy_idx)
|
|
589
|
+
|
|
590
|
+
# Ensure caches are ready
|
|
591
|
+
if self._sci_names_lookup is None:
|
|
592
|
+
self._prepare_vectorized_caches()
|
|
593
|
+
|
|
594
|
+
df_dict = {"tax_id": tax_ids_arr}
|
|
595
|
+
|
|
596
|
+
for rank in canonical_columns:
|
|
597
|
+
# canonical_maps now store internal indices
|
|
598
|
+
ancestor_indices = np.full(len(tax_ids_arr), -1, dtype=np.int32)
|
|
599
|
+
# Map input tax_ids to their ancestor's internal index
|
|
600
|
+
ancestor_indices[valid_mask] = self.canonical_maps[rank][indices[valid_mask]]
|
|
601
|
+
|
|
602
|
+
# Use dummy_idx for missing ancestors
|
|
603
|
+
safe_anc_indices = np.where(ancestor_indices != -1, ancestor_indices, dummy_idx)
|
|
604
|
+
|
|
605
|
+
# Vectorized gather from Polars
|
|
606
|
+
df_dict[rank] = self._sci_names_lookup.gather(safe_anc_indices.astype(np.int32))
|
|
607
|
+
|
|
608
|
+
# Scientific name for the input TaxID
|
|
609
|
+
df_dict["scientific_name"] = self._sci_names_lookup.gather(safe_indices.astype(np.int32))
|
|
610
|
+
|
|
611
|
+
# Rank for the input TaxID
|
|
612
|
+
target_rank_indices = self._ranks_extended[safe_indices]
|
|
613
|
+
df_dict["rank"] = self._rank_names_series.gather(target_rank_indices.astype(np.int32))
|
|
614
|
+
|
|
615
|
+
df = pl.DataFrame(df_dict)
|
|
616
|
+
final_order = ['tax_id'] + canonical_columns + ['scientific_name', 'rank']
|
|
617
|
+
return df.select(final_order)
|
|
618
|
+
|
|
619
|
+
def _ensure_up_table(self) -> None:
|
|
620
|
+
"""Lazy initialization of binary lifting table."""
|
|
621
|
+
if self._up_table is not None:
|
|
622
|
+
return
|
|
623
|
+
|
|
624
|
+
logger.info("Initializing binary lifting table (Hyper-Vectorized)...")
|
|
625
|
+
num_nodes = len(self._index_to_id)
|
|
626
|
+
max_log = int(np.ceil(np.log2(np.max(self.depths) + 1)))
|
|
627
|
+
|
|
628
|
+
# Shape: (max_log, num_nodes) - optimized for contiguous column access
|
|
629
|
+
self._up_table = np.zeros((max_log, num_nodes), dtype=np.int32)
|
|
630
|
+
|
|
631
|
+
# Power 2^0 is just the parents
|
|
632
|
+
self._up_table[0, :] = self.parents
|
|
633
|
+
|
|
634
|
+
# Power 2^j = 2^{j-1} jump from the 2^{j-1} ancestor
|
|
635
|
+
# Fully vectorized initialization
|
|
636
|
+
for j in range(1, max_log):
|
|
637
|
+
prev_ancestors = self._up_table[j-1, :]
|
|
638
|
+
self._up_table[j, :] = self._up_table[j-1, prev_ancestors]
|
|
639
|
+
|
|
640
|
+
def save(self, directory: str) -> None:
|
|
641
|
+
"""Saves the vectorized tree to a directory for fast loading."""
|
|
642
|
+
if not os.path.exists(directory):
|
|
643
|
+
os.makedirs(directory)
|
|
644
|
+
|
|
645
|
+
logger.info(f"Saving binary cache to {directory}...")
|
|
646
|
+
np.save(os.path.join(directory, "index_to_id.npy"), self._index_to_id)
|
|
647
|
+
np.save(os.path.join(directory, "parents.npy"), self.parents)
|
|
648
|
+
np.save(os.path.join(directory, "depths.npy"), self.depths)
|
|
649
|
+
np.save(os.path.join(directory, "ranks.npy"), self.ranks)
|
|
650
|
+
np.save(os.path.join(directory, "entry_times.npy"), self.entry_times)
|
|
651
|
+
np.save(os.path.join(directory, "exit_times.npy"), self.exit_times)
|
|
652
|
+
|
|
653
|
+
# Save Polars metadata
|
|
654
|
+
self._scientific_names.to_frame().write_ipc(os.path.join(directory, "scientific_names.ipc"))
|
|
655
|
+
self._common_names.to_frame().write_ipc(os.path.join(directory, "common_names.ipc"))
|
|
656
|
+
self._search_index.write_ipc(os.path.join(directory, "search_index.ipc"))
|
|
657
|
+
|
|
658
|
+
maps_dir = os.path.join(directory, "canonical_maps")
|
|
659
|
+
if not os.path.exists(maps_dir):
|
|
660
|
+
os.makedirs(maps_dir)
|
|
661
|
+
for rank, arr in self.canonical_maps.items():
|
|
662
|
+
np.save(os.path.join(maps_dir, f"{rank}.npy"), arr)
|
|
663
|
+
|
|
664
|
+
import pickle
|
|
665
|
+
with open(os.path.join(directory, "metadata.pkl"), 'wb') as f:
|
|
666
|
+
pickle.dump({
|
|
667
|
+
"rank_names": self.rank_names,
|
|
668
|
+
"top_rank": self.top_rank,
|
|
669
|
+
"provenance": {
|
|
670
|
+
"build_time": self._build_time,
|
|
671
|
+
"source_nodes": self._source_nodes,
|
|
672
|
+
"source_names": self._source_names,
|
|
673
|
+
"package_version": __version__,
|
|
674
|
+
"node_count": len(self._index_to_id),
|
|
675
|
+
"max_depth": int(np.max(self.depths))
|
|
676
|
+
}
|
|
677
|
+
}, f)
|
|
678
|
+
|
|
679
|
+
@classmethod
|
|
680
|
+
def load(cls, directory: str) -> 'JolTree':
|
|
681
|
+
"""Loads the vectorized tree from a binary cache directory."""
|
|
682
|
+
logger.info(f"Loading binary cache from {directory}...")
|
|
683
|
+
|
|
684
|
+
import pickle
|
|
685
|
+
with open(os.path.join(directory, "metadata.pkl"), 'rb') as f:
|
|
686
|
+
meta = pickle.load(f)
|
|
687
|
+
prov = meta.get("provenance", {})
|
|
688
|
+
saved_version = prov.get("package_version", "unknown")
|
|
689
|
+
def version_to_tuple(v):
|
|
690
|
+
try:
|
|
691
|
+
return tuple(map(int, v.split('.')))
|
|
692
|
+
except (ValueError, AttributeError):
|
|
693
|
+
return (0, 0, 0)
|
|
694
|
+
if version_to_tuple(saved_version) < version_to_tuple(MINIMUM_CACHE_VERSION):
|
|
695
|
+
raise RuntimeError(
|
|
696
|
+
f"Incompatible taxonomy cache. Saved version: {saved_version}, "
|
|
697
|
+
f"Minimum required: {MINIMUM_CACHE_VERSION}. Please rebuild with build_from_dmp()."
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
tree = cls()
|
|
701
|
+
tree.rank_names = meta["rank_names"]
|
|
702
|
+
tree.top_rank = meta.get("top_rank", "domain")
|
|
703
|
+
tree._build_time = prov.get("build_time")
|
|
704
|
+
tree._source_nodes = prov.get("source_nodes")
|
|
705
|
+
tree._source_names = prov.get("source_names")
|
|
706
|
+
|
|
707
|
+
tree._index_to_id = np.load(os.path.join(directory, "index_to_id.npy"))
|
|
708
|
+
tree.parents = np.load(os.path.join(directory, "parents.npy"))
|
|
709
|
+
tree.depths = np.load(os.path.join(directory, "depths.npy"))
|
|
710
|
+
tree.ranks = np.load(os.path.join(directory, "ranks.npy"))
|
|
711
|
+
tree.entry_times = np.load(os.path.join(directory, "entry_times.npy"))
|
|
712
|
+
tree.exit_times = np.load(os.path.join(directory, "exit_times.npy"))
|
|
713
|
+
|
|
714
|
+
# Load Polars metadata
|
|
715
|
+
tree._scientific_names = pl.read_ipc(os.path.join(directory, "scientific_names.ipc"))["scientific_name"]
|
|
716
|
+
tree._common_names = pl.read_ipc(os.path.join(directory, "common_names.ipc"))["common_name"]
|
|
717
|
+
tree._search_index = pl.read_ipc(os.path.join(directory, "search_index.ipc"))
|
|
718
|
+
|
|
719
|
+
maps_dir = os.path.join(directory, "canonical_maps")
|
|
720
|
+
if os.path.exists(maps_dir):
|
|
721
|
+
for filename in os.listdir(maps_dir):
|
|
722
|
+
if filename.endswith(".npy"):
|
|
723
|
+
rank = filename[:-4]
|
|
724
|
+
tree.canonical_maps[rank] = np.load(os.path.join(maps_dir, filename))
|
|
725
|
+
|
|
726
|
+
# Re-initialize vectorized caches
|
|
727
|
+
tree._prepare_vectorized_caches()
|
|
728
|
+
|
|
729
|
+
logger.info("Loaded taxonomy cache:")
|
|
730
|
+
logger.info(f" Version: {saved_version}")
|
|
731
|
+
logger.info(f" Build time: {tree._build_time}")
|
|
732
|
+
logger.info(f" Node count: {prov.get('node_count', 'Unknown'):,}")
|
|
733
|
+
logger.info(f" Top rank: {tree.top_rank}")
|
|
734
|
+
return tree
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: joltax
|
|
3
|
+
Version: 0.1.1
|
|
4
|
+
Summary: A high-performance, vectorized taxonomy library for Python.
|
|
5
|
+
Author-email: Daniel Svensson <daniel.svensson@umu.se>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/SweBiTS/JolTax
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/SweBiTS/JolTax/issues
|
|
9
|
+
Project-URL: Source Code, https://github.com/SweBiTS/JolTax
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: numpy>=1.20.0
|
|
19
|
+
Requires-Dist: polars>=0.20.0
|
|
20
|
+
Requires-Dist: rapidfuzz>=3.0.0
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
<p align="center">
|
|
24
|
+
<img src="assets/logo.png" alt="joltax logo" width="300">
|
|
25
|
+
</p>
|
|
26
|
+
|
|
27
|
+
# joltax
|
|
28
|
+
|
|
29
|
+
**High-performance, vectorized taxonomy library for Python.**
|
|
30
|
+
|
|
31
|
+
`JolTax` is a Python library designed to handle massive taxonomies with extreme efficiency. By representing taxonomy trees as contiguous NumPy arrays and leveraging Polars for mass data handling, it achieves lightning-fast traversals, constant-time clade queries, and rapid mass annotation of large datasets.
|
|
32
|
+
|
|
33
|
+
## Key Features
|
|
34
|
+
- **Vectorized Performance:** Uses hardware-accelerated NumPy operations for million-scale property lookups.
|
|
35
|
+
- **Memory Efficient:** Optimized string store using Polars/Arrow reduces RAM footprint.
|
|
36
|
+
- **Fuzzy Name Search:** Rapid fuzzy matching using RapidFuzz to find TaxIDs from names.
|
|
37
|
+
- **Instant Clade Queries:** Quickly find all descendants of any node (even millions) using optimized range indexing.
|
|
38
|
+
- **Hyper-Vectorized LCA search:** Lowest Common Ancestor (LCA) search and node-to-node distance calculations at lightning speeds.
|
|
39
|
+
- **Mass Annotation:** Annotate massive TaxID tables with 2,000,000+ rows in under a second using Polars.
|
|
40
|
+
|
|
41
|
+
## Quick Start
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
from joltax.joltree import JolTree
|
|
45
|
+
|
|
46
|
+
# Build and process the NCBI taxonomy
|
|
47
|
+
tree = JolTree(nodes_file='nodes.dmp', names_file='names.dmp')
|
|
48
|
+
|
|
49
|
+
# Save for instant loading next time
|
|
50
|
+
tree.save('my_taxonomy_cache')
|
|
51
|
+
|
|
52
|
+
# Re-load in milliseconds (using zero-copy Arrow IPC)
|
|
53
|
+
tree = JolTree.load('my_taxonomy_cache')
|
|
54
|
+
|
|
55
|
+
# Batch LCA (process 10,000 pairs in <10ms)
|
|
56
|
+
lcas = tree.get_lca_batch(ids1, ids2)
|
|
57
|
+
|
|
58
|
+
# Fuzzy search for a name (returns a Polars DataFrame)
|
|
59
|
+
results = tree.search_name('Escherchia', fuzzy=True)
|
|
60
|
+
print(results)
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Installation
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
cd joltax
|
|
67
|
+
pip install .
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
Requires: `numpy`, `polars`, `rapidfuzz`.
|
|
71
|
+
|
|
72
|
+
## Documentation
|
|
73
|
+
|
|
74
|
+
For a detailed API reference and a comprehensive "How-To" guide with example workflows, please see [USAGE.md](./USAGE.md).
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
joltax/__init__.py
|
|
5
|
+
joltax/joltree.py
|
|
6
|
+
joltax.egg-info/PKG-INFO
|
|
7
|
+
joltax.egg-info/SOURCES.txt
|
|
8
|
+
joltax.egg-info/dependency_links.txt
|
|
9
|
+
joltax.egg-info/requires.txt
|
|
10
|
+
joltax.egg-info/top_level.txt
|
|
11
|
+
tests/test_tree.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
joltax
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "joltax"
|
|
7
|
+
version = "0.1.1"
|
|
8
|
+
authors = [
|
|
9
|
+
{ name="Daniel Svensson", email="daniel.svensson@umu.se" },
|
|
10
|
+
]
|
|
11
|
+
description = "A high-performance, vectorized taxonomy library for Python."
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.8"
|
|
14
|
+
license = {text = "MIT"}
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"License :: OSI Approved :: MIT License",
|
|
18
|
+
"Operating System :: OS Independent",
|
|
19
|
+
"Intended Audience :: Science/Research",
|
|
20
|
+
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
|
21
|
+
]
|
|
22
|
+
dependencies = [
|
|
23
|
+
"numpy>=1.20.0",
|
|
24
|
+
"polars>=0.20.0",
|
|
25
|
+
"rapidfuzz>=3.0.0",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
[project.urls]
|
|
29
|
+
"Homepage" = "https://github.com/SweBiTS/JolTax"
|
|
30
|
+
"Bug Tracker" = "https://github.com/SweBiTS/JolTax/issues"
|
|
31
|
+
"Source Code" = "https://github.com/SweBiTS/JolTax"
|
|
32
|
+
|
|
33
|
+
[tool.setuptools]
|
|
34
|
+
packages = ["joltax"]
|
joltax-0.1.1/setup.cfg
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import numpy as np
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
# Add the project root to sys.path
|
|
8
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
9
|
+
|
|
10
|
+
from joltax.joltree import JolTree
|
|
11
|
+
|
|
12
|
+
class TestJolTree(unittest.TestCase):
|
|
13
|
+
@classmethod
|
|
14
|
+
def setUpClass(cls):
|
|
15
|
+
cls.names_file = 'tests/data/names.dmp'
|
|
16
|
+
cls.nodes_file = 'tests/data/nodes.dmp'
|
|
17
|
+
# Check if files exist, if not, create them (should be copied already)
|
|
18
|
+
if not os.path.exists(cls.names_file):
|
|
19
|
+
raise FileNotFoundError(f"Missing test data: {cls.names_file}")
|
|
20
|
+
|
|
21
|
+
cls.tree = JolTree(nodes_file=cls.nodes_file, names_file=cls.names_file)
|
|
22
|
+
|
|
23
|
+
def test_lineage(self):
|
|
24
|
+
# 562 (E. coli) -> 561 (Escherichia) -> 543 -> 91347 -> 1236 -> 1224 -> 2 -> 1
|
|
25
|
+
lineage = self.tree.get_lineage(562)
|
|
26
|
+
expected = [1, 2, 1224, 1236, 91347, 543, 561, 562]
|
|
27
|
+
self.assertEqual(lineage, expected)
|
|
28
|
+
|
|
29
|
+
def test_clade(self):
|
|
30
|
+
# Clade of 561 (genus) should contain 561 and 562 (species)
|
|
31
|
+
clade = self.tree.get_clade(561)
|
|
32
|
+
self.assertIn(561, clade)
|
|
33
|
+
self.assertIn(562, clade)
|
|
34
|
+
self.assertEqual(len(clade), 2)
|
|
35
|
+
|
|
36
|
+
def test_lca(self):
|
|
37
|
+
# LCA of 562 and 561 is 561
|
|
38
|
+
lca = self.tree.get_lca(562, 561)
|
|
39
|
+
self.assertEqual(lca, 561)
|
|
40
|
+
|
|
41
|
+
# LCA of 562 and 2 (Bacteria) is 2
|
|
42
|
+
lca = self.tree.get_lca(562, 2)
|
|
43
|
+
self.assertEqual(lca, 2)
|
|
44
|
+
|
|
45
|
+
def test_distance(self):
|
|
46
|
+
# 562 to 561 is 1 step
|
|
47
|
+
self.assertEqual(self.tree.get_distance(562, 561), 1)
|
|
48
|
+
# 562 to 2 is 6 steps
|
|
49
|
+
self.assertEqual(self.tree.get_distance(562, 2), 6)
|
|
50
|
+
|
|
51
|
+
def test_get_name_and_rank(self):
|
|
52
|
+
self.assertEqual(self.tree.get_name(562), 'Escherichia coli')
|
|
53
|
+
self.assertEqual(self.tree.get_rank(562), 'species')
|
|
54
|
+
self.assertEqual(self.tree.get_name(2), 'Bacteria')
|
|
55
|
+
self.assertEqual(self.tree.get_rank(2), 'superkingdom')
|
|
56
|
+
# Test unknown
|
|
57
|
+
self.assertEqual(self.tree.get_name(999999), 'Unknown_999999')
|
|
58
|
+
self.assertEqual(self.tree.get_rank(999999), 'unknown')
|
|
59
|
+
|
|
60
|
+
def test_annotate_table(self):
|
|
61
|
+
tax_ids = [562, 561, 2]
|
|
62
|
+
df = self.tree.annotate_table(tax_ids)
|
|
63
|
+
self.assertIsInstance(df, pl.DataFrame)
|
|
64
|
+
self.assertEqual(len(df), 3)
|
|
65
|
+
self.assertIn('species', df.columns)
|
|
66
|
+
self.assertIn('genus', df.columns)
|
|
67
|
+
|
|
68
|
+
# Check first row (562)
|
|
69
|
+
row0 = df.row(0, named=True)
|
|
70
|
+
self.assertEqual(row0['species'], 'Escherichia coli')
|
|
71
|
+
self.assertEqual(row0['genus'], 'Escherichia')
|
|
72
|
+
self.assertEqual(row0['scientific_name'], 'Escherichia coli')
|
|
73
|
+
|
|
74
|
+
def test_name_search(self):
|
|
75
|
+
# Search by scientific name
|
|
76
|
+
df = self.tree.search_name('Escherichia coli')
|
|
77
|
+
self.assertIn(562, df['tax_id'].to_list())
|
|
78
|
+
|
|
79
|
+
# Search by common name
|
|
80
|
+
df = self.tree.search_name('all')
|
|
81
|
+
self.assertIn(1, df['tax_id'].to_list())
|
|
82
|
+
|
|
83
|
+
def test_fuzzy_search(self):
|
|
84
|
+
# Typo: "Escherchia"
|
|
85
|
+
df = self.tree.search_name('Escherchia', fuzzy=True)
|
|
86
|
+
self.assertIsInstance(df, pl.DataFrame)
|
|
87
|
+
self.assertTrue(len(df) > 0)
|
|
88
|
+
# Top result should be Escherichia or Escherichia coli
|
|
89
|
+
top_name = df.row(0, named=True)['matched_name']
|
|
90
|
+
self.assertIn('Escherichia', top_name)
|
|
91
|
+
|
|
92
|
+
def test_save_load(self):
|
|
93
|
+
import shutil
|
|
94
|
+
cache_dir = 'tests/cache_test'
|
|
95
|
+
if os.path.exists(cache_dir):
|
|
96
|
+
shutil.rmtree(cache_dir)
|
|
97
|
+
|
|
98
|
+
self.tree.save(cache_dir)
|
|
99
|
+
new_tree = JolTree.load(cache_dir)
|
|
100
|
+
|
|
101
|
+
self.assertEqual(new_tree.get_lineage(562), self.tree.get_lineage(562))
|
|
102
|
+
# Check name index loaded
|
|
103
|
+
df = new_tree.search_name('Escherichia coli')
|
|
104
|
+
self.assertIn(562, df['tax_id'].to_list())
|
|
105
|
+
|
|
106
|
+
shutil.rmtree(cache_dir)
|
|
107
|
+
|
|
108
|
+
def test_version_validation(self):
|
|
109
|
+
import shutil
|
|
110
|
+
import pickle
|
|
111
|
+
cache_dir = 'tests/version_test'
|
|
112
|
+
if os.path.exists(cache_dir):
|
|
113
|
+
shutil.rmtree(cache_dir)
|
|
114
|
+
|
|
115
|
+
self.tree.save(cache_dir)
|
|
116
|
+
|
|
117
|
+
# Manually corrupt metadata with old version
|
|
118
|
+
meta_path = os.path.join(cache_dir, "metadata.pkl")
|
|
119
|
+
with open(meta_path, 'rb') as f:
|
|
120
|
+
meta = pickle.load(f)
|
|
121
|
+
|
|
122
|
+
meta["provenance"]["package_version"] = "0.0.1" # Older than 0.1.0
|
|
123
|
+
|
|
124
|
+
with open(meta_path, 'wb') as f:
|
|
125
|
+
pickle.dump(meta, f)
|
|
126
|
+
|
|
127
|
+
# Should raise RuntimeError
|
|
128
|
+
with self.assertRaises(RuntimeError) as cm:
|
|
129
|
+
JolTree.load(cache_dir)
|
|
130
|
+
|
|
131
|
+
self.assertIn("Incompatible taxonomy cache", str(cm.exception))
|
|
132
|
+
shutil.rmtree(cache_dir)
|
|
133
|
+
|
|
134
|
+
if __name__ == '__main__':
|
|
135
|
+
# We skip tests if dependencies aren't installed
|
|
136
|
+
try:
|
|
137
|
+
import numpy
|
|
138
|
+
import polars
|
|
139
|
+
import rapidfuzz
|
|
140
|
+
unittest.main()
|
|
141
|
+
except ImportError:
|
|
142
|
+
print("Skipping tests due to missing dependencies.")
|