hidt 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hidt-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Li Junping <lijunping02@qq.com>
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
hidt-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.1
2
+ Name: hidt
3
+ Version: 0.1.0
4
+ Summary: A computational pipeline for identifying differential TADs from 3D genome contact maps
5
+ Home-page: https://github.com/GaoLabXDU/HiDT
6
+ Author: Li Junping
7
+ Author-email: lijunping02@qq.com
8
+ License: MIT Licence
9
+ Keywords: 3D genome,Comparative analysis,Topologically associating domains
10
+ Platform: any
11
+ License-File: LICENSE
@@ -0,0 +1,11 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on June 29 2025
4
+
5
+ @author: Li Junping
6
+ """
7
+ __author__ = 'Li Junping'
8
+ __version__ = '0.1.0'
9
+
10
+ Me = __file__
11
+
@@ -0,0 +1,12 @@
1
+ import torch
2
+ from sklearn import metrics
3
+ import numpy as np
4
+
5
+ def euclidean_distance(x, y):
6
+ """This is the squared Euclidean distance."""
7
+ return torch.sum((x - y) ** 2, dim=-1)
8
+
9
+ def compute_similarity(x, y):
10
+ """Compute the distance between x and y vectors."""
11
+ # similarity is negative distance
12
+ return -euclidean_distance(x, y)
@@ -0,0 +1,72 @@
1
+ import numpy as np
2
+ import hicstraw
3
+ from scipy.sparse import coo_matrix
4
+
5
+ def load_binsNum(hicfile, res):
6
+ # get chrom lenght in .hic file; remove X, Y and MT chromosomes
7
+ exclude_chroms = {'Y', 'MT', 'All', 'chrY'}
8
+ hic = hicstraw.HiCFile(hicfile)
9
+ chroms = hic.getChromosomes()
10
+ chrom_bins = {}
11
+ for chrom in chroms:
12
+ if any(exclude in chrom.name for exclude in exclude_chroms):
13
+ continue
14
+ num_bins = (chrom.length + res) // res
15
+ chrom_bins[chrom.name] = num_bins
16
+ return chrom_bins
17
+
18
+ def normalizationMat(matrix, binsNum):
19
+ # make the sum of rows' count equal to 1
20
+ # count / rows_sum
21
+ row_sums = np.array(matrix.sum(axis=1)).flatten()
22
+ non_zero_mask = row_sums != 0
23
+ matrix = matrix.tocoo()
24
+ normed_count = np.zeros_like(matrix.data)
25
+ for i in range(len(matrix.data)):
26
+ row = matrix.row[i]
27
+ if non_zero_mask[row]:
28
+ normed_count[i] = matrix.data[i] / row_sums[row]
29
+ else:
30
+ normed_count[i] = matrix.data[i]
31
+ normed_mat = coo_matrix((normed_count, (matrix.row, matrix.col)), shape=(binsNum, binsNum))
32
+ return normed_mat
33
+
34
+ def constructSpaMat(result, binsNum, res):
35
+ # Construct sparse matrix (A + A.T - A.diag)
36
+ rows = []
37
+ cols = []
38
+ count = []
39
+ for i in range(len(result)):
40
+ rows.append(int(result[i].binX / res))
41
+ cols.append(int(result[i].binY / res))
42
+ count.append(result[i].counts)
43
+ rows = np.array(rows)
44
+ cols = np.array(cols)
45
+ count = np.array(count)
46
+ upper_mat = coo_matrix((count, (rows, cols)), shape=(binsNum, binsNum))
47
+ lower_mat = coo_matrix((count, (cols, rows)), shape=(binsNum, binsNum))
48
+ full_mat = upper_mat + lower_mat - coo_matrix((count[rows == cols],
49
+ (rows[rows == cols],
50
+ cols[rows == cols])), shape=(binsNum, binsNum))
51
+ return full_mat
52
+
53
+ def filter_matrix(mat):
54
+ col_sum = np.sum(mat, axis=0)
55
+ if np.any(col_sum < 0.1):
56
+ return False
57
+ else:
58
+ return True
59
+
60
+ def dumpMatrix(chrom, binsNum, res, hicfile):
61
+ """
62
+ convert hic file to iced normalization sparse matrix
63
+ :param chrom: chrom ID
64
+ :param binsNum: the number of bin
65
+ :param hicfile: input hic file
66
+ :param res: resolution for hic data
67
+ """
68
+ # load mat form .hic file
69
+ result = hicstraw.straw('observed', 'KR', hicfile, str(chrom), str(chrom), 'BP', res)
70
+ sp_mat = constructSpaMat(result, binsNum, res)
71
+ normed_mat = normalizationMat(sp_mat, binsNum)
72
+ return normed_mat.tocsr()
@@ -0,0 +1,9 @@
1
+ import torch
2
+
3
+ def euclidean_distance(x, y):
4
+ """This is the squared Euclidean distance."""
5
+ return torch.sum((x - y) ** 2, dim=-1)
6
+
7
+ def pairwise_loss(x, y, labels, margin=1.0):
8
+ loss = torch.relu(margin - labels * (1 - euclidean_distance(x, y)))
9
+ return loss, euclidean_distance(x, y)
@@ -0,0 +1,192 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def DSN2(t):
8
+ a = t.sum(dim=1, keepdim=True)
9
+ b = t.sum(dim=0, keepdim=True)
10
+ lamb = torch.cat([a.squeeze(), b.squeeze()], dim=0).max()
11
+ r = t.shape[0] * lamb - t.sum(dim=0).sum(dim=0)
12
+ a = a.expand(-1, t.shape[1])
13
+ b = b.expand(t.shape[0], -1)
14
+ tt = t + (lamb ** 2 - lamb * (a + b) + a * b) / r
15
+ ttmatrix = tt / tt.sum(dim=0)[0]
16
+ ttmatrix = torch.where(t > 0, ttmatrix, t)
17
+ return ttmatrix
18
+
19
+
20
+ def DSN(x):
21
+ """Doubly stochastic normalization"""
22
+ p = x.shape[0]
23
+ y1 = []
24
+ for i in range(p):
25
+ y1.append(DSN2(x[i]))
26
+ y1 = torch.stack(y1, dim=0)
27
+ return y1
28
+
29
+
30
+ def unsorted_segment_sum(data, segment_ids, num_segments):
31
+ """
32
+ Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
33
+
34
+ :param data: A tensor whose segments are to be summed.
35
+ :param segment_ids: The segment indices tensor.
36
+ :param num_segments: The number of segments.
37
+ :return: A tensor of same data type as the data argument.
38
+ """
39
+
40
+ assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape"
41
+
42
+ # Encourage to use the below code when a deterministic result is
43
+ # needed (reproducibility). However, the code below is with low efficiency.
44
+
45
+ # tensor = torch.zeros(num_segments, data.shape[1], device=data.device)
46
+ # for index in range(num_segments):
47
+ # tensor[index, :] = torch.sum(data[segment_ids == index, :], dim=0)
48
+ # return tensor
49
+
50
+ if len(segment_ids.shape) == 1:
51
+ s = torch.prod(torch.tensor(data.shape[1:], device=data.device)).long()
52
+ segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])
53
+
54
+ assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
55
+
56
+ shape = [num_segments] + list(data.shape[1:])
57
+ tensor = torch.zeros(*shape, device=data.device).scatter_add(0, segment_ids, data)
58
+ tensor = tensor.type(data.dtype)
59
+ return tensor
60
+
61
+ # reference Beaconet (https://github.com/GaoLabXDU/Beaconet)
62
+ class BatchSpecificNorm(nn.Module):
63
+ def __init__(self, n_batches, feature_dim, eps=1e-8):
64
+ super(BatchSpecificNorm, self).__init__()
65
+ self.scale = nn.Embedding(n_batches, feature_dim)
66
+ self.shift = nn.Embedding(n_batches, feature_dim)
67
+ nn.init.ones_(self.scale.weight)
68
+ nn.init.zeros_(self.shift.weight)
69
+ self.eps = eps
70
+
71
+ def forward(self, x, batch_idx):
72
+ scale = self.scale(batch_idx)
73
+ shift = self.shift(batch_idx)
74
+ return x * scale + shift
75
+
76
+ class GraphAttentionLayer(nn.Module):
77
+ """
78
+ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
79
+ """
80
+ def __init__(self, in_features, out_features, dropout, alpha, concat=True):
81
+ super(GraphAttentionLayer, self).__init__()
82
+ self.in_features = in_features
83
+ self.out_features = out_features
84
+ self.alpha = alpha
85
+ self.concat = concat
86
+ self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
87
+ nn.init.xavier_uniform_(self.W.data, gain=1.414)
88
+ self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
89
+ nn.init.xavier_uniform_(self.a.data, gain=1.414)
90
+ self.leakyrelu = nn.LeakyReLU(self.alpha)
91
+
92
+ def forward(self, h, edge_attr):
93
+ Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
94
+ e = self._prepare_attentional_mechanism_input(Wh)
95
+ e = e * edge_attr
96
+ attention = DSN(e)
97
+ h_prime = []
98
+ for i in range(edge_attr.shape[0]):
99
+ h_prime.append(torch.matmul(attention[i], Wh))
100
+ if self.concat:
101
+ h_prime = torch.cat(h_prime, dim=1)
102
+ return F.elu(h_prime), e
103
+ else:
104
+ h_prime = torch.stack(h_prime, dim=0)
105
+ h_prime = torch.sum(h_prime, dim=0)
106
+ return h_prime, e
107
+
108
+ # compute attention coefficient
109
+ def _prepare_attentional_mechanism_input(self, Wh):
110
+ Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
111
+ Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
112
+ # broadcast add
113
+ e = Wh1 + Wh2.T
114
+ return self.leakyrelu(e)
115
+
116
+ def __repr__(self):
117
+ return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
118
+
119
+
120
+ class GraphAggregator(nn.Module):
121
+ """This module computes graph representations by aggregating from parts."""
122
+ def __init__(self,
123
+ node_hidden_sizes,
124
+ input_size):
125
+
126
+ super(GraphAggregator, self).__init__()
127
+ self._node_hidden_sizes = node_hidden_sizes
128
+ self._graph_state_dim = node_hidden_sizes[-1]
129
+ self._graph_transform_sizes = node_hidden_sizes[-1]
130
+ self._input_size = input_size
131
+ self.MLP1, self.MLP2 = self.build_model()
132
+
133
+ def build_model(self):
134
+ node_hidden_sizes = self._node_hidden_sizes
135
+ node_hidden_sizes[-1] = self._graph_state_dim * 2
136
+ layer = [nn.Linear(self._input_size[0], 64)]
137
+ layer.append(nn.ReLU())
138
+ layer.append(nn.Linear(64, node_hidden_sizes[0]))
139
+ MLP1 = nn.Sequential(*layer)
140
+ layer = []
141
+ layer.append(nn.Linear(self._graph_state_dim, 32))
142
+ layer.append(nn.ReLU())
143
+ layer.append(nn.Linear(32, 16))
144
+ MLP2 = nn.Sequential(*layer)
145
+ return MLP1, MLP2
146
+
147
+ def forward(self, node_states, graph_idx):
148
+ """Compute aggregated graph representations."""
149
+ node_states_g = self.MLP1(node_states)
150
+ gates = torch.sigmoid(node_states_g[:, :self._graph_state_dim])
151
+ node_states_g = node_states_g[:, self._graph_state_dim:] * gates
152
+ n_graphs = max(graph_idx) + 1
153
+ graph_states = unsorted_segment_sum(node_states_g, graph_idx, n_graphs)
154
+ graph_states = self.MLP2(graph_states)
155
+ return graph_states
156
+
157
+
158
+ class EdgeGNN(nn.Module):
159
+ def __init__(self, node_hidden_dims, edge_feature_dim, dropout, alpha, nheads):
160
+ super(EdgeGNN, self).__init__()
161
+ self.attentions = [GraphAttentionLayer(node_hidden_dims[0], node_hidden_dims[1], dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads[0])]
162
+ for i, attention in enumerate(self.attentions):
163
+ self.add_module('attention_{}'.format(i), attention)
164
+ self.out_att = GraphAttentionLayer(node_hidden_dims[1] * nheads[0] * edge_feature_dim, node_hidden_dims[2], dropout=dropout, alpha=alpha,
165
+ concat=False)
166
+ self.bs_norm = BatchSpecificNorm(n_batches=8, feature_dim=16)
167
+ self.bn1 = nn.BatchNorm1d(node_hidden_dims[2])
168
+ self.bn2 = nn.BatchNorm1d(node_hidden_dims[1] * nheads[0] * edge_feature_dim)
169
+ self.aggregator = GraphAggregator(node_hidden_sizes=[node_hidden_dims[3]], input_size=[node_hidden_dims[2]])
170
+ self.batch_norm = nn.BatchNorm1d(node_hidden_dims[2])
171
+ def forward(self, node_features, edge_features, graph_idx, depth_idx):
172
+ x = node_features
173
+ n_nodes = x.shape[0]
174
+ splits = edge_features.split(1, dim=1)
175
+ edge_attr = [split.view(n_nodes, n_nodes) for split in splits]
176
+ edge_attr = torch.stack(edge_attr, dim=0)
177
+ edge_attr = DSN(edge_attr)
178
+
179
+ temp_x = []
180
+ for att in self.attentions:
181
+ inn_x, edge_attr = att(x, edge_attr)
182
+ temp_x.append(inn_x)
183
+
184
+ x = torch.cat(temp_x, dim=1)
185
+ x, edge_attr = self.out_att(x, edge_attr)
186
+ x = self.bn1(x)
187
+ x = F.elu(x)
188
+ x = self.bs_norm(x, depth_idx)
189
+ x = F.relu(x)
190
+ graph_states = self.aggregator(x, graph_idx)
191
+ graph_states = self.batch_norm(graph_states)
192
+ return graph_states, edge_attr
@@ -0,0 +1,107 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+
5
+ def load_TAD_region(TADfile1, TADfile2):
6
+ # load diffTADs
7
+ region1 = pd.read_csv(TADfile1, sep='\t')
8
+ region2 = pd.read_csv(TADfile2, sep='\t')
9
+ region2 = region2[['chrom', 'start', 'end']]
10
+ # combine
11
+ all_TAD_region = pd.concat([region1, region2], ignore_index=True)
12
+ all_TAD_region = all_TAD_region.sort_values(by=['chrom', 'start'])
13
+ return all_TAD_region
14
+
15
+ def Slice_matrix(mat, start, end):
16
+ # Getting matrix slice with a TAD
17
+ row = mat.shape[0]
18
+ if start > row or end > row:
19
+ print(start, end, row)
20
+ raise ValueError("invalid TAD boundary")
21
+ else:
22
+ cut_mat = mat[start:end + 1, start:end + 1]
23
+ return cut_mat.toarray()
24
+
25
+ def filter_matrix(mat):
26
+ # filter invalid TADs (with outlier)
27
+ col_sum = np.sum(mat, axis=0)
28
+ rows = np.shape(mat)[0]
29
+ if np.any(col_sum < 0.1):
30
+ return False
31
+ else:
32
+ return True
33
+
34
+ def reshape_and_split_tensor(tensor, n_splits):
35
+ """Reshape and split a 2D tensor along the last dimension.
36
+
37
+ Args:
38
+ tensor: a [num_examples, feature_dim] tensor. num_examples must be a
39
+ multiple of `n_splits`.
40
+ n_splits: int, number of splits to split the tensor into.
41
+
42
+ Returns:
43
+ splits: a list of `n_splits` tensors. The first split is [tensor[0],
44
+ tensor[n_splits], tensor[n_splits * 2], ...], the second split is
45
+ [tensor[1], tensor[n_splits + 1], tensor[n_splits * 2 + 1], ...], etc..
46
+ """
47
+ feature_dim = tensor.shape[-1]
48
+ tensor = torch.reshape(tensor, [-1, feature_dim * n_splits])
49
+ tensor_split = []
50
+ for i in range(n_splits):
51
+ tensor_split.append(tensor[:, feature_dim * i: feature_dim * (i + 1)])
52
+ return tensor_split
53
+
54
+
55
+ def get_graphs(graphs, n_features, graph_idx, depth_idx, labels, edge_features):
56
+ adj = torch.FloatTensor(graphs)
57
+ flattend_adj = adj.view(-1)
58
+ reshaped_adj = flattend_adj.view(-1, 1)
59
+ edge_features = reshaped_adj.repeat(1, edge_features)
60
+ node_features = torch.FloatTensor(n_features)
61
+ graph_idx = torch.from_numpy(graph_idx).long()
62
+ depth_idx = torch.from_numpy(depth_idx).long()
63
+ labels = torch.FloatTensor(labels)
64
+ return edge_features, node_features, graph_idx, depth_idx, labels
65
+
66
+ def pack_batches(graphs, depths):
67
+ n_graph = len(graphs)
68
+ # init adj matrix
69
+ sum_node = 0
70
+ for i in range(n_graph):
71
+ cur_graph = graphs[i][0]
72
+ cur_node = np.shape(cur_graph)[0]
73
+ sum_node += cur_node
74
+ combine_adj = np.zeros((sum_node*2, sum_node*2))
75
+ # add
76
+ graph_idx = []
77
+ depth_idx = []
78
+ cur_row = 0
79
+ idx = 0
80
+ for i in range(n_graph):
81
+ graph_1 = graphs[i][0]
82
+ graph_2 = graphs[i][1]
83
+ cur_depth1 = depths[i][0]
84
+ cur_depth2 = depths[i][1]
85
+ cur_node = np.shape(graph_1)[0]
86
+ combine_adj[cur_row:cur_row + cur_node, cur_row:cur_row + cur_node] = graph_1
87
+ cur_row += cur_node
88
+ combine_adj[cur_row:cur_row + cur_node, cur_row:cur_row + cur_node] = graph_2
89
+ cur_row += cur_node
90
+ graph_idx.append(np.ones(cur_node, dtype=np.int32) * idx)
91
+ depth_idx.append(np.ones(cur_node, dtype=np.int32) * cur_depth1)
92
+ idx += 1
93
+ graph_idx.append(np.ones(cur_node, dtype=np.int32) * idx)
94
+ depth_idx.append(np.ones(cur_node, dtype=np.int32) * cur_depth2)
95
+ idx += 1
96
+
97
+ depth_idx = np.concatenate(depth_idx, axis=0)
98
+ graph_idx = np.concatenate(graph_idx, axis=0)
99
+ node_features = np.ones((sum_node*2, 8), dtype=np.float64)
100
+ return combine_adj, node_features, graph_idx, depth_idx
101
+
102
+ def generate_valid_batches(idx, graphs, labels, depths, bs):
103
+ batch_graphs = graphs[idx:idx + bs]
104
+ batch_depths = depths[idx:idx + bs]
105
+ batch_graphs, batch_features, graphs_idx, depth_idx = pack_batches(batch_graphs, batch_depths)
106
+ batch_labels = labels[idx:idx + bs]
107
+ return batch_graphs, batch_features, batch_labels, graphs_idx, depth_idx
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.1
2
+ Name: hidt
3
+ Version: 0.1.0
4
+ Summary: A computational pipeline for identifying differential TADs from 3D genome contact maps
5
+ Home-page: https://github.com/GaoLabXDU/HiDT
6
+ Author: Li Junping
7
+ Author-email: lijunping02@qq.com
8
+ License: MIT Licence
9
+ Keywords: 3D genome,Comparative analysis,Topologically associating domains
10
+ Platform: any
11
+ License-File: LICENSE
@@ -0,0 +1,14 @@
1
+ LICENSE
2
+ setup.py
3
+ hidt/__init__.py
4
+ hidt/evaluation.py
5
+ hidt/load_hic_format.py
6
+ hidt/loss.py
7
+ hidt/model.py
8
+ hidt/utils.py
9
+ hidt.egg-info/PKG-INFO
10
+ hidt.egg-info/SOURCES.txt
11
+ hidt.egg-info/dependency_links.txt
12
+ hidt.egg-info/requires.txt
13
+ hidt.egg-info/top_level.txt
14
+ scripts/HiDT
@@ -0,0 +1,4 @@
1
+ numpy
2
+ pandas
3
+ scikit-learn
4
+ hic-straw==1.3.1
@@ -0,0 +1 @@
1
+ hidt
@@ -0,0 +1,267 @@
1
+ #!/usr/bin/env python
2
+
3
+ from hidt.utils import *
4
+ from hidt.model import EdgeGNN
5
+ from hidt.loss import *
6
+ from hidt.evaluation import *
7
+ import pandas as pd
8
+ import numpy as np
9
+ from hidt.load_hic_format import *
10
+ import os, argparse, sys
11
+
12
+ def load_TAD(TADfile, res):
13
+ """
14
+ Load TADs assuming the first 3 columns are chrom, x1, x2.
15
+ Check that the file has at least 3 columns.
16
+ """
17
+ try:
18
+ df = pd.read_csv(TADfile, delim_whitespace=True)
19
+ if df.shape[1] < 3:
20
+ raise ValueError(f"TAD file '{TADfile}' has fewer than 3 columns.")
21
+ df = df.iloc[:, :3]
22
+ df.columns = ['chrom', 'x1', 'x2']
23
+ df['chrom'] = df['chrom'].astype(str).str.replace('chr', '', regex=False)
24
+ df['x1'] = df['x1'] // res
25
+ df['x2'] = df['x2'] // res
26
+ TADInfo = {}
27
+ chrom_list = []
28
+ for chrom, group in df.groupby('chrom'):
29
+ TADInfo[chrom] = group[['x1', 'x2']].reset_index(drop=True)
30
+ chrom_list.append(chrom)
31
+ return TADInfo, chrom_list
32
+ except Exception as e:
33
+ print(f"Error loading TAD file: {e}")
34
+ sys.exit(1)
35
+
36
+ def scale_total_contacts(total_contacts, resolution):
37
+ base_resolution = 25000
38
+ if resolution == base_resolution:
39
+ return total_contacts
40
+ elif resolution > base_resolution:
41
+ scale_factor = resolution // base_resolution
42
+ return int(total_contacts * scale_factor)
43
+ else:
44
+ scale_factor = base_resolution // resolution
45
+ return int(total_contacts // scale_factor)
46
+
47
+ def load_intra_counts_hic(hicfile):
48
+ hic = hicstraw.HiCFile(hicfile)
49
+ chromosomes = hic.getChromosomes()
50
+ chrom_names = [chrom.name for chrom in chromosomes]
51
+ chrom_names = [chrom for chrom in chrom_names if chrom not in ['All', 'Y', 'MT', 'ALL', 'chrY', 'chrM', 'M']]
52
+ # total intra-chromosomal counts
53
+ total_contacts = 0
54
+ for chrom in chrom_names:
55
+ result = hicstraw.straw('observed', 'NONE', hicfile, chrom, chrom, 'BP', resolution)
56
+ for i in range(len(result)):
57
+ total_contacts += result[i].counts
58
+ # scale contacts based on resolution
59
+ total_contacts = scale_total_contacts(total_contacts, resolution)
60
+ print(f"Hi-C file '{hicfile}' - Total intra contacts: {total_contacts}")
61
+ if total_contacts < 50000000:
62
+ depth = 0
63
+ elif 50000000 <= total_contacts < 100000000:
64
+ depth = 1
65
+ elif 100000000 <= total_contacts < 200000000:
66
+ depth = 2
67
+ elif 200000000 <= total_contacts < 250000000:
68
+ depth = 3
69
+ elif 250000000 <= total_contacts < 450000000:
70
+ depth = 4
71
+ elif 450000000 <= total_contacts < 500000000:
72
+ depth = 5
73
+ elif 500000000 <= total_contacts < 600000000:
74
+ depth = 6
75
+ else:
76
+ depth = 7
77
+ return depth
78
+
79
+
80
+ def load_total_counts_hic(hicfile):
81
+ hic = hicstraw.HiCFile(hicfile)
82
+ chromosomes = hic.getChromosomes()
83
+ chrom_names = [chrom.name for chrom in chromosomes]
84
+ chrom_names = [chrom for chrom in chrom_names if chrom not in ['All', 'Y', 'MT', 'ALL', 'chrY', 'chrM', 'M']]
85
+ # total counts
86
+ total_contacts = 0
87
+ for chrom1 in chrom_names:
88
+ for chrom2 in chrom_names:
89
+ if chrom1 < chrom2:
90
+ continue
91
+ result = hicstraw.straw('observed', 'NONE', hicfile, chrom1, chrom2, 'BP', resolution)
92
+ for i in range(len(result)):
93
+ total_contacts += result[i].counts
94
+
95
+ total_contacts = scale_total_contacts(total_contacts, resolution)
96
+ print(f"Hi-C file '{hicfile}' - Total contacts: {total_contacts}")
97
+ if total_contacts < 50000000:
98
+ depth = 0
99
+ elif 50000000 <= total_contacts < 100000000:
100
+ depth = 1
101
+ elif 100000000 <= total_contacts < 200000000:
102
+ depth = 2
103
+ elif 200000000 <= total_contacts < 300000000:
104
+ depth = 3
105
+ elif 300000000 <= total_contacts < 400000000:
106
+ depth = 4
107
+ elif 400000000 <= total_contacts < 650000000:
108
+ depth = 5
109
+ elif 650000000 <= total_contacts < 900000000:
110
+ depth = 6
111
+ else:
112
+ depth = 7
113
+ return depth
114
+
115
+
116
+ def check_file(path, desc, suffix=None):
117
+ if not os.path.isfile(path):
118
+ raise FileNotFoundError(f"{desc} not found: {path}")
119
+ if suffix and not path.endswith(suffix):
120
+ raise ValueError(f"{desc} must be a {suffix} file: {path}")
121
+
122
+ def check_resolution_in_hic(hic_path, resolution):
123
+ hic = hicstraw.HiCFile(hic_path)
124
+ available_res = hic.getResolutions()
125
+ if resolution not in available_res:
126
+ raise ValueError(f"Resolution {resolution} not found in {hic_path}. "
127
+ f"Available: {available_res}")
128
+
129
+ def getargs():
130
+ ## Construct an ArgumentParser object for command-line arguments
131
+ parser = argparse.ArgumentParser(description='Identify differential TADs from Hi-C contact maps.',
132
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
133
+
134
+ parser.add_argument('--hicfile1', help='Hi-C file (in .hic format) for condition 1.')
135
+ parser.add_argument('--hicfile2', help='Hi-C file (in .hic format) for condition 2.')
136
+ parser.add_argument('--TADfile', help='TAD boundary file used for differential analysis.')
137
+ parser.add_argument('--res', type=int, help='Resolution of the Hi-C contact maps (e.g., 25000 for 25 kb).')
138
+ parser.add_argument('--depth', type=str, default='intra', help='Method to compute sequencing depth: "intra" for intra-chromosomal counts or "total" for all contacts.')
139
+ parser.add_argument('--output', help='Path to the output result file.')
140
+
141
+ ## Parse the command-line arguments
142
+ commands = sys.argv[1:]
143
+ if not commands:
144
+ commands.append('-h')
145
+ args = parser.parse_args(commands)
146
+ return args, commands
147
+
148
+
149
+ if __name__ == '__main__':
150
+ args, commands = getargs()
151
+ check_file(args.TADfile, "TAD file")
152
+ check_file(args.hicfile1, "Hi-C file 1")
153
+ check_file(args.hicfile2, "Hi-C file 2")
154
+ check_resolution_in_hic(args.hicfile1, args.res)
155
+ check_resolution_in_hic(args.hicfile2, args.res)
156
+
157
+ # TAD_file = "/mnt/d/detectTAD/validation_sample/DipC/cortex_hipp_celltype/Cortical_L2a5_Pyramidal_Cell/Cortical_L2a5_Pyramidal_Cell.bed"
158
+ # resolution = 50000
159
+ # TADInfo, chroms = load_TAD(TAD_file, resolution)
160
+ resolution = args.res
161
+ result_file = args.output
162
+ TADInfo, chroms = load_TAD(args.TADfile, resolution)
163
+ hicfile1 = args.hicfile1
164
+ hicfile2 = args.hicfile2
165
+
166
+ # load hic file
167
+ # hicfile_1 = "/mnt/d/detectTAD/validation_sample/DipC/cortex_hipp_celltype/Cortical_L2a5_Pyramidal_Cell.hic"
168
+ # hicfile_2 = "/mnt/d/detectTAD/validation_sample/DipC/cortex_hipp_celltype/Cortical_L6_Pyramidal_Cell.hic"
169
+ # result_file = "/mnt/d/detectTAD/validation_sample/HiDT_rep_result/DipC/Cortical_L2a5_Pyramidal_Cell_result.txt"
170
+ if args.depth == 'intra':
171
+ hic1_depth = load_intra_counts_hic(hicfile1)
172
+ hic2_depth = load_intra_counts_hic(hicfile2)
173
+ elif args.depth == 'total':
174
+ hic1_depth = load_total_counts_hic(hicfile1)
175
+ hic2_depth = load_total_counts_hic(hicfile2)
176
+ chrom_bins = load_binsNum(hicfile1, resolution)
177
+ exclude_chroms = {'All', 'Y', 'MT', 'ALL', 'chrY', 'chrM', 'M'}
178
+ adjs = []
179
+ labels = []
180
+ depths = []
181
+ chrom_list = []
182
+ start_list = []
183
+ end_list = []
184
+ for i in range(len(chroms)):
185
+ chrom = chroms[i]
186
+ if any(exclude == chrom for exclude in exclude_chroms):
187
+ continue
188
+ print(f"Processing chromosome: {chrom}")
189
+ normed_mat_1 = dumpMatrix(chrom, chrom_bins[chrom], resolution, hicfile1)
190
+ normed_mat_2 = dumpMatrix(chrom, chrom_bins[chrom], resolution, hicfile2)
191
+ positions = TADInfo[chrom]
192
+ for start, end in zip(positions['x1'], positions['x2']):
193
+ normed_M1 = normed_mat_1[start:end, start:end].toarray()
194
+ normed_M2 = normed_mat_2[start:end, start:end].toarray()
195
+ if filter_matrix(normed_M1) and filter_matrix(normed_M2):
196
+ # save TAD information
197
+ chrom_list.append(chrom)
198
+ start_list.append(start*resolution)
199
+ end_list.append(end*resolution)
200
+ # save graphs
201
+ adjs.append((normed_M1, normed_M2))
202
+ depths.append((hic1_depth, hic2_depth))
203
+ labels.append(-1)
204
+ valid_graphs = adjs
205
+ valid_labels = labels
206
+ valid_depths = depths
207
+
208
+ # run model
209
+ use_cuda = torch.cuda.is_available()
210
+ device = torch.device('cuda:0' if use_cuda else 'cpu')
211
+ edge_feature_dim = 8
212
+ model = EdgeGNN(node_hidden_dims=[8, 64, 16, 64],
213
+ edge_feature_dim=edge_feature_dim,
214
+ dropout=0.5,
215
+ nheads=[1],
216
+ alpha=0.2)
217
+ model.to(device)
218
+ model.load_state_dict(torch.load('/home/li/detectTAD/DiffGNN/new_trained_model_25K.pth'))
219
+ model.eval()
220
+ batch_size = 40
221
+ simis = []
222
+ all_label = []
223
+ with torch.no_grad():
224
+ accumulated_pair_auc = []
225
+ for k_iter in range(0, len(valid_graphs), batch_size):
226
+ batch_graphs, batch_features, batch_labels, graphs_idx, depth_idx = generate_valid_batches(k_iter,
227
+ valid_graphs,
228
+ valid_labels,
229
+ valid_depths,
230
+ batch_size)
231
+ batch_labels = np.array(batch_labels)
232
+ cur_edge_features, cur_node_features, cur_graphs_idx, cur_depth_idx, cur_batch_labels = get_graphs(batch_graphs,
233
+ batch_features,
234
+ graphs_idx,
235
+ depth_idx,
236
+ batch_labels,
237
+ edge_feature_dim)
238
+
239
+ graph_states, edges = model(cur_node_features.to(device),
240
+ cur_edge_features.to(device),
241
+ cur_graphs_idx.to(device),
242
+ cur_depth_idx.to(device))
243
+ x, y = reshape_and_split_tensor(graph_states, 2)
244
+ similarity = compute_similarity(x, y)
245
+ for elem in similarity:
246
+ simis.append(-elem.item())
247
+ for cur_label in cur_batch_labels:
248
+ all_label.append(cur_label.item())
249
+
250
+ count = 0
251
+ pos = []
252
+ neg = []
253
+ pos_count = 0
254
+ for i in range(len(simis)):
255
+ pos_count += 1
256
+ if simis[i] > 1:
257
+ count += 1
258
+ print(
259
+ f"Total TADs: {pos_count}, "
260
+ f"Differential TADs: {count}, "
261
+ f"Percentage: {count / pos_count * 100:.2f}%"
262
+ )
263
+ # Save result
264
+ with open(result_file, 'w') as out:
265
+ for i in range(len(simis)):
266
+ out.write(chrom_list[i] + '\t' + str(start_list[i]) + '\t' + str(end_list[i]) + '\t' + str(simis[i]) + '\n')
267
+ print(f"Results have been saved to '{result_file}'.")
hidt-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
hidt-0.1.0/setup.py ADDED
@@ -0,0 +1,45 @@
1
+
2
+ """
3
+ Setup script for HiDT.
4
+
5
+ """
6
+ import os, sys, glob
7
+ import setuptools
8
+
9
+ def read(fname):
10
+ return open(os.path.join(os.path.dirname(__file__), fname)).read()
11
+
12
+ if (sys.version_info.major < 3) or (sys.version_info.major == 3 and sys.version_info.minor < 7):
13
+ print(
14
+ f"Python >=3.7 is required. You are currently using Python {sys.version.split()[0]}"
15
+ )
16
+ sys.exit(2)
17
+
18
+ # Guarantee Unix Format
19
+ for src in glob.glob('scripts/*'):
20
+ text = open(src, 'r').read().replace('\r\n', '\n')
21
+ open(src, 'w').write(text)
22
+
23
+ setuptools.setup(
24
+ name = 'hidt',
25
+ version = "0.1.0",
26
+ author = "Li Junping",
27
+ author_email = 'lijunping02@qq.com',
28
+ url = 'https://github.com/GaoLabXDU/HiDT',
29
+ description = 'A computational pipeline for identifying differential TADs from 3D genome contact maps',
30
+ keywords = ("3D genome", "Comparative analysis", "Topologically associating domains"),
31
+ scripts = glob.glob('scripts/*'),
32
+ packages = setuptools.find_packages(),
33
+ include_package_data = True,
34
+ package_data={
35
+ "your_package.model": ["hidt/pretrained_model.pth"],
36
+ },
37
+ platforms = "any",
38
+ license="MIT Licence",
39
+ install_requires = [
40
+ "numpy",
41
+ "pandas",
42
+ "scikit-learn",
43
+ "hic-straw==1.3.1"
44
+ ]
45
+ )