joltax 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
joltax-0.1.1/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Swedish Biodiversity in Time and Space (SweBiTS)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
joltax-0.1.1/PKG-INFO ADDED
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.4
2
+ Name: joltax
3
+ Version: 0.1.1
4
+ Summary: A high-performance, vectorized taxonomy library for Python.
5
+ Author-email: Daniel Svensson <daniel.svensson@umu.se>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/SweBiTS/JolTax
8
+ Project-URL: Bug Tracker, https://github.com/SweBiTS/JolTax/issues
9
+ Project-URL: Source Code, https://github.com/SweBiTS/JolTax
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: numpy>=1.20.0
19
+ Requires-Dist: polars>=0.20.0
20
+ Requires-Dist: rapidfuzz>=3.0.0
21
+ Dynamic: license-file
22
+
23
+ <p align="center">
24
+ <img src="assets/logo.png" alt="joltax logo" width="300">
25
+ </p>
26
+
27
+ # joltax
28
+
29
+ **High-performance, vectorized taxonomy library for Python.**
30
+
31
+ `JolTax` is a Python library designed to handle massive taxonomies with extreme efficiency. By representing taxonomy trees as contiguous NumPy arrays and leveraging Polars for mass data handling, it achieves lightning-fast traversals, constant-time clade queries, and rapid mass annotation of large datasets.
32
+
33
+ ## Key Features
34
+ - **Vectorized Performance:** Uses hardware-accelerated NumPy operations for million-scale property lookups.
35
+ - **Memory Efficient:** Optimized string store using Polars/Arrow reduces RAM footprint.
36
+ - **Fuzzy Name Search:** Rapid fuzzy matching using RapidFuzz to find TaxIDs from names.
37
+ - **Instant Clade Queries:** Quickly find all descendants of any node (even millions) using optimized range indexing.
38
+ - **Hyper-Vectorized LCA search:** Lowest Common Ancestor (LCA) search and node-to-node distance calculations at lightning speeds.
39
+ - **Mass Annotation:** Annotate massive TaxID tables with 2,000,000+ rows in under a second using Polars.
40
+
41
+ ## Quick Start
42
+
43
+ ```python
44
+ from joltax.joltree import JolTree
45
+
46
+ # Build and process the NCBI taxonomy
47
+ tree = JolTree(nodes_file='nodes.dmp', names_file='names.dmp')
48
+
49
+ # Save for instant loading next time
50
+ tree.save('my_taxonomy_cache')
51
+
52
+ # Re-load in milliseconds (using zero-copy Arrow IPC)
53
+ tree = JolTree.load('my_taxonomy_cache')
54
+
55
+ # Batch LCA (process 10,000 pairs in <10ms)
56
+ lcas = tree.get_lca_batch(ids1, ids2)
57
+
58
+ # Fuzzy search for a name (returns a Polars DataFrame)
59
+ results = tree.search_name('Escherchia', fuzzy=True)
60
+ print(results)
61
+ ```
62
+
63
+ ## Installation
64
+
65
+ ```bash
66
+ cd joltax
67
+ pip install .
68
+ ```
69
+
70
+ Requires: `numpy`, `polars`, `rapidfuzz`.
71
+
72
+ ## Documentation
73
+
74
+ For a detailed API reference and a comprehensive "How-To" guide with example workflows, please see [USAGE.md](./USAGE.md).
joltax-0.1.1/README.md ADDED
@@ -0,0 +1,52 @@
1
+ <p align="center">
2
+ <img src="assets/logo.png" alt="joltax logo" width="300">
3
+ </p>
4
+
5
+ # joltax
6
+
7
+ **High-performance, vectorized taxonomy library for Python.**
8
+
9
+ `JolTax` is a Python library designed to handle massive taxonomies with extreme efficiency. By representing taxonomy trees as contiguous NumPy arrays and leveraging Polars for mass data handling, it achieves lightning-fast traversals, constant-time clade queries, and rapid mass annotation of large datasets.
10
+
11
+ ## Key Features
12
+ - **Vectorized Performance:** Uses hardware-accelerated NumPy operations for million-scale property lookups.
13
+ - **Memory Efficient:** Optimized string store using Polars/Arrow reduces RAM footprint.
14
+ - **Fuzzy Name Search:** Rapid fuzzy matching using RapidFuzz to find TaxIDs from names.
15
+ - **Instant Clade Queries:** Quickly find all descendants of any node (even millions) using optimized range indexing.
16
+ - **Hyper-Vectorized LCA search:** Lowest Common Ancestor (LCA) search and node-to-node distance calculations at lightning speeds.
17
+ - **Mass Annotation:** Annotate massive TaxID tables with 2,000,000+ rows in under a second using Polars.
18
+
19
+ ## Quick Start
20
+
21
+ ```python
22
+ from joltax.joltree import JolTree
23
+
24
+ # Build and process the NCBI taxonomy
25
+ tree = JolTree(nodes_file='nodes.dmp', names_file='names.dmp')
26
+
27
+ # Save for instant loading next time
28
+ tree.save('my_taxonomy_cache')
29
+
30
+ # Re-load in milliseconds (using zero-copy Arrow IPC)
31
+ tree = JolTree.load('my_taxonomy_cache')
32
+
33
+ # Batch LCA (process 10,000 pairs in <10ms)
34
+ lcas = tree.get_lca_batch(ids1, ids2)
35
+
36
+ # Fuzzy search for a name (returns a Polars DataFrame)
37
+ results = tree.search_name('Escherchia', fuzzy=True)
38
+ print(results)
39
+ ```
40
+
41
+ ## Installation
42
+
43
+ ```bash
44
+ cd joltax
45
+ pip install .
46
+ ```
47
+
48
+ Requires: `numpy`, `polars`, `rapidfuzz`.
49
+
50
+ ## Documentation
51
+
52
+ For a detailed API reference and a comprehensive "How-To" guide with example workflows, please see [USAGE.md](./USAGE.md).
@@ -0,0 +1 @@
1
+ from .joltree import JolTree
@@ -0,0 +1,734 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ joltax/joltree.py
4
+ Implementation of a high-performance, vectorized taxonomy tree.
5
+ """
6
+
7
+ __version__ = "0.1.1"
8
+
9
+ # The minimum version of a saved taxonomy cache that is compatible with this software.
10
+ # Increment this when making breaking changes to the binary layout or metadata structure.
11
+ MINIMUM_CACHE_VERSION = "0.1.1"
12
+
13
+ import logging
14
+ import os
15
+ import datetime
16
+ from typing import Dict, List, Optional, Set, Union, Tuple
17
+ from collections import namedtuple
18
+
19
+ import numpy as np
20
+ import polars as pl
21
+ from rapidfuzz import process, fuzz, utils
22
+
23
+ # Set up logging for the module
24
+ logging.basicConfig(
25
+ format='%(asctime)s %(levelname)-8s %(message)s',
26
+ level=logging.INFO,
27
+ datefmt='%Y-%m-%d [%H:%M:%S]'
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Standard canonical ranks in order (highest to lowest)
32
+ # Including both superkingdom and domain for compatibility with pre/post-2025 taxonomies
33
+ CANONICAL_RANKS = [
34
+ 'superkingdom', 'domain', 'kingdom', 'phylum',
35
+ 'class', 'order', 'family', 'genus', 'species'
36
+ ]
37
+
38
+ # Mapping rank names to standard Kraken-style codes
39
+ RANK_TO_CODE = {
40
+ 'superkingdom': 'D',
41
+ 'domain': 'D',
42
+ 'kingdom': 'K',
43
+ 'phylum': 'P',
44
+ 'class': 'C',
45
+ 'order': 'O',
46
+ 'family': 'F',
47
+ 'genus': 'G',
48
+ 'species': 'S'
49
+ }
50
+
51
+ class JolTree:
52
+ """
53
+ A high-performance taxonomy representation using vectorized arrays.
54
+
55
+ This class replaces traditional object-oriented trees with contiguous
56
+ NumPy arrays for lightning-fast lookups, traversals, and mass annotations.
57
+ """
58
+
59
+ def __init__(self, nodes_file: Optional[str] = None, names_file: Optional[str] = None):
60
+ """
61
+ Initialize the taxonomy tree. If files are provided, it builds from DMP files.
62
+ Otherwise, it can be loaded from a binary cache using `load()`.
63
+
64
+ Args:
65
+ nodes_file: Path to NCBI nodes.dmp
66
+ names_file: Path to NCBI names.dmp
67
+ """
68
+ # Vectorized internal index mapping (sorted array of TaxIDs)
69
+ self._index_to_id: np.ndarray = np.array([], dtype=np.int32)
70
+
71
+ # Primary arrays (indexed by the dense internal index)
72
+ self.parents: np.ndarray = np.array([], dtype=np.int32)
73
+ self.depths: np.ndarray = np.array([], dtype=np.int32)
74
+ self.ranks: np.ndarray = np.array([], dtype=np.uint8)
75
+
76
+ # Metadata storage (Polars Series for memory efficiency)
77
+ self._scientific_names: pl.Series = pl.Series("scientific_name", [], dtype=pl.String)
78
+ self._common_names: pl.Series = pl.Series("common_name", [], dtype=pl.String)
79
+ self.rank_names: List[str] = []
80
+ self.top_rank: str = "domain" # Default, will be detected
81
+ self._source_nodes: Optional[str] = None
82
+ self._source_names: Optional[str] = None
83
+ self._build_time: Optional[str] = None
84
+
85
+ # Clade query support (Euler Tour timestamps)
86
+ self.entry_times: np.ndarray = np.array([], dtype=np.int32)
87
+ self.exit_times: np.ndarray = np.array([], dtype=np.int32)
88
+
89
+ # Binary lifting table for LCA (initialized on demand)
90
+ self._up_table: Optional[np.ndarray] = None
91
+
92
+ # Pre-calculated canonical rank maps (dense internal index -> dense internal index)
93
+ # Values are internal indices, not TaxIDs. -1 means no ancestor at that rank.
94
+ self.canonical_maps: Dict[str, np.ndarray] = {}
95
+
96
+ # Search index (Polars DataFrame: name -> tax_id)
97
+ self._search_index: pl.DataFrame = pl.DataFrame(schema={"name": pl.String, "tax_id": pl.Int32})
98
+
99
+ # Caches for vectorized lookup (prepared during build/load)
100
+ self._sci_names_lookup: Optional[pl.Series] = None
101
+ self._rank_names_series: Optional[pl.Series] = None
102
+ self._ranks_extended: Optional[np.ndarray] = None
103
+
104
+ if nodes_file and names_file:
105
+ self.build_from_dmp(nodes_file, names_file)
106
+
107
+ def build_from_dmp(self, nodes_file: str, names_file: str) -> None:
108
+ """
109
+ Parses NCBI DMP files and builds the vectorized internal structure.
110
+
111
+ Args:
112
+ nodes_file: Path to NCBI nodes.dmp
113
+ names_file: Path to NCBI names.dmp
114
+ """
115
+ self._source_nodes = os.path.abspath(nodes_file)
116
+ self._source_names = os.path.abspath(names_file)
117
+ self._build_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
118
+ logger.info(f"Starting taxonomy build at {self._build_time}...")
119
+
120
+ # 1. Parse Names
121
+ logger.info(f"Parsing names from {names_file}...")
122
+ scientific_names = {}
123
+ common_names = {}
124
+ search_data = [] # List of (name, tax_id)
125
+
126
+ with open(names_file, 'r') as f:
127
+ for name_line in f:
128
+ parts = name_line.split('|')
129
+ name_type = parts[3].strip()
130
+
131
+ # Only care about scientific and genbank common names for now
132
+ if name_type not in ['scientific name', 'genbank common name']:
133
+ continue
134
+
135
+ tax_id = int(parts[0].strip())
136
+ name_txt = parts[1].strip()
137
+
138
+ if name_type == 'scientific name':
139
+ scientific_names[tax_id] = name_txt
140
+ elif name_type == 'genbank common name':
141
+ common_names[tax_id] = name_txt
142
+
143
+ search_data.append({"name": name_txt, "tax_id": tax_id})
144
+
145
+ # Build search index
146
+ self._search_index = pl.DataFrame(search_data).sort("name")
147
+
148
+ # 2. Parse Nodes and initial parent structure
149
+ logger.info(f"Parsing nodes from {nodes_file}...")
150
+ temp_parents = {}
151
+ temp_ranks = {}
152
+ all_ranks = set()
153
+
154
+ with open(nodes_file, 'r') as f:
155
+ for line in f:
156
+ parts = line.split('|')
157
+ tax_id = int(parts[0].strip())
158
+ parent_id = int(parts[1].strip())
159
+ rank = parts[2].strip()
160
+
161
+ temp_parents[tax_id] = parent_id
162
+ temp_ranks[tax_id] = rank
163
+ all_ranks.add(rank)
164
+
165
+ # 2.1 Detect top rank (superkingdom vs domain)
166
+ has_sk = 'superkingdom' in all_ranks
167
+ has_dm = 'domain' in all_ranks
168
+ if has_sk and has_dm:
169
+ raise ValueError("Found both 'superkingdom' and 'domain' ranks. The taxonomy must use only one as the top rank.")
170
+ self.top_rank = 'superkingdom' if has_sk else 'domain'
171
+ logger.info(f"Detected top rank: {self.top_rank}")
172
+
173
+ # 3. Create dense mapping
174
+ logger.info("Creating dense mapping and vectorized arrays...")
175
+ sorted_tax_ids = sorted(temp_parents.keys())
176
+ num_nodes = len(sorted_tax_ids)
177
+ self._index_to_id = np.array(sorted_tax_ids, dtype=np.int32)
178
+
179
+ # Mapping rank names to indices
180
+ self.rank_names = sorted(list(all_ranks))
181
+ rank_to_idx = {r: i for i, r in enumerate(self.rank_names)}
182
+
183
+ self.parents = np.zeros(num_nodes, dtype=np.int32)
184
+ self.ranks = np.zeros(num_nodes, dtype=np.uint8)
185
+
186
+ # Temporary dict for building parent connections (will be discarded)
187
+ id_to_index_temp = {tid: i for i, tid in enumerate(sorted_tax_ids)}
188
+
189
+ for tid, i in id_to_index_temp.items():
190
+ parent_id = temp_parents[tid]
191
+ # Handle root (1) which is its own parent in NCBI
192
+ if tid == 1:
193
+ self.parents[i] = i
194
+ else:
195
+ self.parents[i] = id_to_index_temp[parent_id]
196
+
197
+ self.ranks[i] = rank_to_idx[temp_ranks[tid]]
198
+
199
+ # Populate names aligned with indices
200
+ logger.info("Aligning names and ranks...")
201
+ sci_names_list = [scientific_names.get(tid, f"Unknown_{tid}") for tid in sorted_tax_ids]
202
+ com_names_list = [common_names.get(tid) for tid in sorted_tax_ids]
203
+ self._scientific_names = pl.Series("scientific_name", sci_names_list)
204
+ self._common_names = pl.Series("common_name", com_names_list)
205
+
206
+ # 4. Calculate depths
207
+ logger.info("Calculating node depths...")
208
+ self.depths = np.zeros(num_nodes, dtype=np.int32)
209
+ for i in range(num_nodes):
210
+ self._calculate_depth(i)
211
+
212
+ # 5. Build Euler Tour for clade queries
213
+ self._build_euler_tour()
214
+
215
+ # 6. Pre-calculate canonical rank maps
216
+ self._build_canonical_maps()
217
+
218
+ # 7. Prepare caches for vectorized lookups
219
+ self._prepare_vectorized_caches()
220
+
221
+ logger.info("Taxonomy build complete.")
222
+
223
+ def _prepare_vectorized_caches(self) -> None:
224
+ """Initializes caches used for high-performance vectorized lookups."""
225
+ logger.info("Preparing vectorized lookup caches...")
226
+ # Scientific names lookup (aligned with dense internal index + 1 for "Unknown")
227
+ self._sci_names_lookup = self._scientific_names.append(pl.Series([None]))
228
+
229
+ # Rank names lookup
230
+ self._rank_names_series = pl.Series(self.rank_names).append(pl.Series(["unclassified"]))
231
+
232
+ # Ranks extended with a pointer to "unclassified" for unknown nodes
233
+ self._ranks_extended = np.append(self.ranks, [len(self.rank_names)]).astype(np.int32)
234
+
235
+ def _build_canonical_maps(self) -> None:
236
+ """Pre-calculates canonical rank ancestors for all nodes."""
237
+ logger.info("Pre-calculating canonical rank maps...")
238
+ num_nodes = len(self._index_to_id)
239
+
240
+ # Identify all canonical ranks to track
241
+ canonical_columns = [self.top_rank] + [r for r in CANONICAL_RANKS if r not in ['superkingdom', 'domain']]
242
+
243
+ # Initialize maps with -1 (meaning no ancestor at that rank)
244
+ self.canonical_maps = {rank: np.full(num_nodes, -1, dtype=np.int32) for rank in canonical_columns}
245
+
246
+ # Sort nodes by depth to ensure parents are processed before children
247
+ for i in range(num_nodes):
248
+ curr_idx = i
249
+ root_idx = 0 # TaxID 1 is always the first in sorted_tax_ids
250
+ while True:
251
+ rank_name = self.rank_names[self.ranks[curr_idx]]
252
+
253
+ # Normalize superkingdom/domain based on detected top_rank
254
+ mapped_rank = rank_name
255
+ if rank_name in ['superkingdom', 'domain']:
256
+ mapped_rank = self.top_rank
257
+
258
+ if mapped_rank in self.canonical_maps:
259
+ self.canonical_maps[mapped_rank][i] = curr_idx
260
+
261
+ if curr_idx == root_idx:
262
+ break
263
+ curr_idx = self.parents[curr_idx]
264
+
265
+ def _calculate_depth(self, index: int) -> int:
266
+ """Recursive depth calculation with memoization."""
267
+ if index == 0: # TaxID 1 is always index 0
268
+ return 0
269
+ if self.depths[index] != 0:
270
+ return self.depths[index]
271
+
272
+ d = self._calculate_depth(self.parents[index]) + 1
273
+ self.depths[index] = d
274
+ return d
275
+
276
+ def _build_euler_tour(self) -> None:
277
+ """Assigns entry/exit times to enable instant clade queries."""
278
+ logger.info("Building Euler Tour index for clade queries...")
279
+ num_nodes = len(self._index_to_id)
280
+ self.entry_times = np.zeros(num_nodes, dtype=np.int32)
281
+ self.exit_times = np.zeros(num_nodes, dtype=np.int32)
282
+
283
+ # Build adjacency list (children)
284
+ children = [[] for _ in range(num_nodes)]
285
+ root_idx = 0 # TaxID 1
286
+ for i, p in enumerate(self.parents):
287
+ if i != root_idx:
288
+ children[p].append(i)
289
+
290
+ timer = 0
291
+ stack = [(root_idx, False)] # (index, is_processed)
292
+
293
+ while stack:
294
+ idx, processed = stack.pop()
295
+ if not processed:
296
+ self.entry_times[idx] = timer
297
+ timer += 1
298
+ stack.append((idx, True))
299
+ for child in reversed(children[idx]):
300
+ stack.append((child, False))
301
+ else:
302
+ self.exit_times[idx] = timer - 1
303
+
304
+ def _get_index(self, tax_id: int) -> int:
305
+ """Returns the internal index for a TaxID, or -1 if not found."""
306
+ idx = np.searchsorted(self._index_to_id, tax_id)
307
+ if idx < len(self._index_to_id) and self._index_to_id[idx] == tax_id:
308
+ return int(idx)
309
+ return -1
310
+
311
+ def _get_indices(self, tax_ids: np.ndarray) -> np.ndarray:
312
+ """Returns internal indices for an array of TaxIDs, with -1 for missing."""
313
+ indices = np.searchsorted(self._index_to_id, tax_ids)
314
+ # Handle out of bounds
315
+ mask = indices < len(self._index_to_id)
316
+ # Check for actual equality
317
+ valid = np.zeros(len(tax_ids), dtype=bool)
318
+ valid[mask] = self._index_to_id[indices[mask]] == tax_ids[mask]
319
+ return np.where(valid, indices, -1)
320
+
321
+ def get_lineage(self, tax_id: int) -> List[int]:
322
+ """Returns the full lineage from root to the given TaxID."""
323
+ idx = self._get_index(tax_id)
324
+ if idx == -1:
325
+ logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
326
+ return []
327
+
328
+ lineage = []
329
+ root_idx = 0
330
+
331
+ while True:
332
+ lineage.append(int(self._index_to_id[idx]))
333
+ if idx == root_idx:
334
+ break
335
+ idx = self.parents[idx]
336
+
337
+ return lineage[::-1]
338
+
339
+ def get_name(self, tax_id: int) -> str:
340
+ """Returns the scientific name of the given TaxID."""
341
+ idx = self._get_index(tax_id)
342
+ if idx != -1:
343
+ return self._scientific_names[idx]
344
+ return f"Unknown_{tax_id}"
345
+
346
+ def get_common_name(self, tax_id: int) -> Optional[str]:
347
+ """Returns the genbank common name of the given TaxID, if available."""
348
+ idx = self._get_index(tax_id)
349
+ if idx != -1:
350
+ return self._common_names[idx]
351
+ return None
352
+
353
+ def get_rank(self, tax_id: int) -> str:
354
+ """Returns the taxonomic rank of the given TaxID."""
355
+ idx = self._get_index(tax_id)
356
+ if idx == -1:
357
+ logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
358
+ return "unknown"
359
+ return self.rank_names[self.ranks[idx]]
360
+
361
+ def search_name(self, query: str, fuzzy: bool = False, limit: int = 10, score_cutoff: float = 60.0) -> pl.DataFrame:
362
+ """
363
+ Searches for TaxIDs by name.
364
+ """
365
+ if not fuzzy:
366
+ matches = self._search_index.filter(pl.col("name") == query)
367
+ if matches.is_empty():
368
+ return pl.DataFrame(schema=["tax_id", "name", "rank", "score"])
369
+
370
+ # Vectorized rank lookup for matches
371
+ tids = matches["tax_id"].to_numpy()
372
+ indices = self._get_indices(tids)
373
+ ranks = [self.rank_names[self.ranks[i]] if i != -1 else "unknown" for i in indices]
374
+
375
+ return matches.with_columns([
376
+ pl.Series("rank", ranks),
377
+ pl.lit(100.0).alias("score")
378
+ ])
379
+
380
+ # Fuzzy matching path
381
+ unique_names = self._search_index["name"].unique().to_list()
382
+
383
+ # rapidfuzz extract
384
+ matches = process.extract(
385
+ query,
386
+ unique_names,
387
+ scorer=fuzz.WRatio,
388
+ limit=limit,
389
+ processor=utils.default_process,
390
+ score_cutoff=score_cutoff
391
+ )
392
+
393
+ data = []
394
+ for match_str, score, _ in matches:
395
+ # Find all TaxIDs associated with this name
396
+ tids = self._search_index.filter(pl.col("name") == match_str)["tax_id"].to_list()
397
+ for tid in tids:
398
+ idx = self._get_index(tid)
399
+ rank = self.rank_names[self.ranks[idx]] if idx != -1 else "unknown"
400
+
401
+ # Smart Ranking: Boost scores for canonical ranks
402
+ rank_boost = 0.0
403
+ if rank in CANONICAL_RANKS:
404
+ rank_boost = 2.0
405
+
406
+ data.append({
407
+ "tax_id": tid,
408
+ "matched_name": match_str,
409
+ "scientific_name": self.get_name(tid),
410
+ "rank": rank,
411
+ "score": score + rank_boost
412
+ })
413
+
414
+ if not data:
415
+ return pl.DataFrame(schema=["tax_id", "matched_name", "scientific_name", "rank", "score"])
416
+
417
+ return pl.DataFrame(data).sort("score", descending=True)
418
+
419
+ def get_clade(self, tax_id: int) -> List[int]:
420
+ """Returns all TaxIDs in the clade rooted at the given TaxID."""
421
+ idx = self._get_index(tax_id)
422
+ if idx == -1:
423
+ logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
424
+ return []
425
+
426
+ entry = self.entry_times[idx]
427
+ exit = self.exit_times[idx]
428
+
429
+ mask = (self.entry_times >= entry) & (self.entry_times <= exit)
430
+ return self._index_to_id[mask].astype(int).tolist()
431
+
432
+ def get_clade_at_rank(self, tax_id: int, rank_name: str) -> List[int]:
433
+ """
434
+ Returns all TaxIDs of a specific rank within the clade rooted at tax_id.
435
+ """
436
+ idx = self._get_index(tax_id)
437
+ if idx == -1:
438
+ logger.warning(f"TaxID {tax_id} not found in taxonomy tree.")
439
+ return []
440
+
441
+ try:
442
+ target_rank_idx = self.rank_names.index(rank_name)
443
+ except ValueError:
444
+ logger.warning(f"Rank '{rank_name}' not found in taxonomy. Available ranks: {self.rank_names}")
445
+ return []
446
+
447
+ entry = self.entry_times[idx]
448
+ exit = self.exit_times[idx]
449
+
450
+ mask = (self.entry_times >= entry) & (self.entry_times <= exit) & (self.ranks == target_rank_idx)
451
+ return self._index_to_id[mask].astype(int).tolist()
452
+
453
+ def get_lca(self, tax_id_1: int, tax_id_2: int) -> int:
454
+ """Finds the Lowest Common Ancestor using Binary Lifting."""
455
+ idx1 = self._get_index(tax_id_1)
456
+ idx2 = self._get_index(tax_id_2)
457
+
458
+ if idx1 == -1 or idx2 == -1:
459
+ logger.warning(f"One or both TaxIDs ({tax_id_1}, {tax_id_2}) not found.")
460
+ return 1
461
+
462
+ self._ensure_up_table()
463
+
464
+ if self.depths[idx1] < self.depths[idx2]:
465
+ idx1, idx2 = idx2, idx1
466
+
467
+ diff = self.depths[idx1] - self.depths[idx2]
468
+ max_log = self._up_table.shape[0]
469
+
470
+ for i in range(max_log):
471
+ if (diff >> i) & 1:
472
+ idx1 = self._up_table[i, idx1]
473
+
474
+ if idx1 == idx2:
475
+ return int(self._index_to_id[idx1])
476
+
477
+ for i in reversed(range(max_log)):
478
+ up1 = self._up_table[i, idx1]
479
+ up2 = self._up_table[i, idx2]
480
+ if up1 != up2:
481
+ idx1 = up1
482
+ idx2 = up2
483
+
484
+ return int(self._index_to_id[self.parents[idx1]])
485
+
486
+ def get_distance(self, tax_id_1: int, tax_id_2: int) -> int:
487
+ """Calculates distance (number of edges) between two TaxIDs."""
488
+ lca_id = self.get_lca(tax_id_1, tax_id_2)
489
+ idx1 = self._get_index(tax_id_1)
490
+ idx2 = self._get_index(tax_id_2)
491
+ idx_lca = self._get_index(lca_id)
492
+ return int(self.depths[idx1] + self.depths[idx2] - 2 * self.depths[idx_lca])
493
+
494
+ def get_lca_batch(self, ids1: Union[List[int], np.ndarray], ids2: Union[List[int], np.ndarray]) -> np.ndarray:
495
+ """
496
+ Calculates Lowest Common Ancestor for arrays of TaxIDs.
497
+ Hyper-vectorized implementation for peak performance.
498
+ """
499
+ ids1 = np.array(ids1, dtype=np.int32)
500
+ ids2 = np.array(ids2, dtype=np.int32)
501
+
502
+ if ids1.shape != ids2.shape:
503
+ raise ValueError("Input arrays must have the same shape.")
504
+
505
+ self._ensure_up_table()
506
+
507
+ idx1 = self._get_indices(ids1)
508
+ idx2 = self._get_indices(ids2)
509
+
510
+ # Handle missing IDs by pointing to root (index 0)
511
+ valid_mask = (idx1 != -1) & (idx2 != -1)
512
+ s_idx1 = np.where(valid_mask, idx1, 0)
513
+ s_idx2 = np.where(valid_mask, idx2, 0)
514
+
515
+ # 1. Bring both nodes to the same depth
516
+ d1 = self.depths[s_idx1]
517
+ d2 = self.depths[s_idx2]
518
+
519
+ # Ensure s_idx1 is the deeper one
520
+ swap = d1 < d2
521
+ s_idx1[swap], s_idx2[swap] = s_idx2[swap], s_idx1[swap]
522
+
523
+ diff = np.abs(d1 - d2)
524
+ max_log = self._up_table.shape[0]
525
+
526
+ for i in range(max_log):
527
+ mask = (diff >> i) & 1 == 1
528
+ if np.any(mask):
529
+ s_idx1[mask] = self._up_table[i, s_idx1[mask]]
530
+
531
+ # 2. Binary search for the LCA
532
+ lca_indices = s_idx1.copy()
533
+ not_same = s_idx1 != s_idx2
534
+
535
+ if np.any(not_same):
536
+ sub1 = s_idx1[not_same]
537
+ sub2 = s_idx2[not_same]
538
+
539
+ for i in reversed(range(max_log)):
540
+ up1 = self._up_table[i, sub1]
541
+ up2 = self._up_table[i, sub2]
542
+
543
+ diff_up = up1 != up2
544
+ sub1[diff_up] = up1[diff_up]
545
+ sub2[diff_up] = up2[diff_up]
546
+
547
+ lca_indices[not_same] = self.parents[sub1]
548
+
549
+ results = self._index_to_id[lca_indices]
550
+ results[~valid_mask] = 1
551
+ return results
552
+
553
+ def get_distance_batch(self, ids1: Union[List[int], np.ndarray], ids2: Union[List[int], np.ndarray]) -> np.ndarray:
554
+ """Vectorized distance calculation for arrays of TaxIDs."""
555
+ ids1 = np.array(ids1, dtype=np.int32)
556
+ ids2 = np.array(ids2, dtype=np.int32)
557
+
558
+ lca_ids = self.get_lca_batch(ids1, ids2)
559
+
560
+ idx1 = self._get_indices(ids1)
561
+ idx2 = self._get_indices(ids2)
562
+ idx_lca = self._get_indices(lca_ids)
563
+
564
+ # Mask invalid lookups to avoid OOB errors
565
+ valid = (idx1 != -1) & (idx2 != -1) & (idx_lca != -1)
566
+
567
+ dists = np.zeros(len(ids1), dtype=np.int32)
568
+ if np.any(valid):
569
+ v1, v2, vl = idx1[valid], idx2[valid], idx_lca[valid]
570
+ dists[valid] = self.depths[v1] + self.depths[v2] - 2 * self.depths[vl]
571
+
572
+ return dists
573
+
574
+ def annotate_table(self, tax_ids: Union[List[int], np.ndarray]) -> pl.DataFrame:
575
+ """
576
+ Massively annotates a list of TaxIDs with scientific_names and canonical ranks.
577
+ Extremely efficient for large tables (e.g. 200k+ rows) using Polars and vectorized lookups.
578
+ """
579
+ logger.info(f"Annotating {len(tax_ids)} taxa...")
580
+ canonical_columns = [self.top_rank] + [r for r in CANONICAL_RANKS if r not in ['superkingdom', 'domain']]
581
+
582
+ tax_ids_arr = np.array(tax_ids, dtype=np.int32)
583
+ indices = self._get_indices(tax_ids_arr)
584
+ valid_mask = indices != -1
585
+
586
+ # dummy_idx points to the "Unknown/None" entry at the end of the lookup series
587
+ dummy_idx = len(self._index_to_id)
588
+ safe_indices = np.where(valid_mask, indices, dummy_idx)
589
+
590
+ # Ensure caches are ready
591
+ if self._sci_names_lookup is None:
592
+ self._prepare_vectorized_caches()
593
+
594
+ df_dict = {"tax_id": tax_ids_arr}
595
+
596
+ for rank in canonical_columns:
597
+ # canonical_maps now store internal indices
598
+ ancestor_indices = np.full(len(tax_ids_arr), -1, dtype=np.int32)
599
+ # Map input tax_ids to their ancestor's internal index
600
+ ancestor_indices[valid_mask] = self.canonical_maps[rank][indices[valid_mask]]
601
+
602
+ # Use dummy_idx for missing ancestors
603
+ safe_anc_indices = np.where(ancestor_indices != -1, ancestor_indices, dummy_idx)
604
+
605
+ # Vectorized gather from Polars
606
+ df_dict[rank] = self._sci_names_lookup.gather(safe_anc_indices.astype(np.int32))
607
+
608
+ # Scientific name for the input TaxID
609
+ df_dict["scientific_name"] = self._sci_names_lookup.gather(safe_indices.astype(np.int32))
610
+
611
+ # Rank for the input TaxID
612
+ target_rank_indices = self._ranks_extended[safe_indices]
613
+ df_dict["rank"] = self._rank_names_series.gather(target_rank_indices.astype(np.int32))
614
+
615
+ df = pl.DataFrame(df_dict)
616
+ final_order = ['tax_id'] + canonical_columns + ['scientific_name', 'rank']
617
+ return df.select(final_order)
618
+
619
+ def _ensure_up_table(self) -> None:
620
+ """Lazy initialization of binary lifting table."""
621
+ if self._up_table is not None:
622
+ return
623
+
624
+ logger.info("Initializing binary lifting table (Hyper-Vectorized)...")
625
+ num_nodes = len(self._index_to_id)
626
+ max_log = int(np.ceil(np.log2(np.max(self.depths) + 1)))
627
+
628
+ # Shape: (max_log, num_nodes) - optimized for contiguous column access
629
+ self._up_table = np.zeros((max_log, num_nodes), dtype=np.int32)
630
+
631
+ # Power 2^0 is just the parents
632
+ self._up_table[0, :] = self.parents
633
+
634
+ # Power 2^j = 2^{j-1} jump from the 2^{j-1} ancestor
635
+ # Fully vectorized initialization
636
+ for j in range(1, max_log):
637
+ prev_ancestors = self._up_table[j-1, :]
638
+ self._up_table[j, :] = self._up_table[j-1, prev_ancestors]
639
+
640
+ def save(self, directory: str) -> None:
641
+ """Saves the vectorized tree to a directory for fast loading."""
642
+ if not os.path.exists(directory):
643
+ os.makedirs(directory)
644
+
645
+ logger.info(f"Saving binary cache to {directory}...")
646
+ np.save(os.path.join(directory, "index_to_id.npy"), self._index_to_id)
647
+ np.save(os.path.join(directory, "parents.npy"), self.parents)
648
+ np.save(os.path.join(directory, "depths.npy"), self.depths)
649
+ np.save(os.path.join(directory, "ranks.npy"), self.ranks)
650
+ np.save(os.path.join(directory, "entry_times.npy"), self.entry_times)
651
+ np.save(os.path.join(directory, "exit_times.npy"), self.exit_times)
652
+
653
+ # Save Polars metadata
654
+ self._scientific_names.to_frame().write_ipc(os.path.join(directory, "scientific_names.ipc"))
655
+ self._common_names.to_frame().write_ipc(os.path.join(directory, "common_names.ipc"))
656
+ self._search_index.write_ipc(os.path.join(directory, "search_index.ipc"))
657
+
658
+ maps_dir = os.path.join(directory, "canonical_maps")
659
+ if not os.path.exists(maps_dir):
660
+ os.makedirs(maps_dir)
661
+ for rank, arr in self.canonical_maps.items():
662
+ np.save(os.path.join(maps_dir, f"{rank}.npy"), arr)
663
+
664
+ import pickle
665
+ with open(os.path.join(directory, "metadata.pkl"), 'wb') as f:
666
+ pickle.dump({
667
+ "rank_names": self.rank_names,
668
+ "top_rank": self.top_rank,
669
+ "provenance": {
670
+ "build_time": self._build_time,
671
+ "source_nodes": self._source_nodes,
672
+ "source_names": self._source_names,
673
+ "package_version": __version__,
674
+ "node_count": len(self._index_to_id),
675
+ "max_depth": int(np.max(self.depths))
676
+ }
677
+ }, f)
678
+
679
+ @classmethod
680
+ def load(cls, directory: str) -> 'JolTree':
681
+ """Loads the vectorized tree from a binary cache directory."""
682
+ logger.info(f"Loading binary cache from {directory}...")
683
+
684
+ import pickle
685
+ with open(os.path.join(directory, "metadata.pkl"), 'rb') as f:
686
+ meta = pickle.load(f)
687
+ prov = meta.get("provenance", {})
688
+ saved_version = prov.get("package_version", "unknown")
689
+ def version_to_tuple(v):
690
+ try:
691
+ return tuple(map(int, v.split('.')))
692
+ except (ValueError, AttributeError):
693
+ return (0, 0, 0)
694
+ if version_to_tuple(saved_version) < version_to_tuple(MINIMUM_CACHE_VERSION):
695
+ raise RuntimeError(
696
+ f"Incompatible taxonomy cache. Saved version: {saved_version}, "
697
+ f"Minimum required: {MINIMUM_CACHE_VERSION}. Please rebuild with build_from_dmp()."
698
+ )
699
+
700
+ tree = cls()
701
+ tree.rank_names = meta["rank_names"]
702
+ tree.top_rank = meta.get("top_rank", "domain")
703
+ tree._build_time = prov.get("build_time")
704
+ tree._source_nodes = prov.get("source_nodes")
705
+ tree._source_names = prov.get("source_names")
706
+
707
+ tree._index_to_id = np.load(os.path.join(directory, "index_to_id.npy"))
708
+ tree.parents = np.load(os.path.join(directory, "parents.npy"))
709
+ tree.depths = np.load(os.path.join(directory, "depths.npy"))
710
+ tree.ranks = np.load(os.path.join(directory, "ranks.npy"))
711
+ tree.entry_times = np.load(os.path.join(directory, "entry_times.npy"))
712
+ tree.exit_times = np.load(os.path.join(directory, "exit_times.npy"))
713
+
714
+ # Load Polars metadata
715
+ tree._scientific_names = pl.read_ipc(os.path.join(directory, "scientific_names.ipc"))["scientific_name"]
716
+ tree._common_names = pl.read_ipc(os.path.join(directory, "common_names.ipc"))["common_name"]
717
+ tree._search_index = pl.read_ipc(os.path.join(directory, "search_index.ipc"))
718
+
719
+ maps_dir = os.path.join(directory, "canonical_maps")
720
+ if os.path.exists(maps_dir):
721
+ for filename in os.listdir(maps_dir):
722
+ if filename.endswith(".npy"):
723
+ rank = filename[:-4]
724
+ tree.canonical_maps[rank] = np.load(os.path.join(maps_dir, filename))
725
+
726
+ # Re-initialize vectorized caches
727
+ tree._prepare_vectorized_caches()
728
+
729
+ logger.info("Loaded taxonomy cache:")
730
+ logger.info(f" Version: {saved_version}")
731
+ logger.info(f" Build time: {tree._build_time}")
732
+ logger.info(f" Node count: {prov.get('node_count', 'Unknown'):,}")
733
+ logger.info(f" Top rank: {tree.top_rank}")
734
+ return tree
@@ -0,0 +1,74 @@
1
+ Metadata-Version: 2.4
2
+ Name: joltax
3
+ Version: 0.1.1
4
+ Summary: A high-performance, vectorized taxonomy library for Python.
5
+ Author-email: Daniel Svensson <daniel.svensson@umu.se>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/SweBiTS/JolTax
8
+ Project-URL: Bug Tracker, https://github.com/SweBiTS/JolTax/issues
9
+ Project-URL: Source Code, https://github.com/SweBiTS/JolTax
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: numpy>=1.20.0
19
+ Requires-Dist: polars>=0.20.0
20
+ Requires-Dist: rapidfuzz>=3.0.0
21
+ Dynamic: license-file
22
+
23
+ <p align="center">
24
+ <img src="assets/logo.png" alt="joltax logo" width="300">
25
+ </p>
26
+
27
+ # joltax
28
+
29
+ **High-performance, vectorized taxonomy library for Python.**
30
+
31
+ `JolTax` is a Python library designed to handle massive taxonomies with extreme efficiency. By representing taxonomy trees as contiguous NumPy arrays and leveraging Polars for mass data handling, it achieves lightning-fast traversals, constant-time clade queries, and rapid mass annotation of large datasets.
32
+
33
+ ## Key Features
34
+ - **Vectorized Performance:** Uses hardware-accelerated NumPy operations for million-scale property lookups.
35
+ - **Memory Efficient:** Optimized string store using Polars/Arrow reduces RAM footprint.
36
+ - **Fuzzy Name Search:** Rapid fuzzy matching using RapidFuzz to find TaxIDs from names.
37
+ - **Instant Clade Queries:** Quickly find all descendants of any node (even millions) using optimized range indexing.
38
+ - **Hyper-Vectorized LCA search:** Lowest Common Ancestor (LCA) search and node-to-node distance calculations at lightning speeds.
39
+ - **Mass Annotation:** Annotate massive TaxID tables with 2,000,000+ rows in under a second using Polars.
40
+
41
+ ## Quick Start
42
+
43
+ ```python
44
+ from joltax.joltree import JolTree
45
+
46
+ # Build and process the NCBI taxonomy
47
+ tree = JolTree(nodes_file='nodes.dmp', names_file='names.dmp')
48
+
49
+ # Save for instant loading next time
50
+ tree.save('my_taxonomy_cache')
51
+
52
+ # Re-load in milliseconds (using zero-copy Arrow IPC)
53
+ tree = JolTree.load('my_taxonomy_cache')
54
+
55
+ # Batch LCA (process 10,000 pairs in <10ms)
56
+ lcas = tree.get_lca_batch(ids1, ids2)
57
+
58
+ # Fuzzy search for a name (returns a Polars DataFrame)
59
+ results = tree.search_name('Escherchia', fuzzy=True)
60
+ print(results)
61
+ ```
62
+
63
+ ## Installation
64
+
65
+ ```bash
66
+ cd joltax
67
+ pip install .
68
+ ```
69
+
70
+ Requires: `numpy`, `polars`, `rapidfuzz`.
71
+
72
+ ## Documentation
73
+
74
+ For a detailed API reference and a comprehensive "How-To" guide with example workflows, please see [USAGE.md](./USAGE.md).
@@ -0,0 +1,11 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ joltax/__init__.py
5
+ joltax/joltree.py
6
+ joltax.egg-info/PKG-INFO
7
+ joltax.egg-info/SOURCES.txt
8
+ joltax.egg-info/dependency_links.txt
9
+ joltax.egg-info/requires.txt
10
+ joltax.egg-info/top_level.txt
11
+ tests/test_tree.py
@@ -0,0 +1,3 @@
1
+ numpy>=1.20.0
2
+ polars>=0.20.0
3
+ rapidfuzz>=3.0.0
@@ -0,0 +1 @@
1
+ joltax
@@ -0,0 +1,34 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "joltax"
7
+ version = "0.1.1"
8
+ authors = [
9
+ { name="Daniel Svensson", email="daniel.svensson@umu.se" },
10
+ ]
11
+ description = "A high-performance, vectorized taxonomy library for Python."
12
+ readme = "README.md"
13
+ requires-python = ">=3.8"
14
+ license = {text = "MIT"}
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "License :: OSI Approved :: MIT License",
18
+ "Operating System :: OS Independent",
19
+ "Intended Audience :: Science/Research",
20
+ "Topic :: Scientific/Engineering :: Bio-Informatics",
21
+ ]
22
+ dependencies = [
23
+ "numpy>=1.20.0",
24
+ "polars>=0.20.0",
25
+ "rapidfuzz>=3.0.0",
26
+ ]
27
+
28
+ [project.urls]
29
+ "Homepage" = "https://github.com/SweBiTS/JolTax"
30
+ "Bug Tracker" = "https://github.com/SweBiTS/JolTax/issues"
31
+ "Source Code" = "https://github.com/SweBiTS/JolTax"
32
+
33
+ [tool.setuptools]
34
+ packages = ["joltax"]
joltax-0.1.1/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,142 @@
1
+ import unittest
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ import polars as pl
6
+
7
+ # Add the project root to sys.path
8
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+ from joltax.joltree import JolTree
11
+
12
+ class TestJolTree(unittest.TestCase):
13
+ @classmethod
14
+ def setUpClass(cls):
15
+ cls.names_file = 'tests/data/names.dmp'
16
+ cls.nodes_file = 'tests/data/nodes.dmp'
17
+ # Check if files exist, if not, create them (should be copied already)
18
+ if not os.path.exists(cls.names_file):
19
+ raise FileNotFoundError(f"Missing test data: {cls.names_file}")
20
+
21
+ cls.tree = JolTree(nodes_file=cls.nodes_file, names_file=cls.names_file)
22
+
23
+ def test_lineage(self):
24
+ # 562 (E. coli) -> 561 (Escherichia) -> 543 -> 91347 -> 1236 -> 1224 -> 2 -> 1
25
+ lineage = self.tree.get_lineage(562)
26
+ expected = [1, 2, 1224, 1236, 91347, 543, 561, 562]
27
+ self.assertEqual(lineage, expected)
28
+
29
+ def test_clade(self):
30
+ # Clade of 561 (genus) should contain 561 and 562 (species)
31
+ clade = self.tree.get_clade(561)
32
+ self.assertIn(561, clade)
33
+ self.assertIn(562, clade)
34
+ self.assertEqual(len(clade), 2)
35
+
36
+ def test_lca(self):
37
+ # LCA of 562 and 561 is 561
38
+ lca = self.tree.get_lca(562, 561)
39
+ self.assertEqual(lca, 561)
40
+
41
+ # LCA of 562 and 2 (Bacteria) is 2
42
+ lca = self.tree.get_lca(562, 2)
43
+ self.assertEqual(lca, 2)
44
+
45
+ def test_distance(self):
46
+ # 562 to 561 is 1 step
47
+ self.assertEqual(self.tree.get_distance(562, 561), 1)
48
+ # 562 to 2 is 6 steps
49
+ self.assertEqual(self.tree.get_distance(562, 2), 6)
50
+
51
+ def test_get_name_and_rank(self):
52
+ self.assertEqual(self.tree.get_name(562), 'Escherichia coli')
53
+ self.assertEqual(self.tree.get_rank(562), 'species')
54
+ self.assertEqual(self.tree.get_name(2), 'Bacteria')
55
+ self.assertEqual(self.tree.get_rank(2), 'superkingdom')
56
+ # Test unknown
57
+ self.assertEqual(self.tree.get_name(999999), 'Unknown_999999')
58
+ self.assertEqual(self.tree.get_rank(999999), 'unknown')
59
+
60
+ def test_annotate_table(self):
61
+ tax_ids = [562, 561, 2]
62
+ df = self.tree.annotate_table(tax_ids)
63
+ self.assertIsInstance(df, pl.DataFrame)
64
+ self.assertEqual(len(df), 3)
65
+ self.assertIn('species', df.columns)
66
+ self.assertIn('genus', df.columns)
67
+
68
+ # Check first row (562)
69
+ row0 = df.row(0, named=True)
70
+ self.assertEqual(row0['species'], 'Escherichia coli')
71
+ self.assertEqual(row0['genus'], 'Escherichia')
72
+ self.assertEqual(row0['scientific_name'], 'Escherichia coli')
73
+
74
+ def test_name_search(self):
75
+ # Search by scientific name
76
+ df = self.tree.search_name('Escherichia coli')
77
+ self.assertIn(562, df['tax_id'].to_list())
78
+
79
+ # Search by common name
80
+ df = self.tree.search_name('all')
81
+ self.assertIn(1, df['tax_id'].to_list())
82
+
83
+ def test_fuzzy_search(self):
84
+ # Typo: "Escherchia"
85
+ df = self.tree.search_name('Escherchia', fuzzy=True)
86
+ self.assertIsInstance(df, pl.DataFrame)
87
+ self.assertTrue(len(df) > 0)
88
+ # Top result should be Escherichia or Escherichia coli
89
+ top_name = df.row(0, named=True)['matched_name']
90
+ self.assertIn('Escherichia', top_name)
91
+
92
+ def test_save_load(self):
93
+ import shutil
94
+ cache_dir = 'tests/cache_test'
95
+ if os.path.exists(cache_dir):
96
+ shutil.rmtree(cache_dir)
97
+
98
+ self.tree.save(cache_dir)
99
+ new_tree = JolTree.load(cache_dir)
100
+
101
+ self.assertEqual(new_tree.get_lineage(562), self.tree.get_lineage(562))
102
+ # Check name index loaded
103
+ df = new_tree.search_name('Escherichia coli')
104
+ self.assertIn(562, df['tax_id'].to_list())
105
+
106
+ shutil.rmtree(cache_dir)
107
+
108
+ def test_version_validation(self):
109
+ import shutil
110
+ import pickle
111
+ cache_dir = 'tests/version_test'
112
+ if os.path.exists(cache_dir):
113
+ shutil.rmtree(cache_dir)
114
+
115
+ self.tree.save(cache_dir)
116
+
117
+ # Manually corrupt metadata with old version
118
+ meta_path = os.path.join(cache_dir, "metadata.pkl")
119
+ with open(meta_path, 'rb') as f:
120
+ meta = pickle.load(f)
121
+
122
+ meta["provenance"]["package_version"] = "0.0.1" # Older than 0.1.0
123
+
124
+ with open(meta_path, 'wb') as f:
125
+ pickle.dump(meta, f)
126
+
127
+ # Should raise RuntimeError
128
+ with self.assertRaises(RuntimeError) as cm:
129
+ JolTree.load(cache_dir)
130
+
131
+ self.assertIn("Incompatible taxonomy cache", str(cm.exception))
132
+ shutil.rmtree(cache_dir)
133
+
134
+ if __name__ == '__main__':
135
+ # We skip tests if dependencies aren't installed
136
+ try:
137
+ import numpy
138
+ import polars
139
+ import rapidfuzz
140
+ unittest.main()
141
+ except ImportError:
142
+ print("Skipping tests due to missing dependencies.")