zipstrain 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zipstrain/database.py ADDED
@@ -0,0 +1,871 @@
1
+ """zipstrain.database
2
+ ========================
3
+ This module provides classes and functions to manage profile and comparison databases for efficient data handling.
4
+ The ProfileDatabase class manages profiles, while the GenomeComparisonDatabase class handles comparisons between profiles.
5
+ See the documentation of each class for more details.
6
+
7
+ """
8
+ from __future__ import annotations
9
+ import polars as pl
10
+ import pathlib
11
+ import os
12
+ import tempfile
13
+ import json
14
+ import copy
15
+ from pydantic import BaseModel, Field, field_validator,ConfigDict
16
+
17
+
18
+ class ProfileItem(BaseModel):
19
+ """
20
+ This class describes all necessary attributes of a profile and makes sure they comply with the necessary formating.
21
+ """
22
+ model_config = ConfigDict(extra="forbid")
23
+ profile_name: str = Field(description="An arbitrary name given to the profile (Usually sample name or name of the parquet file)")
24
+ profile_location: str = Field(description="The location of the profile")
25
+ scaffold_location: str = Field(description="The location of the scaffold")
26
+ reference_db_id: str = Field(description="The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.")
27
+ gene_db_id:str= Field(default="",description="The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.")
28
+
29
+ @field_validator("profile_location","scaffold_location")
30
+ def check_file_exists(cls, v):
31
+ if not os.path.exists(v):
32
+ raise ValueError(f"The file {v} does not exist.")
33
+ return v
34
+
35
+ @field_validator("reference_db_id","gene_db_id")
36
+ def check_reference_db_id(cls, v):
37
+ if not v:
38
+ raise ValueError("The reference_db_id and gene_db_id cannot be empty.")
39
+ return v
40
+
41
+
42
+
43
+ class ProfileDatabase:
44
+ """
45
+ The profile database simply holds profile information. Does not need to be specific to a comparison database.
46
+ The data behind a profile is stored in a parquet file. It is basically a table with the following columns:
47
+
48
+ - profile_name: An arbitrary name given to the profile (Usually sample name or name of the parquet file)
49
+
50
+ - profile_location: The location of the profile
51
+
52
+ - scaffold_location: The location of the scaffold
53
+
54
+ - reference_db_id: The ID of the reference database. This could be the name or any other identifier for the database that the reads are mapped to.
55
+
56
+ - gene_db_id: The ID of the gene database in fasta format. This could be the name or any other identifier for the database that the reads are mapped to.
57
+
58
+ Args:
59
+ db_loc (str|None): The location of the profile database parquet file. If None, an empty database is created.
60
+
61
+ """
62
+ def __init__(self,
63
+ db_loc: str|None = None,
64
+ ):
65
+ if db_loc is not None:
66
+ self.db_loc = pathlib.Path(db_loc)
67
+ self._db = pl.scan_parquet(self.db_loc)
68
+ else:
69
+ self._db=pl.LazyFrame({
70
+ "profile_name": [],
71
+ "profile_location": [],
72
+ "scaffold_location": [],
73
+ "reference_db_id": [],
74
+ "gene_db_id": []
75
+ }, schema={
76
+ "profile_name": pl.Utf8,
77
+ "profile_location": pl.Utf8,
78
+ "scaffold_location": pl.Utf8,
79
+ "reference_db_id": pl.Utf8,
80
+ "gene_db_id": pl.Utf8
81
+ })
82
+ self.db_loc=None
83
+
84
+ @property
85
+ def db(self):
86
+ return self._db
87
+
88
+ def _validate_db(self,check_profile_exists: bool=True,check_scaffold_exists:bool=True)->None:
89
+ """Simple method to see if the database has the minimum required structure."""
90
+
91
+ ### Next check if the database has the required columns
92
+ required_columns = ["profile_name","profile_location", "scaffold_location", "reference_db_id", "gene_db_id"]
93
+ for col in required_columns:
94
+ if col not in self.db.collect_schema().names():
95
+ raise ValueError(f"Missing required column: {col}")
96
+
97
+ if check_profile_exists:
98
+ # Check if the profile exists in the database
99
+ db_path_validated= self.db.select(pl.col("profile_location")).collect(engine="streaming").with_columns(
100
+ (pl.col("profile_location").map_elements(lambda x: pathlib.Path(x).exists(),return_dtype=pl.Boolean)).alias("profile_exists")
101
+ ).filter(~ pl.col("profile_exists"))
102
+ if db_path_validated.height != 0:
103
+ raise ValueError(f"There are {db_path_validated.height} profiles that do not exist: {db_path_validated['profile_location'].to_list()}")
104
+ ### add log later
105
+ if check_scaffold_exists:
106
+ db_path_validated= self.db.select(pl.col("scaffold_location")).collect(engine="streaming").with_columns(
107
+ (pl.col("scaffold_location").map_elements(lambda x: pathlib.Path(x).exists(),return_dtype=pl.Boolean)).alias("scaffold_exists")
108
+ ).filter(~ pl.col("scaffold_exists"))
109
+ if db_path_validated.height != 0:
110
+ raise ValueError(f"There are {db_path_validated.height} scaffolds that do not exist: {db_path_validated['scaffold_location'].to_list()}")
111
+ ### add log later
112
+
113
+ def add_profile(self,
114
+ data: dict
115
+ ) -> None:
116
+ """Add a profile to the database.
117
+ The data dictionary must contain the following and only the following keys:
118
+
119
+ - profile_name
120
+
121
+ - profile_location
122
+
123
+ - scaffold_location
124
+
125
+ - reference_db_id
126
+
127
+ - gene_db_id
128
+
129
+ Args:
130
+ data (dict): The profile data to add.
131
+ """
132
+ try:
133
+ profile_item = ProfileItem(**data)
134
+ lf=pl.LazyFrame({
135
+ "profile_name": [profile_item.profile_name],
136
+ "profile_location": [profile_item.profile_location],
137
+ "scaffold_location": [profile_item.scaffold_location],
138
+ "reference_db_id": [profile_item.reference_db_id],
139
+ "gene_db_id": [profile_item.gene_db_id]
140
+ })
141
+ self._db = pl.concat([self.db, lf]).unique()
142
+ self._validate_db()
143
+ except Exception as e:
144
+ raise ValueError(f"The profile data provided is not valid: {e}")
145
+
146
+
147
+ def add_database(self, profile_database: ProfileDatabase) -> None:
148
+ """Merge the provided profile database into the current database.
149
+
150
+ Args:
151
+ profile_database (ProfileDatabase): The profile database to merge.
152
+ """
153
+ try:
154
+ profile_database._validate_db()
155
+
156
+ except Exception as e:
157
+ raise ValueError(f"The profile database provided is not valid: {e}")
158
+
159
+ self._db = pl.concat([self._db, profile_database.db]).unique()
160
+
161
+
162
+ def save_as_new_database(self, output_path: str) -> None:
163
+ """Save the database to a parquet file.
164
+
165
+ Args:
166
+ output_path (str): The path to save the database to.
167
+ """
168
+ #The new database must be written to a new location
169
+ if self.db_loc is not None and str(self.db_loc.absolute()) == str(pathlib.Path(output_path).absolute()):
170
+ raise ValueError("The output path must be different from the current database location.")
171
+
172
+ try:
173
+ self.db.sink_parquet(output_path)
174
+ self.db_loc=pathlib.Path(output_path)
175
+ ### add log later
176
+ except Exception as e:
177
+ pass
178
+ ### add log later
179
+
180
+ def update_database(self)->None:
181
+ """Overwrites the database saved on the disk to the current database object
182
+ """
183
+ if self.db_loc is None:
184
+ raise Exception("db_loc attribute is not determined yet!")
185
+ try:
186
+ self.db.collect(engine="streaming").write_parquet(self.db_loc)
187
+ except Exception as e:
188
+ raise Exception(f"Something went wrong when updating the database:{e}")
189
+
190
+
191
+ @classmethod
192
+ def from_csv(cls, csv_path: str) -> ProfileDatabase:
193
+ """Create a ProfileDatabase instance from a CSV file with exactly same columns as the required columns for a profile database.
194
+
195
+ Args:
196
+ csv_path (str): The path to the CSV file.
197
+
198
+ Returns:
199
+ ProfileDatabase: The created ProfileDatabase instance.
200
+ """
201
+ lf=pl.scan_csv(csv_path).collect().lazy() # To avoid clash when using to_csv on same file
202
+ prof_db=cls()
203
+ prof_db._db=lf
204
+ prof_db._validate_db()
205
+ return prof_db
206
+
207
+ def to_csv(self,output_dir:str)->None:
208
+ """Writes the the current database object to a csv file"
209
+
210
+ Args:
211
+ output_dir (str): The path to save the CSV file.
212
+
213
+ Returns:
214
+ None
215
+ """
216
+ self.db.sink_csv(output_dir,engine="streaming")
217
+
218
+
219
+
220
+
221
+
222
+
223
+ class GenomeComparisonConfig(BaseModel):
224
+ """
225
+ This class defines object which have all necessary options to describe
226
+ Parameters used to compare profiles:
227
+
228
+ Attributes:
229
+ gene_db_id (str): The ID of the gene fasta database to use for the comparison. The file name is perfect.
230
+ reference_id (str): The ID of the reference fasta database to use for the comparison. The file name is perfect.
231
+ scope (str): The scope of the comparison- 'all' if all covered positions are desired. Otherwise, a bunch of genome names separated by commas.
232
+ min_cov (int): Minimum coverage a base on the reference fasta that must have in order to be compared.
233
+ null_model_p_value(float): P_value above which a base call is counted as sequencing error
234
+ min_gene_compare_len (int): Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations
235
+ stb_file_loc (str): The location of the scaffold to bin file.
236
+ null_model_loc (str): The location of the null model file.
237
+ """
238
+ model_config = ConfigDict(extra="forbid")
239
+ gene_db_id:str= Field(default="",description="An ID given to the gene fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with gene database IDs in the Profile Database.")
240
+ reference_id:str= Field(description="An ID given to the reference fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with reference IDs in the Profile Database.")
241
+ scope: str =Field(description="An ID given to the reference fasta file used for profiling. IMPORTANT: Make sure that this is in agreement with reference IDs in the Profile Database.")
242
+ min_cov: int =Field(description="Minimum coverage a base on the reference fasta that must have in order to be compared.")
243
+ min_gene_compare_len: int=Field(description="Minimum length of a gene that needs to be covered at min_cov to be considered for gene similarity calculations")
244
+ null_model_p_value:float=Field(default=0.05,description="P_value above which a base call is counted as sequencing error")
245
+ stb_file_loc:str=Field(description="The location of the scaffold to bin file.")
246
+ null_model_loc:str=Field(description="The location of the null model file.")
247
+
248
+ def is_compatible(self, other: GenomeComparisonConfig) -> bool:
249
+ """
250
+ Check if this comparison configuration is compatible with another. Two configurations are compatible if they have the same parameters, except for scope.
251
+ Scope can be different as long as they are not disjoint. Also, all is compatible with any scope.
252
+ Args:
253
+ other (GenomeComparisonConfig): The other comparison configuration to check compatibility with.
254
+ Returns:
255
+ bool: True if the configurations are compatible, False otherwise.
256
+ """
257
+ attrs=self.__dict__
258
+ for key in attrs:
259
+ if key!="scope":
260
+ if attrs[key] != getattr(other, key):
261
+ return False
262
+ if other.scope != "all" and self.scope != "all":
263
+ if (set(other.scope.split(",")).intersection(set(self.scope.split(","))) == set()):
264
+ return False
265
+ return True
266
+
267
+ @classmethod
268
+ def from_json(cls,json_file_dir:str)->GenomeComparisonConfig:
269
+ """Create a GenomeComparisonConfig instance from a json file."""
270
+ with open(json_file_dir, 'r') as f:
271
+ config_dict = json.load(f)
272
+ return cls(**config_dict)
273
+
274
+ def to_json(self,json_file_dir:str)->None:
275
+ """Writes the the current object to a json file"""
276
+ with open(json_file_dir,"w") as f:
277
+ json.dump(self.__dict__,f)
278
+
279
+ def to_dict(self)->dict:
280
+ """Returns the dictionary representation of the current object"""
281
+ return copy.copy(self.__dict__)
282
+
283
+
284
+ def get_maximal_scope_config(self, other: GenomeComparisonConfig) -> GenomeComparisonConfig:
285
+ """
286
+ Get a new GenomeComparisonConfig object with the maximal scope that is compatible with the two configurations.
287
+ Args:
288
+ other (GenomeComparisonConfig): The other comparison configuration to get the maximal scope with.
289
+ Returns:
290
+ GenomeComparisonConfig: The new comparison configuration with the maximal scope.
291
+ """
292
+ if not self.is_compatible(other):
293
+ raise ValueError("The two comparison configurations are not compatible.")
294
+
295
+ new_scope=None
296
+ if other.scope == "all" and self.scope == "all":
297
+ new_scope="all"
298
+
299
+ elif other.scope == "all":
300
+ new_scope=self.scope.split(",")
301
+
302
+ elif self.scope == "all":
303
+ new_scope=other.scope.split(",")
304
+
305
+ else:
306
+ new_scope=list(set(self.scope.split(",")).intersection(set(other.scope.split(","))))
307
+ curr_config_dict=self.to_dict()
308
+ curr_config_dict["scope"]=new_scope if new_scope=="all" else ",".join(sorted(new_scope))
309
+ return GenomeComparisonConfig(**curr_config_dict)
310
+
311
+ class GeneComparisonConfig(BaseModel):
312
+ """
313
+ Configuration for gene-level comparisons between profiles.
314
+
315
+ Attributes:
316
+ scope (str): The scope of the comparison in format "GENOME:GENE" (e.g., "all:gene1" compares gene1 across all genomes, "genome1:gene1" compares gene1 only in genome1 across samples).
317
+ null_model_loc (str): Location of the null model parquet file.
318
+ stb_file_loc (str): Location of the scaffold-to-genome mapping file.
319
+ min_cov (int): Minimum coverage threshold for considering a position.
320
+ min_gene_compare_len (int): Minimum gene length required for comparison.
321
+ """
322
+ model_config = ConfigDict(extra="forbid")
323
+ scope: str = Field(description="Scope in format GENOME:GENE (e.g., 'all:gene1', 'genome1:gene1')")
324
+ null_model_loc: str = Field(description="Location of the null model parquet file")
325
+ stb_file_loc: str = Field(description="Location of the scaffold-to-genome mapping file")
326
+ min_cov: int = Field(default=5, description="Minimum coverage threshold")
327
+ min_gene_compare_len: int = Field(default=100, description="Minimum gene length for comparison")
328
+
329
+ @field_validator("scope")
330
+ @classmethod
331
+ def validate_scope(cls, v: str) -> str:
332
+ """Validate that scope follows GENOME:GENE format."""
333
+ if ":" not in v:
334
+ raise ValueError("Scope must be in format 'GENOME:GENE' (e.g., 'all:gene1' or 'genome1:gene1')")
335
+ parts = v.split(":")
336
+ if len(parts) != 2:
337
+ raise ValueError("Scope must have exactly one ':' separator")
338
+ genome_part, gene_part = parts
339
+ if not genome_part or not gene_part:
340
+ raise ValueError("Both genome and gene parts must be non-empty")
341
+ return v
342
+
343
+ def is_compatible(self, other: GeneComparisonConfig) -> bool:
344
+ """
345
+ Check if this gene comparison configuration is compatible with another.
346
+ Two configurations are compatible if they have the same parameters, except for scope.
347
+ Scope can be different as long as they are not disjoint. Also, 'all' is compatible with any scope.
348
+
349
+ Args:
350
+ other (GeneComparisonConfig): The other gene comparison configuration to check compatibility with.
351
+ """
352
+ attrs=self.__dict__
353
+ for key in attrs:
354
+ if key!="scope":
355
+ if attrs[key] != getattr(other, key):
356
+ return False
357
+ self_genome_scope, self_gene_scope = self.scope.split(":")
358
+ other_genome_scope, other_gene_scope = other.scope.split(":")
359
+ if self_genome_scope == "all" or other_genome_scope == "all":
360
+ return True
361
+ return self_genome_scope == other_genome_scope and self_gene_scope == other_gene_scope
362
+
363
+ @classmethod
364
+ def from_json(cls,json_file_dir:str)->GeneComparisonConfig:
365
+ """Create a GeneComparisonConfig instance from a json file."""
366
+ with open(json_file_dir, 'r') as f:
367
+ config_dict = json.load(f)
368
+ return cls(**config_dict)
369
+
370
+ def to_json(self,json_file_dir:str)->None:
371
+ """Writes the the current object to a json file"""
372
+ with open(json_file_dir,"w") as f:
373
+ json.dump(self.__dict__,f)
374
+
375
+ def to_dict(self)->dict:
376
+ """Returns the dictionary representation of the current object"""
377
+ return copy.copy(self.__dict__)
378
+
379
+ def get_maximal_scope_config(self, other: GeneComparisonConfig) -> GeneComparisonConfig:
380
+ """
381
+ Get a new GeneComparisonConfig object with the maximal scope that is compatible with the two configurations.
382
+
383
+ Args:
384
+ other (GeneComparisonConfig): The other gene comparison configuration to get the maximal scope with.
385
+
386
+ Returns:
387
+ GeneComparisonConfig: The new gene comparison configuration with the maximal scope.
388
+ """
389
+ if not self.is_compatible(other):
390
+ raise ValueError("The two comparison configurations are not compatible.")
391
+
392
+ self_genome_scope, self_gene_scope = self.scope.split(":")
393
+ other_genome_scope, other_gene_scope = other.scope.split(":")
394
+
395
+ if self_genome_scope == "all" and other_genome_scope == "all":
396
+ new_genome_scope = "all"
397
+ elif self_genome_scope == "all":
398
+ new_genome_scope = other_genome_scope
399
+ elif other_genome_scope == "all":
400
+ new_genome_scope = self_genome_scope
401
+ else:
402
+ new_genome_scope = self_genome_scope # They must be equal if compatible
403
+
404
+ if self_gene_scope == "all" and other_gene_scope == "all":
405
+ new_gene_scope = "all"
406
+ elif self_gene_scope == "all":
407
+ new_gene_scope = other_gene_scope
408
+ elif other_gene_scope == "all":
409
+ new_gene_scope = self_gene_scope
410
+ else:
411
+ new_gene_scope = self_gene_scope # They must be equal if compatible
412
+
413
+ curr_config_dict=self.to_dict()
414
+ curr_config_dict["scope"]=f"{new_genome_scope}:{new_gene_scope}"
415
+ return GeneComparisonConfig(**curr_config_dict)
416
+
417
+
418
+ class GenomeComparisonDatabase:
419
+ """
420
+ GenomeComparisonDatabase object holds a reference to a comparison parquet file. The methods in this class serve to provide
421
+ functionality for working with the comparison data in an easy and efficient manner.
422
+ The comparison parquet file the result of running compare, and optionally concatenating multiple compare parquet file from single comparisons.
423
+ This parquet file must contain the following columns:
424
+
425
+ - genome
426
+
427
+ - total_positions
428
+
429
+ - share_allele_pos
430
+
431
+ - genome_pop_ani
432
+
433
+ - max_consecutive_length
434
+
435
+ - shared_genes_count
436
+
437
+ - identical_gene_count
438
+
439
+ - sample_1
440
+
441
+ - sample_2
442
+
443
+ A ComparisonDatabase object needs a ComparisonConfig object to specify the parameters used for the comparison.
444
+
445
+ Args:
446
+ profile_db (ProfileDatabase): The profile database used for the comparison.
447
+ config (GenomeComparisonConfig): The comparison configuration used for the comparison.
448
+ comp_db_loc (str|None): The location of the comparison database parquet file. If
449
+ None, an empty comparison database is created.
450
+
451
+ """
452
+ COLUMN_NAMES = [
453
+ "genome",
454
+ "total_positions",
455
+ "share_allele_pos",
456
+ "genome_pop_ani",
457
+ "max_consecutive_length",
458
+ "shared_genes_count",
459
+ "identical_gene_count",
460
+ "perc_id_genes",
461
+ "sample_1",
462
+ "sample_2"
463
+ ]
464
+
465
+ def __init__(self,
466
+ profile_db: ProfileDatabase,
467
+ config: GenomeComparisonConfig,
468
+ comp_db_loc: str|None = None,
469
+ ):
470
+ self.profile_db = profile_db
471
+ self.config = config
472
+ if comp_db_loc is not None:
473
+ self.comp_db_loc = pathlib.Path(comp_db_loc)
474
+ self._comp_db = pl.scan_parquet(self.comp_db_loc)
475
+ else:
476
+ self.comp_db_loc = None
477
+ self._comp_db=pl.LazyFrame({
478
+ "genome": [],
479
+ "total_positions": [],
480
+ "share_allele_pos": [],
481
+ "genome_pop_ani": [],
482
+ "max_consecutive_length": [],
483
+ "shared_genes_count": [],
484
+ "identical_gene_count": [],
485
+ "perc_id_genes": [],
486
+ "sample_1": [],
487
+ "sample_2": []
488
+ }, schema={
489
+ "genome": pl.Utf8,
490
+ "total_positions": pl.Int64,
491
+ "share_allele_pos": pl.Int64,
492
+ "genome_pop_ani": pl.Float64,
493
+ "max_consecutive_length": pl.Int64,
494
+ "shared_genes_count": pl.Int64,
495
+ "identical_gene_count": pl.Int64,
496
+ "perc_id_genes": pl.Float64,
497
+ "sample_1": pl.Utf8,
498
+ "sample_2": pl.Utf8
499
+ })
500
+ self.comp_db_loc=None
501
+
502
+ @property
503
+ def comp_db(self):
504
+ return self._comp_db
505
+
506
+ def _validate_db(self)->None:
507
+ self.profile_db._validate_db()
508
+
509
+ if set(self._comp_db.collect_schema()) != set(self.COLUMN_NAMES):
510
+ raise ValueError(f"Your comparison database must provide these extra columns: { set(self.COLUMN_NAMES)-set(self._comp_db.collect_schema())}")
511
+ #check if all profile names exist in the profile database
512
+ profile_names_in_comp_db = set(self.get_all_profile_names())
513
+ profile_names_in_profile_db = set(self.profile_db.db.select("profile_name").collect(engine="streaming").to_series().to_list())
514
+ if not profile_names_in_comp_db.issubset(profile_names_in_profile_db):
515
+ raise ValueError(f"The following profile names are in the comparison database but not in the profile database: {profile_names_in_comp_db - profile_names_in_profile_db}")
516
+
517
+ def get_all_profile_names(self) -> set[str]:
518
+ """
519
+ Get all profile names that are in the comparison database.
520
+ """
521
+ return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
522
+ set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
523
+ )
524
+ def get_remaining_pairs(self) -> pl.LazyFrame:
525
+ """
526
+ Get pairs of profiles that are in the profile database but not in the comparison database.
527
+ """
528
+ profiles = self.profile_db.db.select("profile_name")
529
+ pairs=profiles.join(profiles,how="cross").rename({"profile_name":"profile_1","profile_name_right":"profile_2"}).filter(pl.col("profile_1")<pl.col("profile_2"))
530
+ samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"), pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")).select(["profile_1", "profile_2"])
531
+
532
+ remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1","profile_2"])
533
+ return remaining_pairs
534
+
535
+ def is_complete(self) -> bool:
536
+ """
537
+ Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.
538
+ """
539
+ return self.get_remaining_pairs().collect(engine="streaming").is_empty()
540
+
541
+ def add_comp_database(self, comp_database: GenomeComparisonDatabase) -> None:
542
+ """Merge the provided comparison database into the current database.
543
+
544
+ Args:
545
+ comp_database (ComparisonDatabase): The comparison database to merge.
546
+ """
547
+ try:
548
+ comp_database._validate_db()
549
+
550
+ except Exception as e:
551
+ raise ValueError(f"The comparison database provided is not valid: {e}")
552
+
553
+ if not self.config.is_compatible(comp_database.config):
554
+ raise ValueError("The comparison database provided is not compatible with the current comparison database.")
555
+
556
+ self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
557
+ self.config = self.config.get_maximal_scope_config(comp_database.config)
558
+
559
+
560
+ def save_new_compare_database(self, output_path: str) -> None:
561
+ """Save the database to a parquet file."""
562
+ output_path = pathlib.Path(output_path)
563
+ output_path.parent.mkdir(parents=True, exist_ok=True)
564
+
565
+ # The new database must be written to a new location
566
+ if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
567
+ raise ValueError("The output path must be different from the current database location.")
568
+
569
+ self.comp_db.sink_parquet(output_path)
570
+
571
+
572
+ def update_compare_database(self)->None:
573
+ """Overwrites the comparison database saved on the disk to the current comparison database object
574
+ """
575
+ if self.comp_db_loc is None:
576
+ raise Exception("comp_db_loc attribute is not determined yet!")
577
+ try:
578
+ tmp_path=pathlib.Path(tempfile.mktemp(suffix=".parquet",prefix="tmp_comp_db_",dir=str(self.comp_db_loc.parent)))
579
+ self.comp_db.sink_parquet(tmp_path)
580
+ os.replace(tmp_path,self.comp_db_loc)
581
+ self._comp_db=pl.scan_parquet(self.comp_db_loc)
582
+ except Exception as e:
583
+ raise Exception(f"Something went wrong when updating the comparison database:{e}")
584
+
585
+ def dump_obj(self, output_path: str) -> None:
586
+ """Dump the current object to a json file.
587
+
588
+ Args:
589
+ output_path (str): The path to save the json file to.
590
+ """
591
+ obj_dict = {
592
+ "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
593
+ "config": self.config.to_dict(),
594
+ "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
595
+ }
596
+ with open(output_path, "w") as f:
597
+ json.dump(obj_dict, f, indent=4)
598
+
599
+ @classmethod
600
+ def load_obj(cls, json_path: str) -> GenomeComparisonDatabase:
601
+ """Load a GenomeComparisonDatabase object from a json file.
602
+
603
+ Args:
604
+ json_path (str): The path to the json file.
605
+
606
+ Returns:
607
+ GenomeComparisonDatabase: The loaded GenomeComparisonDatabase object.
608
+ """
609
+ with open(json_path, "r") as f:
610
+ obj_dict = json.load(f)
611
+
612
+ return cls(profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]) ,
613
+ config=GenomeComparisonConfig(**obj_dict["config"]),
614
+ comp_db_loc=obj_dict["comp_db_loc"])
615
+
616
+
617
+ def to_complete_input_table(self)->pl.LazyFrame:
618
+ """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete. The table contains the following columns:
619
+
620
+ - sample_name_1
621
+
622
+ - sample_name_2
623
+
624
+ - profile_location_1
625
+
626
+ - scaffold_location_1
627
+
628
+ - profile_location_2
629
+
630
+ - scaffold_location_2
631
+
632
+ Returns:
633
+ pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
634
+ """
635
+ lf=self.get_remaining_pairs().rename({"profile_1":"sample_name_1","profile_2":"sample_name_2"})
636
+ return (lf.join(self.profile_db.db.select(["profile_name","profile_location","scaffold_location"]),left_on="sample_name_1",right_on="profile_name",how="left")
637
+ .rename({"profile_location":"profile_location_1","scaffold_location":"scaffold_location_1"})
638
+ .join(self.profile_db.db.select(["profile_name","profile_location","scaffold_location"]),left_on="sample_name_2",right_on="profile_name",how="left")
639
+ .rename({"profile_location":"profile_location_2","scaffold_location":"scaffold_location_2"})
640
+ )
641
+
642
+
643
+
644
+ class GeneComparisonDatabase:
645
+ """
646
+ GeneComparisonDatabase object holds a reference to a gene comparison parquet file. The methods in this class serve to provide
647
+ functionality for working with the gene comparison data in an easy and efficient manner.
648
+ The comparison parquet file is the result of running gene-level comparisons, and optionally concatenating multiple compare parquet files from single comparisons.
649
+ This parquet file must contain the following columns:
650
+
651
+ - genome
652
+ - gene
653
+ - total_positions
654
+ - share_allele_pos
655
+ - ani
656
+ - sample_1
657
+ - sample_2
658
+
659
+ A GeneComparisonDatabase object needs a GeneComparisonConfig object to specify the parameters used for the comparison.
660
+
661
+ Args:
662
+ profile_db (ProfileDatabase): The profile database used for the comparison.
663
+ config (GeneComparisonConfig): The gene comparison configuration used for the comparison.
664
+ comp_db_loc (str|None): The location of the comparison database parquet file. If
665
+ None, an empty comparison database is created.
666
+ """
667
+ COLUMN_NAMES = [
668
+ "genome",
669
+ "gene",
670
+ "total_positions",
671
+ "share_allele_pos",
672
+ "ani",
673
+ "sample_1",
674
+ "sample_2"
675
+ ]
676
+
677
+ def __init__(self,
678
+ profile_db: ProfileDatabase,
679
+ config: GeneComparisonConfig,
680
+ comp_db_loc: str|None = None,
681
+ ):
682
+ self.profile_db = profile_db
683
+ self.config = config
684
+ if comp_db_loc is not None:
685
+ self.comp_db_loc = pathlib.Path(comp_db_loc)
686
+ self._comp_db = pl.scan_parquet(self.comp_db_loc)
687
+ else:
688
+ self.comp_db_loc = None
689
+ self._comp_db = pl.LazyFrame({
690
+ "genome": [],
691
+ "gene": [],
692
+ "total_positions": [],
693
+ "share_allele_pos": [],
694
+ "ani": [],
695
+ "sample_1": [],
696
+ "sample_2": []
697
+ }, schema={
698
+ "genome": pl.Utf8,
699
+ "gene": pl.Utf8,
700
+ "total_positions": pl.Int64,
701
+ "share_allele_pos": pl.Int64,
702
+ "ani": pl.Float64,
703
+ "sample_1": pl.Utf8,
704
+ "sample_2": pl.Utf8
705
+ })
706
+ self.comp_db_loc = None
707
+
708
+ @property
709
+ def comp_db(self):
710
+ return self._comp_db
711
+
712
+ def _validate_db(self)->None:
713
+ """Validate the gene comparison database structure and content."""
714
+ self.profile_db._validate_db()
715
+
716
+ if set(self._comp_db.collect_schema()) != set(self.COLUMN_NAMES):
717
+ raise ValueError(f"Your comparison database must provide these extra columns: {set(self.COLUMN_NAMES) - set(self._comp_db.collect_schema())}")
718
+
719
+ # Check if all profile names exist in the profile database
720
+ profile_names_in_comp_db = set(self.get_all_profile_names())
721
+ profile_names_in_profile_db = set(self.profile_db.db.select("profile_name").collect(engine="streaming").to_series().to_list())
722
+ if not profile_names_in_comp_db.issubset(profile_names_in_profile_db):
723
+ raise ValueError(f"The following profile names are in the comparison database but not in the profile database: {profile_names_in_comp_db - profile_names_in_profile_db}")
724
+
725
+ def get_all_profile_names(self) -> set[str]:
726
+ """
727
+ Get all profile names that are in the comparison database.
728
+
729
+ Returns:
730
+ set[str]: Set of all profile names in the comparison database.
731
+ """
732
+ return set(self.comp_db.select(pl.col("sample_1")).collect(engine="streaming").to_series().to_list()).union(
733
+ set(self.comp_db.select(pl.col("sample_2")).collect(engine="streaming").to_series().to_list())
734
+ )
735
+
736
+ def get_remaining_pairs(self) -> pl.LazyFrame:
737
+ """
738
+ Get pairs of profiles that are in the profile database but not in the comparison database.
739
+
740
+ Returns:
741
+ pl.LazyFrame: LazyFrame with columns profile_1 and profile_2 containing remaining pairs.
742
+ """
743
+ profiles = self.profile_db.db.select("profile_name")
744
+ pairs = profiles.join(profiles, how="cross").rename({"profile_name": "profile_1", "profile_name_right": "profile_2"}).filter(pl.col("profile_1") < pl.col("profile_2"))
745
+ samplepairs = self.comp_db.group_by("sample_1", "sample_2").agg().with_columns(
746
+ pl.min_horizontal(["sample_1", "sample_2"]).alias("profile_1"),
747
+ pl.max_horizontal(["sample_1", "sample_2"]).alias("profile_2")
748
+ ).select(["profile_1", "profile_2"])
749
+
750
+ remaining_pairs = pairs.join(samplepairs, on=["profile_1", "profile_2"], how="anti").sort(["profile_1", "profile_2"])
751
+ return remaining_pairs
752
+
753
+ def is_complete(self) -> bool:
754
+ """
755
+ Check if the comparison database is complete, i.e., if all pairs of profiles in the profile database have been compared.
756
+
757
+ Returns:
758
+ bool: True if all pairs have been compared, False otherwise.
759
+ """
760
+ return self.get_remaining_pairs().collect(engine="streaming").is_empty()
761
+
762
+ def add_comp_database(self, comp_database: GeneComparisonDatabase) -> None:
763
+ """Merge the provided gene comparison database into the current database.
764
+
765
+ Args:
766
+ comp_database (GeneComparisonDatabase): The gene comparison database to merge.
767
+
768
+ Raises:
769
+ ValueError: If the provided database is invalid or incompatible.
770
+ """
771
+ try:
772
+ comp_database._validate_db()
773
+ except Exception as e:
774
+ raise ValueError(f"The comparison database provided is not valid: {e}")
775
+
776
+ if not self.config.is_compatible(comp_database.config):
777
+ raise ValueError("The comparison database provided is not compatible with the current comparison database.")
778
+
779
+ self._comp_db = pl.concat([self._comp_db, comp_database.comp_db]).unique()
780
+ self.config = self.config.get_maximal_scope_config(comp_database.config)
781
+
782
+ def save_new_compare_database(self, output_path: str) -> None:
783
+ """Save the database to a parquet file.
784
+
785
+ Args:
786
+ output_path (str): The path to save the parquet file to.
787
+
788
+ Raises:
789
+ ValueError: If the output path is the same as the current database location.
790
+ """
791
+ output_path = pathlib.Path(output_path)
792
+ output_path.parent.mkdir(parents=True, exist_ok=True)
793
+
794
+ # The new database must be written to a new location
795
+ if self.comp_db_loc is not None and str(self.comp_db_loc.absolute()) == str(output_path.absolute()):
796
+ raise ValueError("The output path must be different from the current database location.")
797
+
798
+ self.comp_db.sink_parquet(output_path)
799
+
800
+ def update_compare_database(self) -> None:
801
+ """Overwrites the comparison database saved on the disk to the current comparison database object.
802
+
803
+ Raises:
804
+ Exception: If comp_db_loc is not set or if update fails.
805
+ """
806
+ if self.comp_db_loc is None:
807
+ raise Exception("comp_db_loc attribute is not determined yet!")
808
+ try:
809
+ tmp_path = pathlib.Path(tempfile.mktemp(suffix=".parquet", prefix="tmp_gene_comp_db_", dir=str(self.comp_db_loc.parent)))
810
+ self.comp_db.sink_parquet(tmp_path)
811
+ os.replace(tmp_path, self.comp_db_loc)
812
+ self._comp_db = pl.scan_parquet(self.comp_db_loc)
813
+ except Exception as e:
814
+ raise Exception(f"Something went wrong when updating the comparison database: {e}")
815
+
816
+ def dump_obj(self, output_path: str) -> None:
817
+ """Dump the current object to a json file.
818
+
819
+ Args:
820
+ output_path (str): The path to save the json file to.
821
+ """
822
+ obj_dict = {
823
+ "profile_db_loc": str(self.profile_db.db_loc.absolute()) if self.profile_db.db_loc is not None else None,
824
+ "config": self.config.to_dict(),
825
+ "comp_db_loc": str(self.comp_db_loc.absolute()) if self.comp_db_loc is not None else None
826
+ }
827
+ with open(output_path, "w") as f:
828
+ json.dump(obj_dict, f, indent=4)
829
+
830
+ @classmethod
831
+ def load_obj(cls, json_path: str) -> GeneComparisonDatabase:
832
+ """Load a GeneComparisonDatabase object from a json file.
833
+
834
+ Args:
835
+ json_path (str): The path to the json file.
836
+
837
+ Returns:
838
+ GeneComparisonDatabase: The loaded GeneComparisonDatabase object.
839
+ """
840
+ with open(json_path, "r") as f:
841
+ obj_dict = json.load(f)
842
+
843
+ return cls(
844
+ profile_db=ProfileDatabase(db_loc=obj_dict["profile_db_loc"]),
845
+ config=GeneComparisonConfig(**obj_dict["config"]),
846
+ comp_db_loc=obj_dict["comp_db_loc"]
847
+ )
848
+
849
+ def to_complete_input_table(self) -> pl.LazyFrame:
850
+ """This method gives a table of all pairwise comparisons that is needed to make the comparison database complete.
851
+ The table contains the following columns:
852
+
853
+ - sample_name_1
854
+ - sample_name_2
855
+ - profile_location_1
856
+ - scaffold_location_1
857
+ - profile_location_2
858
+ - scaffold_location_2
859
+
860
+ Returns:
861
+ pl.LazyFrame: The table of all pairwise comparisons needed to complete the comparison database.
862
+ """
863
+ lf = self.get_remaining_pairs().rename({"profile_1": "sample_name_1", "profile_2": "sample_name_2"})
864
+ return (lf.join(self.profile_db.db.select(["profile_name", "profile_location", "scaffold_location"]),
865
+ left_on="sample_name_1", right_on="profile_name", how="left")
866
+ .rename({"profile_location": "profile_location_1", "scaffold_location": "scaffold_location_1"})
867
+ .join(self.profile_db.db.select(["profile_name", "profile_location", "scaffold_location"]),
868
+ left_on="sample_name_2", right_on="profile_name", how="left")
869
+ .rename({"profile_location": "profile_location_2", "scaffold_location": "scaffold_location_2"})
870
+ )
871
+