gffkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1001 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ agat_sp_complement_annotations.pl 的 Python 改写版(纯 Python,不调用 Perl)。
6
+
7
+ 功能概要:
8
+ 1. 读取一个参考注释文件(--ref);
9
+ 2. 依次读取一个或多个补充注释文件(--add);
10
+ 3. 按原脚本说明的规则,把可补充的一级特征(通常是 gene)加入参考注释;
11
+ 4. 输出合并后的 GFF3 文件。
12
+ 5. 可选地对指定染色体区间启用“主参考/补充来源互换”模式。
13
+ 6. 支持从 detect_bridge_merged_genes.py 产生的 suspicious.tsv 自动读取交换区间,
14
+ 并对区间边界自动扩展一定 bp(默认 100 bp)。
15
+
16
+ 说明:
17
+ - 原 Perl 脚本依赖 AGAT::AGAT 中的内部函数;这里给出的是不依赖 Perl/AGAT 的 Python 实现。
18
+ - 为了尽可能接近原始行为,脚本支持常见的 GFF3 / GTF 解析,并按描述中的规则做补充。
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import copy
25
+ import re
26
+ import sys
27
+ from collections import defaultdict
28
+ from dataclasses import dataclass, field
29
+ from typing import Dict, Iterable, List, Optional, Tuple
30
+
31
+
32
+ GENE_LIKE_TYPES = {
33
+ "gene", "pseudogene", "nc_gene", "lnc_RNA_gene", "miRNA_gene",
34
+ "snRNA_gene", "snoRNA_gene", "tRNA_gene", "rRNA_gene",
35
+ }
36
+
37
+ TRANSCRIPT_LIKE_TYPES = {
38
+ "mRNA", "transcript", "lnc_RNA", "ncRNA", "miRNA", "snRNA",
39
+ "snoRNA", "tRNA", "rRNA", "pre_miRNA", "pseudogenic_transcript",
40
+ }
41
+
42
+ NON_L3_TYPES = GENE_LIKE_TYPES | TRANSCRIPT_LIKE_TYPES
43
+
44
+ _SYNTHETIC_COUNTER = {
45
+ "gene": 0,
46
+ "transcript": 0,
47
+ }
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class SwapRegion:
52
+ """
53
+ 表示一个需要交换主次参考关系的区间。
54
+
55
+ 规则:
56
+ - 在该区间内:-a 文件视为主参考,-r 文件视为补充来源;
57
+ - 在该区间外:仍保持原规则,即 -r 为主参考,-a 为补充来源。
58
+
59
+ 说明:
60
+ - 这里以“一级特征(通常是 gene)是否与区间发生重叠”来判断该特征属于区间内还是区间外;
61
+ - 若某个一级特征横跨区间边界,只要与区间重叠,就按“区间内”处理。
62
+ """
63
+ seqid: str
64
+ start: int
65
+ end: int
66
+
67
+
68
+ def eprint(*args, **kwargs):
69
+ """打印到标准错误,用于日志信息。"""
70
+ print(*args, file=sys.stderr, **kwargs)
71
+
72
+
73
+ def overlaps_1d(a_start: int, a_end: int, b_start: int, b_end: int) -> bool:
74
+ """判断两个闭区间是否重叠。"""
75
+ return not (a_end < b_start or b_end < a_start)
76
+
77
+
78
+ def strand_compatible(a: str, b: str) -> bool:
79
+ """
80
+ 判断链方向是否可比较:
81
+ - 相同链:可比较
82
+ - 任一为 '.':视为兼容
83
+ """
84
+ return a == b or a == "." or b == "."
85
+
86
+
87
+ def merge_intervals(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
88
+ """合并一组区间,避免 CDS 长度重复累计。"""
89
+ if not intervals:
90
+ return []
91
+ intervals = sorted(intervals)
92
+ merged = [list(intervals[0])]
93
+ for start, end in intervals[1:]:
94
+ if start <= merged[-1][1] + 1:
95
+ merged[-1][1] = max(merged[-1][1], end)
96
+ else:
97
+ merged.append([start, end])
98
+ return [(s, e) for s, e in merged]
99
+
100
+
101
+ def parse_gff3_attributes(attr_text: str) -> Dict[str, str]:
102
+ """解析 GFF3 第 9 列属性。"""
103
+ attrs: Dict[str, str] = {}
104
+ if not attr_text.strip():
105
+ return attrs
106
+ for item in attr_text.strip().split(";"):
107
+ if not item:
108
+ continue
109
+ if "=" in item:
110
+ k, v = item.split("=", 1)
111
+ attrs[k.strip()] = v.strip()
112
+ else:
113
+ attrs[item.strip()] = ""
114
+ return attrs
115
+
116
+
117
+ def parse_gtf_attributes(attr_text: str) -> Dict[str, str]:
118
+ """解析 GTF 第 9 列属性。"""
119
+ attrs: Dict[str, str] = {}
120
+ # 兼容 key "value"; 以及少数不规范写法
121
+ for part in re.finditer(r'(\S+)\s+"([^"]*)"|(\S+)\s+([^;]+)', attr_text):
122
+ if part.group(1) is not None:
123
+ attrs[part.group(1)] = part.group(2)
124
+ else:
125
+ attrs[part.group(3)] = part.group(4).strip().strip('"')
126
+ return attrs
127
+
128
+
129
+ def detect_format(path: str) -> str:
130
+ """
131
+ 根据文件内容粗略判断是 GFF3 还是 GTF。
132
+ 返回值:'gff3' 或 'gtf'
133
+ """
134
+ with open(path, "r", encoding="utf-8") as fh:
135
+ for line in fh:
136
+ line = line.strip()
137
+ if not line or line.startswith("#"):
138
+ continue
139
+ parts = line.split("\t")
140
+ if len(parts) != 9:
141
+ continue
142
+ attr_text = parts[8]
143
+ if "=" in attr_text:
144
+ return "gff3"
145
+ return "gtf"
146
+ return "gff3"
147
+
148
+
149
+ @dataclass
150
+ class Feature:
151
+ """
152
+ 表示一个注释特征(gene / transcript / exon / CDS 等)。
153
+ """
154
+ seqid: str
155
+ source: str
156
+ feature_type: str
157
+ start: int
158
+ end: int
159
+ score: str
160
+ strand: str
161
+ phase: str
162
+ attrs: Dict[str, str] = field(default_factory=dict)
163
+ children: List["Feature"] = field(default_factory=list)
164
+ parent: Optional["Feature"] = None
165
+
166
+ @property
167
+ def id(self) -> Optional[str]:
168
+ return self.attrs.get("ID")
169
+
170
+ @id.setter
171
+ def id(self, value: str) -> None:
172
+ self.attrs["ID"] = value
173
+
174
+ @property
175
+ def name(self) -> Optional[str]:
176
+ return self.attrs.get("Name")
177
+
178
+ @name.setter
179
+ def name(self, value: str) -> None:
180
+ self.attrs["Name"] = value
181
+
182
+ @property
183
+ def length(self) -> int:
184
+ return self.end - self.start + 1
185
+
186
+ def add_child(self, child: "Feature") -> None:
187
+ child.parent = self
188
+ self.children.append(child)
189
+
190
+ def iter_descendants(self) -> Iterable["Feature"]:
191
+ for child in self.children:
192
+ yield child
193
+ yield from child.iter_descendants()
194
+
195
+ def iter_all(self) -> Iterable["Feature"]:
196
+ yield self
197
+ yield from self.iter_descendants()
198
+
199
+ def deep_clone(self) -> "Feature":
200
+ cloned = copy.deepcopy(self)
201
+
202
+ def fix_parent(node: "Feature", parent: Optional["Feature"] = None) -> None:
203
+ node.parent = parent
204
+ for c in node.children:
205
+ fix_parent(c, node)
206
+
207
+ fix_parent(cloned, None)
208
+ return cloned
209
+
210
+ def overlaps(self, other: "Feature") -> bool:
211
+ return (
212
+ self.seqid == other.seqid
213
+ and strand_compatible(self.strand, other.strand)
214
+ and overlaps_1d(self.start, self.end, other.start, other.end)
215
+ )
216
+
217
+ def to_gff3_lines(self) -> List[str]:
218
+ """
219
+ 递归导出为 GFF3 文本。
220
+ 这里统一输出 GFF3,而不是原样保留 GTF。
221
+ """
222
+ lines = []
223
+ attrs = dict(self.attrs)
224
+
225
+ # Parent 关系由树结构决定,输出时重新写入,避免旧值不一致
226
+ if self.parent and self.parent.id:
227
+ attrs["Parent"] = self.parent.id
228
+ else:
229
+ attrs.pop("Parent", None)
230
+
231
+ attr_text = ";".join(
232
+ f"{k}={v}" if v != "" else k
233
+ for k, v in attrs.items()
234
+ )
235
+ row = [
236
+ self.seqid,
237
+ self.source,
238
+ self.feature_type,
239
+ str(self.start),
240
+ str(self.end),
241
+ self.score,
242
+ self.strand,
243
+ self.phase,
244
+ attr_text if attr_text else ".",
245
+ ]
246
+ lines.append("\t".join(row))
247
+
248
+ # 递归输出子节点;按坐标和类型稳定排序
249
+ for child in sorted(
250
+ self.children,
251
+ key=lambda x: (x.start, x.end, x.feature_type, x.id or "")
252
+ ):
253
+ lines.extend(child.to_gff3_lines())
254
+ return lines
255
+
256
+
257
+ @dataclass
258
+ class AnnotationSet:
259
+ """
260
+ 保存整套注释数据,根节点一般是 gene(也允许是其他顶层特征)。
261
+ """
262
+ roots: List[Feature] = field(default_factory=list)
263
+
264
+ def all_ids(self) -> set:
265
+ ids = set()
266
+ for root in self.roots:
267
+ for feat in root.iter_all():
268
+ if feat.id:
269
+ ids.add(feat.id)
270
+ return ids
271
+
272
+ def all_names(self) -> set:
273
+ names = set()
274
+ for root in self.roots:
275
+ for feat in root.iter_all():
276
+ if feat.name:
277
+ names.add(feat.name)
278
+ return names
279
+
280
+ def level_counts(self) -> Dict[str, Dict[str, int]]:
281
+ """
282
+ 统计 level1 / level2 的 feature_type 数量,模仿原脚本的 quick stat 逻辑。
283
+ """
284
+ counts = {
285
+ "level1": defaultdict(int),
286
+ "level2": defaultdict(int),
287
+ }
288
+ for root in self.roots:
289
+ counts["level1"][root.feature_type] += 1
290
+ for child in root.children:
291
+ counts["level2"][child.feature_type] += 1
292
+ return counts
293
+
294
+ def info(self) -> None:
295
+ counts = self.level_counts()
296
+ eprint("Current annotation summary:")
297
+ for level in ("level1", "level2"):
298
+ eprint(f" {level}:")
299
+ for tag in sorted(counts[level]):
300
+ eprint(f" {tag}: {counts[level][tag]}")
301
+
302
+ def write_gff3(self, output_path: Optional[str]) -> None:
303
+ out_fh = open(output_path, "w", encoding="utf-8") if output_path else sys.stdout
304
+ try:
305
+ print("##gff-version 3", file=out_fh)
306
+ for root in sorted(
307
+ self.roots,
308
+ key=lambda x: (x.seqid, x.start, x.end, x.feature_type, x.id or "")
309
+ ):
310
+ for line in root.to_gff3_lines():
311
+ print(line, file=out_fh)
312
+ finally:
313
+ if output_path:
314
+ out_fh.close()
315
+
316
+
317
+ def make_synthetic_id(prefix: str) -> str:
318
+ """为 GTF 中缺失的 gene/transcript 生成稳定的伪 ID。"""
319
+ _SYNTHETIC_COUNTER[prefix] += 1
320
+ return f"synthetic_{prefix}_{_SYNTHETIC_COUNTER[prefix]}"
321
+
322
+
323
+ def parse_gff3(path: str) -> AnnotationSet:
324
+ """解析 GFF3 文件并构建树结构。"""
325
+ features_by_id: Dict[str, Feature] = {}
326
+ all_features: List[Tuple[Feature, List[str]]] = []
327
+
328
+ with open(path, "r", encoding="utf-8") as fh:
329
+ for line_num, line in enumerate(fh, start=1):
330
+ line = line.rstrip("\n")
331
+ if not line or line.startswith("#"):
332
+ continue
333
+
334
+ parts = line.split("\t")
335
+ if len(parts) != 9:
336
+ raise ValueError(f"{path}:{line_num} 不是合法的 9 列 GFF/GTF 记录")
337
+
338
+ seqid, source, ftype, start, end, score, strand, phase, attr_text = parts
339
+ attrs = parse_gff3_attributes(attr_text)
340
+ parent_ids = attrs.get("Parent", "")
341
+ parent_list = [x for x in parent_ids.split(",") if x] if parent_ids else []
342
+
343
+ feat = Feature(
344
+ seqid=seqid,
345
+ source=source,
346
+ feature_type=ftype,
347
+ start=int(start),
348
+ end=int(end),
349
+ score=score,
350
+ strand=strand,
351
+ phase=phase,
352
+ attrs=attrs,
353
+ )
354
+
355
+ if feat.id:
356
+ if feat.id in features_by_id:
357
+ raise ValueError(f"{path}:{line_num} 出现重复 ID:{feat.id}")
358
+ features_by_id[feat.id] = feat
359
+
360
+ all_features.append((feat, parent_list))
361
+
362
+ roots: List[Feature] = []
363
+
364
+ # 第二轮:挂接父子关系
365
+ for feat, parent_ids in all_features:
366
+ if not parent_ids:
367
+ roots.append(feat)
368
+ continue
369
+
370
+ attached = False
371
+ for pid in parent_ids:
372
+ parent = features_by_id.get(pid)
373
+ if parent is None:
374
+ continue
375
+ parent.add_child(feat)
376
+ attached = True
377
+
378
+ # 若 Parent 指向不存在,则当作顶层特征保留,避免数据丢失
379
+ if not attached:
380
+ roots.append(feat)
381
+
382
+ return AnnotationSet(roots=roots)
383
+
384
+
385
+ def parse_gtf(path: str) -> AnnotationSet:
386
+ """
387
+ 解析 GTF,并补出缺失的 gene / transcript 层级。
388
+ 最终仍构建成统一的树结构,方便后续处理。
389
+ """
390
+ genes: Dict[str, Feature] = {}
391
+ transcripts: Dict[str, Feature] = {}
392
+
393
+ with open(path, "r", encoding="utf-8") as fh:
394
+ for line_num, line in enumerate(fh, start=1):
395
+ line = line.rstrip("\n")
396
+ if not line or line.startswith("#"):
397
+ continue
398
+
399
+ parts = line.split("\t")
400
+ if len(parts) != 9:
401
+ raise ValueError(f"{path}:{line_num} 不是合法的 9 列 GFF/GTF 记录")
402
+
403
+ seqid, source, ftype, start, end, score, strand, phase, attr_text = parts
404
+ attrs = parse_gtf_attributes(attr_text)
405
+ start_i = int(start)
406
+ end_i = int(end)
407
+
408
+ gene_id = attrs.get("gene_id") or attrs.get("geneID") or attrs.get("ID")
409
+ transcript_id = attrs.get("transcript_id") or attrs.get("transcriptID")
410
+
411
+ # 处理 gene 记录
412
+ if ftype in GENE_LIKE_TYPES or (ftype == "gene"):
413
+ if not gene_id:
414
+ gene_id = make_synthetic_id("gene")
415
+ if gene_id not in genes:
416
+ gene_attrs = dict(attrs)
417
+ gene_attrs["ID"] = gene_id
418
+ gene_attrs.setdefault("Name", gene_id)
419
+ genes[gene_id] = Feature(
420
+ seqid=seqid,
421
+ source=source,
422
+ feature_type="gene",
423
+ start=start_i,
424
+ end=end_i,
425
+ score=score,
426
+ strand=strand,
427
+ phase=".",
428
+ attrs=gene_attrs,
429
+ )
430
+ else:
431
+ genes[gene_id].start = min(genes[gene_id].start, start_i)
432
+ genes[gene_id].end = max(genes[gene_id].end, end_i)
433
+ continue
434
+
435
+ # 确保 gene 节点存在
436
+ if not gene_id:
437
+ gene_id = make_synthetic_id("gene")
438
+ if gene_id not in genes:
439
+ genes[gene_id] = Feature(
440
+ seqid=seqid,
441
+ source=source,
442
+ feature_type="gene",
443
+ start=start_i,
444
+ end=end_i,
445
+ score=".",
446
+ strand=strand,
447
+ phase=".",
448
+ attrs={"ID": gene_id, "Name": gene_id},
449
+ )
450
+ else:
451
+ genes[gene_id].start = min(genes[gene_id].start, start_i)
452
+ genes[gene_id].end = max(genes[gene_id].end, end_i)
453
+
454
+ # transcript 记录
455
+ if ftype in TRANSCRIPT_LIKE_TYPES or ftype == "transcript":
456
+ if not transcript_id:
457
+ transcript_id = make_synthetic_id("transcript")
458
+ if transcript_id not in transcripts:
459
+ tx_attrs = dict(attrs)
460
+ tx_attrs["ID"] = transcript_id
461
+ tx_attrs.setdefault("Name", transcript_id)
462
+ tx = Feature(
463
+ seqid=seqid,
464
+ source=source,
465
+ feature_type="transcript",
466
+ start=start_i,
467
+ end=end_i,
468
+ score=score,
469
+ strand=strand,
470
+ phase=".",
471
+ attrs=tx_attrs,
472
+ )
473
+ genes[gene_id].add_child(tx)
474
+ transcripts[transcript_id] = tx
475
+ else:
476
+ transcripts[transcript_id].start = min(transcripts[transcript_id].start, start_i)
477
+ transcripts[transcript_id].end = max(transcripts[transcript_id].end, end_i)
478
+ continue
479
+
480
+ # 其余子特征(exon / CDS / UTR 等)
481
+ parent_node: Feature
482
+ if transcript_id:
483
+ if transcript_id not in transcripts:
484
+ tx = Feature(
485
+ seqid=seqid,
486
+ source=source,
487
+ feature_type="transcript",
488
+ start=start_i,
489
+ end=end_i,
490
+ score=".",
491
+ strand=strand,
492
+ phase=".",
493
+ attrs={"ID": transcript_id, "Name": transcript_id},
494
+ )
495
+ genes[gene_id].add_child(tx)
496
+ transcripts[transcript_id] = tx
497
+ else:
498
+ transcripts[transcript_id].start = min(transcripts[transcript_id].start, start_i)
499
+ transcripts[transcript_id].end = max(transcripts[transcript_id].end, end_i)
500
+ parent_node = transcripts[transcript_id]
501
+ else:
502
+ parent_node = genes[gene_id]
503
+
504
+ child_id = attrs.get("ID")
505
+ if not child_id:
506
+ # 子特征允许没有 ID,这里给一个不冲突的临时 ID,便于统一输出和后续重命名
507
+ child_id = f"{ftype}_{seqid}_{start_i}_{end_i}_{strand}_{line_num}"
508
+
509
+ child_attrs = dict(attrs)
510
+ child_attrs["ID"] = child_id
511
+ child = Feature(
512
+ seqid=seqid,
513
+ source=source,
514
+ feature_type=ftype,
515
+ start=start_i,
516
+ end=end_i,
517
+ score=score,
518
+ strand=strand,
519
+ phase=phase,
520
+ attrs=child_attrs,
521
+ )
522
+ parent_node.add_child(child)
523
+
524
+ return AnnotationSet(roots=list(genes.values()))
525
+
526
+
527
+ def parse_annotation_file(path: str) -> AnnotationSet:
528
+ """自动判断输入格式并解析。"""
529
+ fmt = detect_format(path)
530
+ if fmt == "gtf":
531
+ return parse_gtf(path)
532
+ return parse_gff3(path)
533
+
534
+
535
+ def get_cds_features(root: Feature) -> List[Feature]:
536
+ """提取某个一级特征下面所有 CDS 特征。"""
537
+ return [f for f in root.iter_all() if f.feature_type == "CDS"]
538
+
539
+
540
+ def get_l3_features(root: Feature) -> List[Feature]:
541
+ """
542
+ 取“三级及以下”的功能性子特征,用于“无 CDS vs 无 CDS”的重叠判断。
543
+ 这里把所有非 gene / transcript 类的后代都视为可比较特征。
544
+ """
545
+ feats = []
546
+ for f in root.iter_descendants():
547
+ if f.feature_type not in NON_L3_TYPES:
548
+ feats.append(f)
549
+ return feats
550
+
551
+
552
+ def gene_has_cds(root: Feature) -> bool:
553
+ """判断某个一级特征下是否存在 CDS。"""
554
+ return any(f.feature_type == "CDS" for f in root.iter_all())
555
+
556
+
557
+ def cds_size_nt(root: Feature) -> int:
558
+ """
559
+ 计算 CDS 总长度(核苷酸数)。
560
+ 为避免同一基因不同转录本的重复区间被重复累计,这里按区间合并后求长度。
561
+ """
562
+ groups: Dict[Tuple[str, str], List[Tuple[int, int]]] = defaultdict(list)
563
+ for cds in get_cds_features(root):
564
+ groups[(cds.seqid, cds.strand)].append((cds.start, cds.end))
565
+
566
+ total = 0
567
+ for _, ivals in groups.items():
568
+ for s, e in merge_intervals(ivals):
569
+ total += (e - s + 1)
570
+ return total
571
+
572
+
573
+ def any_feature_overlap(features_a: List[Feature], features_b: List[Feature]) -> bool:
574
+ """判断两组特征中是否存在任意一对重叠。"""
575
+ for fa in features_a:
576
+ for fb in features_b:
577
+ if fa.overlaps(fb):
578
+ return True
579
+ return False
580
+
581
+
582
+ def should_add_root(add_root: Feature, ref_roots: List[Feature], size_min: int) -> bool:
583
+ """
584
+ 按原脚本说明中的规则,判断 add_root 是否应该加入参考注释。
585
+
586
+ 规则摘要:
587
+ 1. add_root 与参考中任何一级特征都不重叠:
588
+ - 若 CDS 长度 >= size_min,则加入。
589
+ 2. 若发生重叠:
590
+ - add 无 CDS、ref 有 CDS:加入
591
+ - add 有 CDS、ref 无 CDS:加入
592
+ - add 有 CDS、ref 有 CDS:仅在 CDS 不重叠时加入
593
+ - add 无 CDS、ref 无 CDS:仅在所有 l3 特征都不重叠时加入
594
+
595
+ 注意:
596
+ - 只要有一个重叠关系不满足条件,就不加入整个一级特征。
597
+ 这与原说明中“只要有一个 isoform 重叠,就阻止整个 gene 加入”的精神一致。
598
+ """
599
+ overlapping_refs = [r for r in ref_roots if add_root.overlaps(r)]
600
+
601
+ # 情况 1:完全不与参考一级特征重叠
602
+ if not overlapping_refs:
603
+ return cds_size_nt(add_root) >= size_min
604
+
605
+ add_has_cds = gene_has_cds(add_root)
606
+ add_cds = get_cds_features(add_root)
607
+ add_l3 = get_l3_features(add_root)
608
+
609
+ for ref_root in overlapping_refs:
610
+ ref_has_cds = gene_has_cds(ref_root)
611
+
612
+ # add 无 CDS;ref 有 CDS -> 允许
613
+ if (not add_has_cds) and ref_has_cds:
614
+ continue
615
+
616
+ # add 有 CDS;ref 无 CDS -> 允许
617
+ if add_has_cds and (not ref_has_cds):
618
+ continue
619
+
620
+ # 双方都有 CDS -> CDS 不能重叠
621
+ if add_has_cds and ref_has_cds:
622
+ ref_cds = get_cds_features(ref_root)
623
+ if any_feature_overlap(add_cds, ref_cds):
624
+ return False
625
+ continue
626
+
627
+ # 双方都没有 CDS -> 所有非 gene/transcript 子特征都不能重叠
628
+ ref_l3 = get_l3_features(ref_root)
629
+ if any_feature_overlap(add_l3, ref_l3):
630
+ return False
631
+
632
+ return True
633
+
634
+
635
+ def feature_overlaps_region(feature: Feature, region: SwapRegion) -> bool:
636
+ """判断某个一级特征是否与指定交换区间重叠。"""
637
+ return (
638
+ feature.seqid == region.seqid
639
+ and overlaps_1d(feature.start, feature.end, region.start, region.end)
640
+ )
641
+
642
+
643
+ def feature_in_swap_regions(feature: Feature, swap_regions: List[SwapRegion]) -> bool:
644
+ """判断一级特征是否落入任一交换区间。"""
645
+ return any(feature_overlaps_region(feature, region) for region in swap_regions)
646
+
647
+
648
+ def build_priority_sets_for_region_swap(
649
+ ref_set: AnnotationSet,
650
+ add_set: AnnotationSet,
651
+ swap_regions: List[SwapRegion],
652
+ ) -> Tuple[AnnotationSet, AnnotationSet, Dict[str, int]]:
653
+ """
654
+ 根据交换区间,把当前 ref/add 两套注释拆成“主参考集合”和“补充集合”。
655
+
656
+ 返回:
657
+ - primary_set: 当前这一步真正作为主参考的一级特征集合
658
+ - supplemental_set: 当前这一步真正作为补充来源的一级特征集合
659
+ - stats: 各来源在区间内/外被分到哪一侧的统计,便于日志输出
660
+
661
+ 逻辑:
662
+ - ref 在区间外 -> 仍为主参考
663
+ - ref 在区间内 -> 改为补充来源
664
+ - add 在区间外 -> 仍为补充来源
665
+ - add 在区间内 -> 改为主参考
666
+
667
+ 这样做有一个重要好处:
668
+ - 补充判断时,补充来源中的特征会与“全部主参考特征”比较,
669
+ 因此即使某个一级特征横跨区间边界,也不会漏掉与另一侧主参考的冲突判断。
670
+ """
671
+ primary_roots: List[Feature] = []
672
+ supplemental_roots: List[Feature] = []
673
+
674
+ stats = {
675
+ "ref_inside": 0,
676
+ "ref_outside": 0,
677
+ "add_inside": 0,
678
+ "add_outside": 0,
679
+ }
680
+
681
+ for root in ref_set.roots:
682
+ if feature_in_swap_regions(root, swap_regions):
683
+ supplemental_roots.append(root)
684
+ stats["ref_inside"] += 1
685
+ else:
686
+ primary_roots.append(root)
687
+ stats["ref_outside"] += 1
688
+
689
+ for root in add_set.roots:
690
+ if feature_in_swap_regions(root, swap_regions):
691
+ primary_roots.append(root)
692
+ stats["add_inside"] += 1
693
+ else:
694
+ supplemental_roots.append(root)
695
+ stats["add_outside"] += 1
696
+
697
+ return AnnotationSet(roots=primary_roots), AnnotationSet(roots=supplemental_roots), stats
698
+
699
+
700
+ def uniquify_feature_tree(root: Feature, existing_ids: set, existing_names: set) -> None:
701
+ """
702
+ 为即将加入的整棵特征树消除 ID/Name 冲突。
703
+ 原 Perl 代码中说明会处理重复名字;这里同时处理 ID 和 Name。
704
+ """
705
+ for feat in root.iter_all():
706
+ # 处理 ID 冲突
707
+ if feat.id:
708
+ base = feat.id
709
+ if base in existing_ids:
710
+ i = 1
711
+ while f"{base}_dup{i}" in existing_ids:
712
+ i += 1
713
+ feat.id = f"{base}_dup{i}"
714
+ existing_ids.add(feat.id)
715
+
716
+ # 处理 Name 冲突(若存在)
717
+ if feat.name:
718
+ base_name = feat.name
719
+ if base_name in existing_names:
720
+ i = 1
721
+ while f"{base_name}_dup{i}" in existing_names:
722
+ i += 1
723
+ feat.name = f"{base_name}_dup{i}"
724
+ existing_names.add(feat.name)
725
+
726
+
727
+ def complement_annotations(ref_set: AnnotationSet, add_set: AnnotationSet, size_min: int) -> int:
728
+ """
729
+ 把 add_set 中满足条件的一级特征补充到 ref_set 中。
730
+ 返回值:实际添加的一级特征数量。
731
+ """
732
+ added_count = 0
733
+ existing_ids = ref_set.all_ids()
734
+ existing_names = ref_set.all_names()
735
+
736
+ for add_root in sorted(
737
+ add_set.roots,
738
+ key=lambda x: (x.seqid, x.start, x.end, x.feature_type, x.id or "")
739
+ ):
740
+ if should_add_root(add_root, ref_set.roots, size_min):
741
+ cloned = add_root.deep_clone()
742
+ uniquify_feature_tree(cloned, existing_ids, existing_names)
743
+ ref_set.roots.append(cloned)
744
+ added_count += 1
745
+
746
+ return added_count
747
+
748
+
749
+ def print_complement_resume(before_counts: Dict[str, Dict[str, int]],
750
+ after_counts: Dict[str, Dict[str, int]]) -> None:
751
+ """
752
+ 输出合并结果摘要。
753
+
754
+ 与最初版本不同,这里同时兼容“增加”和“减少”两种情况:
755
+ - 默认补充模式下通常只会增加;
756
+ - 开启交换区间后,某些区间内原 ref 的一级特征会被 add 作为主参考覆盖,
757
+ 因此统计上可能出现减少。
758
+ """
759
+ changed = False
760
+
761
+ for level in ("level1", "level2"):
762
+ all_tags = set(before_counts[level]) | set(after_counts[level])
763
+ for tag in sorted(all_tags):
764
+ before_n = before_counts[level].get(tag, 0)
765
+ after_n = after_counts[level].get(tag, 0)
766
+ delta = after_n - before_n
767
+ if delta > 0:
768
+ eprint(f"We added {delta} {tag}(s)")
769
+ changed = True
770
+ elif delta < 0:
771
+ eprint(f"We removed {-delta} {tag}(s)")
772
+ changed = True
773
+
774
+ if not changed:
775
+ eprint("\nNothing has been changed")
776
+ else:
777
+ eprint("\nNow the data contains:")
778
+
779
+
780
+ def build_arg_parser() -> argparse.ArgumentParser:
781
+ """构建命令行参数解析器。"""
782
+ parser = argparse.ArgumentParser(
783
+ description="用一个或多个注释文件去补充参考注释(Python 版,纯 Python,不调用 Perl)。"
784
+ )
785
+ parser.add_argument("--ref", "-r", "-i", required=True, help="参考 GFF/GTF 文件")
786
+ parser.add_argument(
787
+ "--add", "-a", action="append", required=True,
788
+ help="用于补充参考注释的 GFF/GTF 文件;可重复指定多次"
789
+ )
790
+ parser.add_argument(
791
+ "--size_min", "-s", type=int, default=0,
792
+ help="仅对“完全不重叠”的一级特征生效:其 CDS 总长度必须 >= 该值,默认 0"
793
+ )
794
+ parser.add_argument(
795
+ "--swap_region", action="append", nargs=3, metavar=("SEQID", "START", "END"),
796
+ help=(
797
+ "手动指定一个需要交换主次参考关系的区间。\n"
798
+ "在该区间内:-a 文件视为主参考,-r 文件视为补充来源;\n"
799
+ "在该区间外:仍按原规则处理。\n"
800
+ "可重复指定多次,例如:--swap_region chr09 160889718 200000000"
801
+ )
802
+ )
803
+ parser.add_argument(
804
+ "--swap_region_tsv",
805
+ help=(
806
+ "从 detect_bridge_merged_genes.py 产生的 suspicious.tsv 读取交换区间。"
807
+ "会自动忽略第一行表头;优先按表头 chrom/start/end 读取;"
808
+ "若无标准表头,则兼容按第2/3/4列读取(即 gene_id, chrom, start, end, ...)。"
809
+ )
810
+ )
811
+ parser.add_argument(
812
+ "--swap_region_flank",
813
+ type=int,
814
+ default=100,
815
+ help="从 suspicious.tsv 读取区间时,start/end 两端各扩展的 bp 数,默认 100"
816
+ )
817
+ parser.add_argument("--output", "--out", "-o", default=None, help="输出文件路径;默认输出到 STDOUT")
818
+ parser.add_argument(
819
+ "-v", "--verbose", type=int, default=1,
820
+ help="日志详细程度(0~4),这里只简单保留该参数接口,默认 1"
821
+ )
822
+ return parser
823
+
824
+
825
+ def parse_swap_regions(raw_regions: Optional[List[List[str]]]) -> List[SwapRegion]:
826
+ """
827
+ 把命令行传入的 --swap_region 参数解析为结构化区间。
828
+ """
829
+ if not raw_regions:
830
+ return []
831
+
832
+ parsed: List[SwapRegion] = []
833
+ for seqid, start_text, end_text in raw_regions:
834
+ try:
835
+ start = int(start_text)
836
+ end = int(end_text)
837
+ except ValueError as exc:
838
+ raise ValueError(
839
+ f"--swap_region 的 START/END 必须是整数:{seqid} {start_text} {end_text}"
840
+ ) from exc
841
+
842
+ if start <= 0 or end <= 0:
843
+ raise ValueError(
844
+ f"--swap_region 的坐标必须为正整数:{seqid} {start} {end}"
845
+ )
846
+
847
+ if start > end:
848
+ start, end = end, start
849
+
850
+ parsed.append(SwapRegion(seqid=seqid, start=start, end=end))
851
+
852
+ return parsed
853
+
854
+
855
+ def parse_swap_regions_from_tsv(tsv_path: str, flank_bp: int = 100) -> List[SwapRegion]:
856
+ """
857
+ 从 detect_bridge_merged_genes.py 产生的 suspicious.tsv 读取区间。
858
+
859
+ 兼容两种情况:
860
+ 1. 有表头,且包含 chrom/start/end 三列(推荐)
861
+ 2. 无法识别表头时,回退为按第 2/3/4 列读取:
862
+ gene_id, chrom, start, end, ...
863
+
864
+ 说明:
865
+ - 自动忽略第一行表头
866
+ - start 左扩 flank_bp
867
+ - end 右扩 flank_bp
868
+ - start 最小不小于 1
869
+ """
870
+ if flank_bp < 0:
871
+ raise ValueError("--swap_region_flank 不能小于 0")
872
+
873
+ regions: List[SwapRegion] = []
874
+
875
+ with open(tsv_path, "r", encoding="utf-8") as fh:
876
+ header = fh.readline().rstrip("\n")
877
+ if not header:
878
+ return regions
879
+
880
+ header_cols = header.split("\t")
881
+ lower_header = [x.strip().lower() for x in header_cols]
882
+
883
+ chrom_idx = None
884
+ start_idx = None
885
+ end_idx = None
886
+
887
+ # 优先按表头名识别
888
+ if "chrom" in lower_header and "start" in lower_header and "end" in lower_header:
889
+ chrom_idx = lower_header.index("chrom")
890
+ start_idx = lower_header.index("start")
891
+ end_idx = lower_header.index("end")
892
+
893
+ for line_num, line in enumerate(fh, start=2):
894
+ line = line.rstrip("\n")
895
+ if not line.strip():
896
+ continue
897
+
898
+ cols = line.split("\t")
899
+
900
+ try:
901
+ if chrom_idx is not None:
902
+ seqid = cols[chrom_idx]
903
+ start = int(cols[start_idx])
904
+ end = int(cols[end_idx])
905
+ else:
906
+ # 回退模式:按 suspicious.tsv 常见格式读取
907
+ # gene_id, chrom, start, end, ...
908
+ if len(cols) < 4:
909
+ raise ValueError("列数不足,至少需要 4 列")
910
+ seqid = cols[1]
911
+ start = int(cols[2])
912
+ end = int(cols[3])
913
+
914
+ if start <= 0 or end <= 0:
915
+ raise ValueError("start/end 必须为正整数")
916
+
917
+ if start > end:
918
+ start, end = end, start
919
+
920
+ start = max(1, start - flank_bp)
921
+ end = end + flank_bp
922
+
923
+ regions.append(SwapRegion(seqid=seqid, start=start, end=end))
924
+
925
+ except Exception as exc:
926
+ raise ValueError(
927
+ f"{tsv_path}:{line_num} 解析失败,请确认文件格式正确。"
928
+ f" 当前行内容为:{line}"
929
+ ) from exc
930
+
931
+ return regions
932
+
933
+
934
+ def main() -> int:
935
+ parser = build_arg_parser()
936
+ args = parser.parse_args()
937
+
938
+ try:
939
+ swap_regions = parse_swap_regions(args.swap_region)
940
+
941
+ # 如果提供了 suspicious.tsv,则把其中区间也读入
942
+ if args.swap_region_tsv:
943
+ swap_regions_from_tsv = parse_swap_regions_from_tsv(
944
+ args.swap_region_tsv,
945
+ flank_bp=args.swap_region_flank
946
+ )
947
+ swap_regions.extend(swap_regions_from_tsv)
948
+
949
+ except ValueError as exc:
950
+ parser.error(str(exc))
951
+
952
+ # 1) 读取参考注释
953
+ ref_set = parse_annotation_file(args.ref)
954
+ eprint(f"{args.ref} parsed")
955
+ ref_set.info()
956
+
957
+ if swap_regions:
958
+ eprint("Configured swap regions:")
959
+ for region in swap_regions:
960
+ eprint(f" - {region.seqid}:{region.start}-{region.end}")
961
+
962
+ # 2) 按用户给定顺序,逐个补充
963
+ for next_file in args.add:
964
+ add_set = parse_annotation_file(next_file)
965
+ eprint(f"{next_file} parsed")
966
+ add_set.info()
967
+
968
+ before_counts = ref_set.level_counts()
969
+
970
+ if swap_regions:
971
+ primary_set, supplemental_set, region_stats = build_priority_sets_for_region_swap(
972
+ ref_set, add_set, swap_regions
973
+ )
974
+ eprint(
975
+ "\nRegion swap summary: "
976
+ f"ref(outside->primary)={region_stats['ref_outside']}, "
977
+ f"ref(inside->supplement)={region_stats['ref_inside']}, "
978
+ f"add(inside->primary)={region_stats['add_inside']}, "
979
+ f"add(outside->supplement)={region_stats['add_outside']}"
980
+ )
981
+
982
+ added_count = complement_annotations(primary_set, supplemental_set, args.size_min)
983
+ ref_set = primary_set
984
+ else:
985
+ added_count = complement_annotations(ref_set, add_set, args.size_min)
986
+
987
+ eprint("\nComplement done !")
988
+
989
+ after_counts = ref_set.level_counts()
990
+ print_complement_resume(before_counts, after_counts)
991
+
992
+ if added_count > 0 or before_counts != after_counts:
993
+ ref_set.info()
994
+
995
+ # 3) 输出结果
996
+ ref_set.write_gff3(args.output)
997
+ return 0
998
+
999
+
1000
+ if __name__ == "__main__":
1001
+ sys.exit(main())