rdrpcatch 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rdrpcatch/__init__.py ADDED
File without changes
File without changes
rdrpcatch/cli/args.py ADDED
@@ -0,0 +1,358 @@
1
+ import warnings
2
+ # Filter numpy warnings before any imports that might trigger them
3
+ warnings.filterwarnings("ignore", category=UserWarning, module="numpy")
4
+ warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
5
+ warnings.filterwarnings("ignore", message=".*subnormal.*")
6
+
7
+ import rich_click as click
8
+ from rich.console import Console
9
+ from rich.table import Table
10
+ from rich.panel import Panel
11
+ from rich.syntax import Syntax
12
+ from rich.progress import Progress, BarColumn, TextColumn, DownloadColumn, TimeRemainingColumn
13
+ from pathlib import Path
14
+ import datetime
15
+ from ..rdrpcatch_wrapper import run_scan
16
+ from ..rdrpcatch_scripts.fetch_dbs import ZenodoDownloader, db_fetcher
17
+ import os
18
+ import shutil
19
+ import requests
20
+
21
+ console = Console()
22
+
23
+ ## FUNCTIONS
24
+ def parse_comma_separated_options(ctx, param, value):
25
+ if not value:
26
+ return ['all']
27
+
28
+ allowed_choices = ['RVMT', 'NeoRdRp', 'NeoRdRp.2.1', 'TSA_Olendraite_fam', 'TSA_Olendraite_gen', 'RDRP-scan',
29
+ 'Lucaprot', 'all']
30
+ lower_choices = [choice.lower() for choice in allowed_choices]
31
+ options = value.split(',')
32
+ lower_options = [option.lower() for option in options]
33
+
34
+ for option in options:
35
+ if option.lower() not in lower_choices:
36
+ raise click.BadParameter(f"Invalid choice: '{option}' (choose from {', '.join(allowed_choices)})")
37
+
38
+ return lower_options
39
+
40
+
41
+ def format_size(bytes_size: int) -> str:
42
+ """Convert bytes to human-readable format without external dependencies"""
43
+ units = ["B", "KB", "MB", "GB", "TB"]
44
+ unit_idx = 0
45
+ size = float(bytes_size)
46
+
47
+ while size >= 1024 and unit_idx < len(units) - 1:
48
+ size /= 1024
49
+ unit_idx += 1
50
+
51
+ return f"{size:.2f} {units[unit_idx]}"
52
+
53
+
54
+
55
+ ## CLI ENTRY POINT
56
+
57
+ @click.group()
58
+ def cli():
59
+ """RdRpCATCH - RNA-dependent RNA polymerase Collaborative Analysis Tool with Collections of pHMMs"""
60
+ pass
61
+
62
+ @cli.command("scan", help="Scan sequences for RdRps.")
63
+ @click.option("-i", "--input",
64
+ help="Path to the input FASTA file.",
65
+ type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path), required=True)
66
+ @click.option("-o", "--output",
67
+ help="Path to the output directory.",
68
+ type=click.Path(exists=False, file_okay=False, writable=True, path_type=Path), required=True)
69
+ @click.option("-db_dir", "--db_dir",
70
+ help="Path to the directory containing RdRpCATCH databases.",
71
+ type=click.Path(exists=True, dir_okay=True, readable=True, path_type=Path),required=True)
72
+ @click.option("-dbs", "--db_options",
73
+ callback=parse_comma_separated_options,
74
+ default="all",
75
+ help="Comma-separated list of databases to search against. Valid options: RVMT, NeoRdRp, NeoRdRp.2.1,"
76
+ " TSA_Olendraite_fam, TSA_Olendraite_gen, RDRP-scan,Lucaprot, all")
77
+ @click.option("--custom-dbs",
78
+ help="Path to directory containing custom MSAs/pHMM files to use as additional databases",
79
+ type=click.Path(exists=True, path_type=Path))
80
+ @click.option("-seq_type", "--seq_type",
81
+ type=click.STRING,
82
+ default=None,
83
+ help="Type of sequence to search against: (prot,nuc) Default: unknown")
84
+ @click.option("-v", "--verbose",
85
+ is_flag=True,
86
+ help="Print verbose output.")
87
+ @click.option('-e', '--evalue',
88
+ type=click.FLOAT,
89
+ default=1e-5,
90
+ help="E-value threshold for HMMsearch. (default: 1e-5)")
91
+ @click.option('-incE', '--incevalue',
92
+ type=click.FLOAT,
93
+ default=1e-5,
94
+ help="Inclusion E-value threshold for HMMsearch. (default: 1e-5)")
95
+ @click.option('-domE', '--domevalue',
96
+ type=click.FLOAT,
97
+ default=1e-5,
98
+ help="Domain E-value threshold for HMMsearch. (default: 1e-5)")
99
+ @click.option('-incdomE', '--incdomevalue',
100
+ type=click.FLOAT,
101
+ default=1e-5,
102
+ help="Inclusion domain E-value threshold for HMMsearch. (default: 1e-5)")
103
+ @click.option('-z', '--zvalue',
104
+ type=click.INT,
105
+ default=1000000,
106
+ help="Number of sequences to search against. (default: 1000000)")
107
+ @click.option('-cpus', '--cpus',
108
+ type=click.INT,
109
+ default=1,
110
+ help="Number of CPUs to use for HMMsearch. (default: 1)")
111
+ @click.option('-length_thr', '--length_thr',
112
+ type=click.INT,
113
+ default=400,
114
+ help="Minimum length threshold for seqkit seq. (default: 400)")
115
+ @click.option('-gen_code', '--gen_code',
116
+ type=click.INT,
117
+ default=1,
118
+ help='Genetic code to use for translation. (default: 1) Possible genetic codes (supported by seqkit translate) : 1: The Standard Code \n'
119
+ '2: The Vertebrate Mitochondrial Code \n'
120
+ '3: The Yeast Mitochondrial Code \n'
121
+ '4: The Mold, Protozoan, and Coelenterate Mitochondrial Code and the Mycoplasma/Spiroplasma Code \n'
122
+ '5: The Invertebrate Mitochondrial Code \n'
123
+ '6: The Ciliate, Dasycladacean and Hexamita Nuclear Code \n'
124
+ '9: The Echinoderm and Flatworm Mitochondrial Code \n'
125
+ '10: The Euplotid Nuclear Code \n'
126
+ '11: The Bacterial, Archaeal and Plant Plastid Code \n'
127
+ '12: The Alternative Yeast Nuclear Code \n'
128
+ '13: The Ascidian Mitochondrial Code \n'
129
+ '14: The Alternative Flatworm Mitochondrial Code \n'
130
+ '16: Chlorophycean Mitochondrial Code \n'
131
+ '21: Trematode Mitochondrial Code \n'
132
+ '22: Scenedesmus obliquus Mitochondrial Code \n'
133
+ '23: Thraustochytrium Mitochondrial Code \n'
134
+ '24: Pterobranchia Mitochondrial Code \n'
135
+ '25: Candidate Division SR1 and Gracilibacteria Code \n'
136
+ '26: Pachysolen tannophilus Nuclear Code \n'
137
+ '27: Karyorelict Nuclear \n'
138
+ '28: Condylostoma Nuclear \n'
139
+ '29: Mesodinium Nuclear \n'
140
+ '30: Peritrich Nuclear \n'
141
+ '31: Blastocrithidia Nuclear \n')
142
+ @click.option('-bundle', '--bundle',
143
+ is_flag=True,
144
+ default=False,
145
+ help="Bundle the output files into a single archive. (default: False)")
146
+ @click.option('-keep_tmp', '--keep_tmp',
147
+ is_flag=True,
148
+ default=False,
149
+ help="Keep temporary files (Expert users) (default: False)")
150
+ @click.pass_context
151
+ def scan(ctx, input, output, db_options, db_dir, custom_dbs, seq_type, verbose, evalue,
152
+ incevalue, domevalue, incdomevalue, zvalue, cpus, length_thr, gen_code, bundle, keep_tmp):
153
+ """Scan sequences for RdRps."""
154
+
155
+ # Create a rich table for displaying parameters
156
+ table = Table(title="Scan Parameters")
157
+ table.add_column("Parameter", style="cyan")
158
+ table.add_column("Value", style="green")
159
+
160
+ table.add_row("Input File", str(input))
161
+ table.add_row("Output Directory", str(output))
162
+ table.add_row("Databases", ", ".join(db_options))
163
+ table.add_row("Database Directory", str(db_dir))
164
+ if custom_dbs:
165
+ table.add_row("Custom Databases", str(custom_dbs))
166
+ table.add_row("Sequence Type", seq_type or "unknown")
167
+ table.add_row("Verbose Mode", "ON" if verbose else "OFF")
168
+ table.add_row("E-value", str(evalue))
169
+ table.add_row("Inclusion E-value", str(incevalue))
170
+ table.add_row("Domain E-value", str(domevalue))
171
+ table.add_row("Inclusion Domain E-value", str(incdomevalue))
172
+ table.add_row("Z-value", str(zvalue))
173
+ table.add_row("CPUs", str(cpus))
174
+ table.add_row("Length Threshold", str(length_thr))
175
+ table.add_row("Genetic Code", str(gen_code))
176
+ table.add_row("Bundle Output", "ON" if bundle else "OFF")
177
+ table.add_row("Save Temporary Files", "ON" if keep_tmp else "OFF")
178
+
179
+ console.print(Panel(table, title="Scan Configuration"))
180
+
181
+ # Add custom databases if provided
182
+ if custom_dbs:
183
+ db = db_fetcher(db_dir)
184
+ if os.path.isfile(custom_dbs):
185
+ db.add_custom_db(custom_dbs)
186
+ else:
187
+ for item in os.listdir(custom_dbs):
188
+ item_path = os.path.join(custom_dbs, item)
189
+ if os.path.isfile(item_path) and item_path.endswith(('.hmm', '.h3m', '.msa', '.sto', '.fasta', '.fa')):
190
+ db.add_custom_db(item_path)
191
+ elif os.path.isdir(item_path):
192
+ db.add_custom_db(item_path, item)
193
+
194
+ run_scan(
195
+ input_file=input,
196
+ output_dir=output,
197
+ db_options=db_options,
198
+ db_dir=db_dir,
199
+ seq_type=seq_type,
200
+ verbose=verbose,
201
+ e=evalue,
202
+ incE=incevalue,
203
+ domE=domevalue,
204
+ incdomE=incdomevalue,
205
+ z=zvalue,
206
+ cpus=cpus,
207
+ length_thr=length_thr,
208
+ gen_code=gen_code,
209
+ bundle=bundle,
210
+ keep_tmp=keep_tmp
211
+ )
212
+
213
+ # @cli.command("download", help="Download RdRpCATCH databases.")
214
+ # @click.option("--destination_dir", "-dest",
215
+ # help="Path to the directory to download HMM databases.",
216
+ # type=click.Path(exists=False, file_okay=False, writable=True, path_type=Path), required=True)
217
+ # @click.option("--check-updates", "-u",
218
+ # is_flag=True,
219
+ # help="Check for database updates")
220
+ # @click.pass_context
221
+ # def download(ctx, destination_dir, check_updates):
222
+ # """Download RdRpCATCH databases."""
223
+ #
224
+ # # if check_updates:
225
+ # # db = db_fetcher(destination_dir)
226
+ # # version_info = db.check_db_updates()
227
+ # # if version_info:
228
+ # # console.print("Current database versions:")
229
+ # # for db_name, info in version_info.items():
230
+ # # console.print(f"- {db_name}: {info}")
231
+ # # else:
232
+ # # console.print("No version information available")
233
+ # # return
234
+ #
235
+ # run_download(destination_dir)
236
+ #
237
+ # # @cli.command("gui", help="Launch the GUI.")
238
+ # # @click.pass_context
239
+ # # def gui(ctx):
240
+ # # """Launch the GUI."""
241
+ # #
242
+ # # console.print(Panel("Starting ColabScan GUI...", title="GUI Launch"))
243
+ # # run_gui()
244
+
245
+
246
+
247
+ @cli.command("download", help="Download & update RdRpCATCH databases. If databases are already installed in the "
248
+ "specified directory,"
249
+ " it will check for updates and download the latest version if available.")
250
+ @click.option("--destination_dir", "-dest",
251
+ help="Path to directory to download databases",
252
+ type=click.Path(path_type=Path, file_okay=False, writable=True),
253
+ required=True)
254
+ @click.option("--concept-doi", default="10.5281/zenodo.14358348",
255
+ help="Zenodo Concept DOI for database repository")
256
+ def download(destination_dir: Path, concept_doi: str):
257
+ """Handle database download/update workflow"""
258
+ downloader = ZenodoDownloader(concept_doi, destination_dir)
259
+
260
+ try:
261
+
262
+ current_version = downloader.get_current_version()
263
+ if downloader.lock_file.exists():
264
+ console.print("[red]× Another download is already in progress[/red]")
265
+ raise click.Abort()
266
+
267
+ if downloader.needs_update() or not current_version:
268
+ downloader.lock_file.touch(exist_ok=False)
269
+ with Progress(
270
+ TextColumn("[progress.description]{task.description}"),
271
+ BarColumn(),
272
+ TextColumn("{task.completed:.2f}/{task.total:.2f} MB"),
273
+ TimeRemainingColumn(),
274
+ transient=True
275
+ ) as progress:
276
+ # Setup main download task
277
+ main_task = progress.add_task("[cyan]Database Manager", total=4)
278
+
279
+ # Phase 1: Metadata fetching
280
+ progress.update(main_task, description="Fetching Zenodo metadata...")
281
+ metadata = downloader._fetch_latest_metadata()
282
+ progress.advance(main_task)
283
+
284
+ # Phase 2: Prepare download
285
+ progress.update(main_task, description="Analyzing package...")
286
+ tarball_info = downloader._get_tarball_info()
287
+ file_size_mb = tarball_info["size"] / (1024 * 1024)
288
+ progress.advance(main_task)
289
+
290
+ # Phase 3: Download with progress
291
+ progress.update(main_task,
292
+ description="Downloading RdRpCATCH databases...",
293
+ total=file_size_mb)
294
+
295
+ if not downloader.temp_dir.exists():
296
+ downloader.temp_dir.mkdir(parents=True, exist_ok=True)
297
+
298
+ temp_tar = downloader.temp_dir / "download.tmp"
299
+
300
+ with requests.get(tarball_info["url"], stream=True) as response:
301
+ response.raise_for_status()
302
+ with open(temp_tar, "wb") as f:
303
+ downloaded = 0
304
+ for chunk in response.iter_content(chunk_size=8192):
305
+ f.write(chunk)
306
+ downloaded += len(chunk)
307
+ progress.update(main_task, advance=len(chunk) / (1024 * 1024))
308
+
309
+ # Phase 4: Verification & installation
310
+ progress.update(main_task, description="Verifying checksum...")
311
+ if not downloader._verify_checksum(temp_tar, tarball_info["checksum"]):
312
+ raise ValueError("Checksum verification failed")
313
+
314
+ progress.update(main_task, description="Installing databases...")
315
+ downloader.extract_and_verify(temp_tar)
316
+ version_info = downloader.get_latest_version_info()
317
+ downloader.atomic_write_version(version_info)
318
+ progress.advance(main_task)
319
+
320
+ # Success message
321
+ size_str = format_size(tarball_info["size"])
322
+ console.print(
323
+ f"\n[bold green]✓ Successfully downloaded version {version_info['record_id']}[/bold green]",
324
+ f"Release date: {version_info['created']}",
325
+ f"Size: {size_str}",
326
+ sep="\n"
327
+ )
328
+
329
+ else:
330
+ installed_date = current_version["downloaded"]
331
+ console.print(
332
+ f"[green]✓ Databases are current[/green]",
333
+ f"Version ID: {current_version['record_id']}",
334
+ f"Installed: {installed_date}",
335
+ sep="\n"
336
+ )
337
+ except FileExistsError:
338
+ console.print("[red]× Another download is already in progress![/red]")
339
+ console.print(f"Lock file exists: {downloader.lock_file}")
340
+ raise click.Abort()
341
+
342
+ except Exception as e:
343
+ console.print(f"\n[red]× Download failed: {str(e)}[/red]")
344
+ if downloader.temp_dir.exists():
345
+ shutil.rmtree(downloader.temp_dir)
346
+ raise click.Abort()
347
+
348
+ finally:
349
+ # Cleanup operations
350
+ if downloader.lock_file.exists():
351
+ downloader.lock_file.unlink()
352
+ if downloader.temp_dir.exists():
353
+ shutil.rmtree(downloader.temp_dir)
354
+
355
+
356
+ if __name__ == '__main__':
357
+ cli(obj={})
358
+
File without changes
@@ -0,0 +1,302 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Dict, Optional
4
+ import requests
5
+
6
+ class db_fetcher:
7
+
8
+ def __init__(self,db_dir):
9
+ self.db_dir = db_dir
10
+ self.version_file = os.path.join(db_dir, "version.json")
11
+ self.custom_db_dir = os.path.join(db_dir, "custom_dbs")
12
+
13
+ def _get_db_version(self):
14
+ """Get database version information"""
15
+ import json
16
+
17
+ if os.path.exists(self.version_file):
18
+ with open(self.version_file) as f:
19
+ return json.loads(f.read())
20
+ return {}
21
+
22
+ def _save_db_version(self, version_info):
23
+ """Save database version information"""
24
+ import json
25
+
26
+ os.makedirs(os.path.dirname(self.version_file), exist_ok=True)
27
+ with open(self.version_file, 'w') as f:
28
+ json.dump(version_info, f, indent=2)
29
+
30
+ def check_db_updates(self):
31
+ """Check if database updates are available"""
32
+ current_version = self._get_db_version()
33
+ # TODO: Implement version checking against remote repository
34
+ # For now just return the current version
35
+ return current_version
36
+
37
+ def add_custom_db(self, db_path, db_name=None):
38
+ """Add a custom database (MSA or pHMM file) to the custom_dbs directory"""
39
+ import shutil
40
+ import datetime
41
+
42
+ if not os.path.exists(self.custom_db_dir):
43
+ os.makedirs(self.custom_db_dir)
44
+
45
+ if db_name is None:
46
+ db_name = os.path.basename(db_path)
47
+
48
+ target_path = os.path.join(self.custom_db_dir, db_name)
49
+
50
+ # Copy the database file
51
+ if os.path.isfile(db_path):
52
+ shutil.copy2(db_path, target_path)
53
+ elif os.path.isdir(db_path):
54
+ if os.path.exists(target_path):
55
+ shutil.rmtree(target_path)
56
+ shutil.copytree(db_path, target_path)
57
+
58
+ # Update version info
59
+ version_info = self._get_db_version()
60
+ version_info.setdefault('custom_dbs', {})
61
+ version_info['custom_dbs'][db_name] = {
62
+ 'added': datetime.datetime.now().isoformat(),
63
+ 'path': target_path
64
+ }
65
+ self._save_db_version(version_info)
66
+
67
+ def _resolve_rdrpcatch_path(self):
68
+ """Automatically detect correct database path structure"""
69
+ # Case 1: Direct rdrpcatch_dbs path
70
+ if self.db_dir.name == "rdrpcatch_dbs":
71
+ return self.db_dir
72
+
73
+ # Case 2: Parent directory containing rdrpcatch_dbs
74
+ candidate = self.db_dir / "rdrpcatch_dbs"
75
+ if candidate.exists() and candidate.is_dir():
76
+ return candidate
77
+
78
+ # Case 3: Already contains hmm_dbs subdirectory
79
+ hmm_check = self.db_dir / "hmm_dbs"
80
+ if hmm_check.exists():
81
+ return self.db_dir
82
+
83
+ # Case 4: Fallback to original path
84
+ return self.db_dir
85
+
86
+ def fetch_hmm_db_path(self, db_name):
87
+ """
88
+ Fetches HMM database from the RdRpCATCH repository or custom databases
89
+ """
90
+ if not os.path.exists(self.db_dir):
91
+ raise FileNotFoundError(f"db_dir does not exist {self.db_dir}")
92
+
93
+
94
+ # First check custom databases
95
+ if os.path.exists(self.custom_db_dir):
96
+ custom_path = os.path.join(self.custom_db_dir, db_name)
97
+ if os.path.exists(custom_path):
98
+ if os.path.isfile(custom_path) and custom_path.endswith(('.h3m', '.hmm')):
99
+ return os.path.splitext(custom_path)[0]
100
+ elif os.path.isdir(custom_path):
101
+ for file in os.listdir(custom_path):
102
+ if file.endswith(('.h3m', '.hmm')):
103
+ return os.path.splitext(os.path.join(custom_path, file))[0]
104
+
105
+ # Then check standard databases
106
+ db_path = None
107
+ db_dir = self._resolve_rdrpcatch_path()
108
+ for root,dirs,files in os.walk(db_dir):
109
+ for name in dirs:
110
+ if name == db_name:
111
+ for file in os.listdir(os.path.join(root, name)):
112
+ if file.endswith(".h3m"):
113
+ db_fn = file.rsplit(".", 1)[0]
114
+ db_path = os.path.join(root,name, db_fn)
115
+ else:
116
+ continue
117
+
118
+ if not db_path:
119
+ raise FileNotFoundError(f"{db_name} not found in {db_dir}")
120
+ else:
121
+ return db_path
122
+
123
+
124
+ def fetch_mmseqs_db_path(self, db_name):
125
+ """
126
+ Fetches MMseqs database from the RdRpCATCH repository
127
+ """
128
+ if not os.path.exists(self.db_dir):
129
+ raise FileNotFoundError(f"db_dir does not exist {self.db_dir}")
130
+
131
+ db_dir = self._resolve_rdrpcatch_path()
132
+
133
+ db_path = None
134
+ for root,dirs,files in os.walk(db_dir):
135
+ for dir in dirs:
136
+ if dir == "mmseqs_dbs":
137
+ for dir_ in os.listdir(os.path.join(root, dir)):
138
+ if dir_ == db_name:
139
+ for file in os.listdir(os.path.join(root, dir, dir_)):
140
+ if file.endswith(".lookup"):
141
+ db_fn = file.rsplit(".", 1)[0]
142
+ db_path = os.path.join(root, dir, dir_, db_fn)
143
+ else:
144
+ continue
145
+
146
+ if not db_path:
147
+ raise FileNotFoundError(f"{db_name} not found in {db_dir}")
148
+ else:
149
+ return db_path
150
+
151
+
152
+
153
+
154
+
155
+
156
+ class ZenodoDownloader:
157
+ """Handles Zenodo database downloads using record IDs for version tracking"""
158
+
159
+ def __init__(self, concept_doi: str, db_dir: Path):
160
+ self.concept_doi = concept_doi
161
+ self.db_dir = db_dir
162
+ self.temp_dir = db_dir / "temp"
163
+ self.lock_file = db_dir / ".lock"
164
+ self.version_file = db_dir / "version.json"
165
+ self._api_base = "https://zenodo.org/api/records"
166
+
167
+ self.db_dir.mkdir(parents=True, exist_ok=True)
168
+
169
+ def _get_record_id(self) -> str:
170
+ """Extract numeric concept ID from Concept DOI"""
171
+ return self.concept_doi.split(".")[-1]
172
+
173
+ def _fetch_latest_metadata(self) -> Dict:
174
+ """Retrieve latest version metadata from Zenodo API"""
175
+
176
+ response = requests.get(f"{self._api_base}/{self._get_record_id()}")
177
+ response.raise_for_status()
178
+ return response.json()
179
+
180
+ def get_latest_version_info(self) -> Dict:
181
+ """Get version information from Zenodo metadata"""
182
+ metadata = self._fetch_latest_metadata()
183
+ return {
184
+ "record_id": str(metadata["id"]),
185
+ "doi": metadata["metadata"]["doi"],
186
+ "created": metadata["metadata"]["publication_date"],
187
+ "conceptdoi": self.concept_doi
188
+ }
189
+
190
+ def _get_tarball_info(self) -> Dict:
191
+ """Find database tarball in Zenodo files"""
192
+
193
+ metadata = self._fetch_latest_metadata()
194
+ for file_info in metadata.get("files", []):
195
+ if file_info["key"].endswith(".tar"):
196
+ return {
197
+ "url": file_info["links"]["self"],
198
+ "checksum": file_info["checksum"],
199
+ "size": file_info["size"]
200
+ }
201
+ raise ValueError("No database tarball found in Zenodo record")
202
+
203
+
204
+ def _verify_checksum(self, file_path: Path, expected: str) -> bool:
205
+ """Validate file checksum (supports MD5 and SHA-256)"""
206
+ import hashlib
207
+
208
+ algorithm, _, expected_hash = expected.partition(":")
209
+ hasher = hashlib.new(algorithm)
210
+
211
+ with open(file_path, "rb") as f:
212
+ while chunk := f.read(8192):
213
+ hasher.update(chunk)
214
+
215
+ return hasher.hexdigest() == expected_hash
216
+
217
+ def extract_and_verify(self, tar_path: Path) -> None:
218
+ """Safely extract tarball with proper directory structure handling"""
219
+ import tarfile
220
+ import tempfile
221
+ import shutil
222
+
223
+
224
+ with tempfile.TemporaryDirectory(dir=self.db_dir, prefix=".tmp_") as tmp_extract:
225
+ tmp_extract = Path(tmp_extract)
226
+
227
+ # Extract to temporary subdirectory
228
+ with tarfile.open(tar_path, "r") as tar:
229
+ # Validate all members first
230
+ for member in tar.getmembers():
231
+ if ".." in member.path:
232
+ raise ValueError(f"Invalid path in archive: {member.path}")
233
+ tar.extractall(tmp_extract)
234
+
235
+ # Handle nested directory structure
236
+ extracted_items = list(tmp_extract.iterdir())
237
+ if len(extracted_items) == 1 and extracted_items[0].is_dir() and extracted_items[0].name == "rdrpcatch_dbs":
238
+ # If archive contains single rdrpcatch_dbs directory, move contents up
239
+ nested_dir = tmp_extract / "rdrpcatch_dbs"
240
+ for item in nested_dir.iterdir():
241
+ shutil.move(str(item), str(tmp_extract))
242
+ nested_dir.rmdir()
243
+
244
+
245
+ # Prepare paths for atomic replacement
246
+ target = self.db_dir / "rdrpcatch_dbs"
247
+ backup = self.db_dir / "rdrpcatch_dbs.bak"
248
+
249
+ # Atomic replacement sequence
250
+ try:
251
+ if backup.exists():
252
+ shutil.rmtree(backup)
253
+
254
+ if target.exists():
255
+ target.rename(backup)
256
+
257
+ tmp_extract.rename(target)
258
+
259
+ finally:
260
+ if backup.exists() and target.exists():
261
+ shutil.rmtree(backup)
262
+
263
+ def needs_update(self) -> bool:
264
+ """Check if local databases are outdated using record ID"""
265
+ import json
266
+
267
+ if not self.version_file.exists():
268
+ return True
269
+
270
+ try:
271
+ with open(self.version_file, "r") as f:
272
+ local_version = json.load(f)
273
+ remote_version = self.get_latest_version_info()
274
+ return remote_version["record_id"] != local_version["record_id"]
275
+ except (json.JSONDecodeError, KeyError):
276
+ return True
277
+
278
+
279
+ def atomic_write_version(self, version_info: Dict) -> None:
280
+ """Safely update version file with download timestamp"""
281
+ import json
282
+ import datetime
283
+
284
+ temp_version = self.version_file.with_suffix(".tmp")
285
+
286
+ # Add timestamp BEFORE writing to file
287
+ version_info["downloaded"] = datetime.datetime.utcnow().isoformat()
288
+ # Using ISO 8601 with timezone
289
+
290
+ with open(temp_version, "w") as f:
291
+ json.dump(version_info, f, indent=2)
292
+
293
+ os.replace(temp_version, self.version_file)
294
+
295
+ def get_current_version(self) -> Optional[Dict]:
296
+ """Read installed database version info"""
297
+ import json
298
+
299
+ if self.version_file.exists():
300
+ with open(self.version_file, "r") as f:
301
+ return json.load(f)
302
+ return None