hjxdl 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hdl/_version.py +2 -2
- hdl/datasets/city_code.json +2576 -0
- hdl/datasets/defined_BaseFeatures.fdef +236 -0
- hdl/datasets/las.tsv +0 -0
- hdl/datasets/route_template.json +113 -0
- hdl/datasets/vocab.txt +591 -0
- hdl/ju/__init__.py +0 -0
- hdl/ju/setup.py +55 -0
- hdl/jupyfuncs/__init__.py +0 -0
- hdl/jupyfuncs/chem/__init__.py +0 -0
- hdl/jupyfuncs/chem/mol.py +548 -0
- hdl/jupyfuncs/chem/norm.py +268 -0
- hdl/jupyfuncs/chem/pdb_ext.py +94 -0
- hdl/jupyfuncs/chem/scaffold.py +25 -0
- hdl/jupyfuncs/chem/shape.py +241 -0
- hdl/jupyfuncs/chem/tokenizers.py +2 -0
- hdl/jupyfuncs/dbtools/__init__.py +0 -0
- hdl/jupyfuncs/dbtools/pg.py +42 -0
- hdl/jupyfuncs/dbtools/query_info.py +150 -0
- hdl/jupyfuncs/dl/__init__.py +0 -0
- hdl/jupyfuncs/dl/cp.py +54 -0
- hdl/jupyfuncs/dl/dataframe.py +38 -0
- hdl/jupyfuncs/dl/fp.py +49 -0
- hdl/jupyfuncs/dl/list.py +20 -0
- hdl/jupyfuncs/dl/model_utils.py +97 -0
- hdl/jupyfuncs/dl/tensor.py +159 -0
- hdl/jupyfuncs/dl/uncs.py +112 -0
- hdl/jupyfuncs/llm/__init__.py +0 -0
- hdl/jupyfuncs/llm/extract.py +123 -0
- hdl/jupyfuncs/llm/openapi.py +94 -0
- hdl/jupyfuncs/network/__init__.py +0 -0
- hdl/jupyfuncs/network/proxy.py +20 -0
- hdl/jupyfuncs/path/__init__.py +0 -0
- hdl/jupyfuncs/path/glob.py +285 -0
- hdl/jupyfuncs/path/strings.py +65 -0
- hdl/jupyfuncs/show/__init__.py +0 -0
- hdl/jupyfuncs/show/pbar.py +50 -0
- hdl/jupyfuncs/show/plot.py +259 -0
- hdl/jupyfuncs/utils/__init__.py +0 -0
- hdl/jupyfuncs/utils/wrappers.py +8 -0
- hdl/utils/weather/__init__.py +0 -0
- hdl/utils/weather/weather.py +68 -0
- {hjxdl-0.1.13.dist-info → hjxdl-0.1.15.dist-info}/METADATA +1 -1
- {hjxdl-0.1.13.dist-info → hjxdl-0.1.15.dist-info}/RECORD +46 -5
- {hjxdl-0.1.13.dist-info → hjxdl-0.1.15.dist-info}/WHEEL +1 -1
- {hjxdl-0.1.13.dist-info → hjxdl-0.1.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
import cirpy
|
4
|
+
import pubchempy as pcp
|
5
|
+
import molvs as mv
|
6
|
+
from psycopg import sql
|
7
|
+
|
8
|
+
from .pg import connect_by_infofile
|
9
|
+
|
10
|
+
|
11
|
+
def query_from_cir(query_name: str):
|
12
|
+
smiles = None
|
13
|
+
# cas_list = []
|
14
|
+
# name_list = [host=172.20.0.5 dbname=pistachio port=5432 user=postgres password=woshipostgres]
|
15
|
+
|
16
|
+
cas_list = cirpy.resolve(query_name, 'cas')
|
17
|
+
if cas_list is None or not cas_list:
|
18
|
+
cas_list = []
|
19
|
+
if isinstance(cas_list, str):
|
20
|
+
cas_list = [cas_list]
|
21
|
+
|
22
|
+
name_list = cirpy.resolve(query_name, 'names')
|
23
|
+
if name_list is None or not name_list:
|
24
|
+
name_list = []
|
25
|
+
if isinstance(name_list, str):
|
26
|
+
name_list = [name_list]
|
27
|
+
|
28
|
+
smiles = cirpy.resolve(query_name, 'smiles')
|
29
|
+
try:
|
30
|
+
smiles = mv.standardize_smiles(smiles)
|
31
|
+
except Exception as e:
|
32
|
+
print(e)
|
33
|
+
|
34
|
+
return smiles, cas_list, name_list
|
35
|
+
|
36
|
+
|
37
|
+
def query_from_pubchem(query_name: str):
|
38
|
+
results = pcp.get_compounds(query_name, 'name')
|
39
|
+
smiles = None
|
40
|
+
name_list = set()
|
41
|
+
cas_list = set()
|
42
|
+
|
43
|
+
if any(results):
|
44
|
+
try:
|
45
|
+
smiles = mv.standardize_smiles(results[0].canonical_smiles)
|
46
|
+
except Exception as e:
|
47
|
+
smiles = results[0].canonical_smiles
|
48
|
+
print(smiles)
|
49
|
+
print(e)
|
50
|
+
for compound in results:
|
51
|
+
name_list.update(set(compound.synonyms))
|
52
|
+
for syn in compound.synonyms:
|
53
|
+
match = re.match('(\d{2,7}-\d\d-\d)', syn)
|
54
|
+
if match:
|
55
|
+
cas_list.add(match.group(1))
|
56
|
+
|
57
|
+
cas_list = list(cas_list)
|
58
|
+
name_list = list(name_list)
|
59
|
+
|
60
|
+
return smiles, cas_list, name_list
|
61
|
+
|
62
|
+
|
63
|
+
def query_a_compound(
|
64
|
+
query_name: str,
|
65
|
+
connect_info: str,
|
66
|
+
by: str = 'name',
|
67
|
+
log_file: str = './err.log'
|
68
|
+
):
|
69
|
+
fei = None
|
70
|
+
found = False
|
71
|
+
|
72
|
+
if by != 'smiles':
|
73
|
+
query_name = query_name.lower()
|
74
|
+
|
75
|
+
by = 'name'
|
76
|
+
table = by + '_maps'
|
77
|
+
# query_name = 'adipic acid'
|
78
|
+
query = sql.SQL(
|
79
|
+
"select fei from {table} where {by} = %s"
|
80
|
+
).format(
|
81
|
+
table=sql.Identifier(table),
|
82
|
+
by=sql.Identifier(by)
|
83
|
+
)
|
84
|
+
conn = connect_by_infofile(connect_info)
|
85
|
+
|
86
|
+
cur = conn.execute(query, [query_name]).fetchone()
|
87
|
+
|
88
|
+
if cur is not None:
|
89
|
+
fei = cur[0]
|
90
|
+
found = True
|
91
|
+
return fei
|
92
|
+
|
93
|
+
if not found:
|
94
|
+
try:
|
95
|
+
smiles, cas_list, name_list = query_from_pubchem(query_name)
|
96
|
+
except Exception as e:
|
97
|
+
print(e)
|
98
|
+
smiles, cas_list, name_list = None, [], []
|
99
|
+
if smiles is not None:
|
100
|
+
found = True
|
101
|
+
else:
|
102
|
+
try:
|
103
|
+
smiles, cas_list, name_list = query_from_cir(query_name)
|
104
|
+
except Exception as e:
|
105
|
+
print(e)
|
106
|
+
smiles, cas_list, name_list = None, [], []
|
107
|
+
if smiles is not None:
|
108
|
+
found = True
|
109
|
+
|
110
|
+
if not found:
|
111
|
+
with open(log_file, 'a') as f:
|
112
|
+
f.write(query_name)
|
113
|
+
f.write('\n')
|
114
|
+
return
|
115
|
+
# raise ValueError('给的啥破玩意儿查都查不着!')
|
116
|
+
else:
|
117
|
+
query_compound = sql.SQL(
|
118
|
+
"select fei from compounds where smiles = %s"
|
119
|
+
)
|
120
|
+
cur = conn.execute(query_compound, [smiles]).fetchone()
|
121
|
+
if cur is not None:
|
122
|
+
fei = cur[0]
|
123
|
+
elif any(cas_list):
|
124
|
+
fei = cas_list[0]
|
125
|
+
insert_compounds_sql = sql.SQL(
|
126
|
+
"INSERT INTO compounds (fei, smiles) VALUES (%s, %s) ON CONFLICT (fei) DO NOTHING"
|
127
|
+
)
|
128
|
+
conn.execute(insert_compounds_sql, [fei, smiles])
|
129
|
+
for cas in cas_list:
|
130
|
+
insert_cas_map_sql = sql.SQL(
|
131
|
+
"INSERT INTO cas_maps (fei, cas) VALUES (%s, %s) ON CONFLICT (cas) DO NOTHING"
|
132
|
+
)
|
133
|
+
try:
|
134
|
+
conn.execute(insert_cas_map_sql, [fei, cas])
|
135
|
+
except Exception as e:
|
136
|
+
print(e)
|
137
|
+
for name in name_list:
|
138
|
+
insert_name_map_sql = sql.SQL(
|
139
|
+
"INSERT INTO name_maps (fei, name) VALUES (%s, %s) ON CONFLICT (fei, name) DO NOTHING"
|
140
|
+
# "INSERT INTO name_maps (fei, name) VALUES (%s, %s)"
|
141
|
+
)
|
142
|
+
try:
|
143
|
+
conn.execute(insert_name_map_sql, [fei, name.lower()])
|
144
|
+
except Exception as e:
|
145
|
+
print(e)
|
146
|
+
|
147
|
+
conn.commit()
|
148
|
+
conn.close()
|
149
|
+
|
150
|
+
return fei
|
File without changes
|
hdl/jupyfuncs/dl/cp.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
"""
|
2
|
+
Conformer Prediction for classification Task
|
3
|
+
|
4
|
+
As we only consider the predicted values as input,
|
5
|
+
we do not differentiate between transductive and inductive conformers
|
6
|
+
"""
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import pandas as pd
|
10
|
+
|
11
|
+
|
12
|
+
class CpClassfier:
|
13
|
+
def __init__(self):
|
14
|
+
self.cal_data = None
|
15
|
+
self.class_num = 0
|
16
|
+
|
17
|
+
def fit_with_data(self, cal_proba, cal_y, class_num=0):
|
18
|
+
"""
|
19
|
+
:parm cal_proba: numpy array of shape [n_samples, n_classes]
|
20
|
+
predicted probability of calibration set
|
21
|
+
:parm cal_y: numpy array of shape [n_samples,]
|
22
|
+
true label of calibration set
|
23
|
+
"""
|
24
|
+
|
25
|
+
if class_num <= 0:
|
26
|
+
print("Get class number for input data.")
|
27
|
+
self.class_num = cal_proba.shape[1]
|
28
|
+
else:
|
29
|
+
self.class_num = class_num
|
30
|
+
|
31
|
+
cal_df = pd.DataFrame(cal_proba)
|
32
|
+
cal_df.columns = ["class_%d"%i for i in range(self.class_num)]
|
33
|
+
cal_df["true_label"] = list(cal_y)
|
34
|
+
self.cal_data = cal_df
|
35
|
+
|
36
|
+
def fit_with_model(self, func):
|
37
|
+
#TODO: besides calibration data, we can also use our model
|
38
|
+
pass
|
39
|
+
|
40
|
+
def predict_with_proba(self, X_proba):
|
41
|
+
"""
|
42
|
+
:parm X_proba: numpy array of shape [n_samples, n_classes]
|
43
|
+
predicted probabilities of calibration set
|
44
|
+
"""
|
45
|
+
cp_proba = []
|
46
|
+
class_lsts = [sorted(self.cal_data[self.cal_data["true_label"] == i]["class_%d"%i]) \
|
47
|
+
for i in range(self.class_num)]
|
48
|
+
proba_lsts = [X_proba[:, i] for i in range(self.class_num)]
|
49
|
+
for c_lst, p_lst in zip(class_lsts, proba_lsts):
|
50
|
+
c_proba = np.searchsorted(c_lst, p_lst, side='left')/len(c_lst)
|
51
|
+
cp_proba.append(c_proba)
|
52
|
+
|
53
|
+
self.cp_P = np.array(cp_proba).T
|
54
|
+
return self.cp_P
|
@@ -0,0 +1,38 @@
|
|
1
|
+
def rm_index(df):
|
2
|
+
"""Remove columns with 'Unnamed' in their names from the DataFrame.
|
3
|
+
|
4
|
+
Args:
|
5
|
+
df (pandas.DataFrame): The input DataFrame.
|
6
|
+
|
7
|
+
Returns:
|
8
|
+
pandas.DataFrame: DataFrame with columns containing 'Unnamed' removed.
|
9
|
+
"""
|
10
|
+
return df.loc[:, ~df.columns.str.match('Unnamed')]
|
11
|
+
|
12
|
+
|
13
|
+
def rm_col(df, col_name):
|
14
|
+
"""Remove a column from a DataFrame.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
df (pandas.DataFrame): The input DataFrame.
|
18
|
+
col_name (str): The name of the column to be removed.
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
pandas.DataFrame: A new DataFrame with the specified column removed.
|
22
|
+
"""
|
23
|
+
return df.loc[:, ~df.columns.str.match(col_name)]
|
24
|
+
|
25
|
+
|
26
|
+
def shuffle_df(df):
|
27
|
+
"""Shuffle the rows of a DataFrame.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
df (pandas.DataFrame): The input DataFrame to shuffle.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
pandas.DataFrame: A new DataFrame with rows shuffled.
|
34
|
+
|
35
|
+
Example:
|
36
|
+
shuffled_df = shuffle_df(df)
|
37
|
+
"""
|
38
|
+
return df.sample(frac=1).reset_index(drop=True)
|
hdl/jupyfuncs/dl/fp.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
from rdkit import Chem
|
2
|
+
import numpy as np
|
3
|
+
from rdkit.Chem import AllChem
|
4
|
+
from rdkit.Chem import MACCSkeys
|
5
|
+
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
'get_fp',
|
9
|
+
]
|
10
|
+
|
11
|
+
|
12
|
+
def get_rdnorm_fp(smiles):
|
13
|
+
from descriptastorus.descriptors import rdNormalizedDescriptors
|
14
|
+
generator = rdNormalizedDescriptors.RDKit2DNormalized()
|
15
|
+
features = generator.process(smiles)[1:]
|
16
|
+
arr = np.array(features)
|
17
|
+
return arr
|
18
|
+
|
19
|
+
|
20
|
+
def get_maccs_fp(smiles):
|
21
|
+
arr = np.zeros(167)
|
22
|
+
try:
|
23
|
+
mol = Chem.MolFromSmiles(smiles)
|
24
|
+
vec = MACCSkeys.GenMACCSKeys(mol)
|
25
|
+
bv = list(vec.GetOnBits())
|
26
|
+
arr[bv] = 1
|
27
|
+
except Exception as e:
|
28
|
+
print(e)
|
29
|
+
return arr
|
30
|
+
|
31
|
+
|
32
|
+
def get_morgan_fp(smiles):
|
33
|
+
mol = Chem.MolFromSmiles(smiles)
|
34
|
+
vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
|
35
|
+
bv = list(vec.GetOnBits())
|
36
|
+
arr = np.zeros(1024)
|
37
|
+
arr[bv] = 1
|
38
|
+
return arr
|
39
|
+
|
40
|
+
|
41
|
+
fp_dict = {
|
42
|
+
'rdnorm': get_rdnorm_fp,
|
43
|
+
'maccs': get_maccs_fp,
|
44
|
+
'morgan': get_morgan_fp
|
45
|
+
}
|
46
|
+
|
47
|
+
|
48
|
+
def get_fp(smiles, fp='maccs'):
|
49
|
+
return fp_dict[fp](smiles)
|
hdl/jupyfuncs/dl/list.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
def list_diff(listA, listB, mode="intersection"):
|
2
|
+
"""Calculate the difference between two lists based on the specified mode.
|
3
|
+
|
4
|
+
Args:
|
5
|
+
listA (list): The first list.
|
6
|
+
listB (list): The second list.
|
7
|
+
mode (str, optional): The mode to determine the difference.
|
8
|
+
Possible values are "intersection" (default), "union", or "diff".
|
9
|
+
|
10
|
+
Returns:
|
11
|
+
list: A list containing the elements based on the specified mode.
|
12
|
+
"""
|
13
|
+
if mode == "intersection":
|
14
|
+
ret = list(set(listA).intersection(set(listB)))
|
15
|
+
elif mode == "union":
|
16
|
+
ret = list(set(listA).union(set(listB)))
|
17
|
+
elif mode == "diff":
|
18
|
+
ret = list(set(listB).difference(set(listA)))
|
19
|
+
|
20
|
+
return ret
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn
|
3
|
+
|
4
|
+
|
5
|
+
def save_model(
|
6
|
+
model,
|
7
|
+
save_dir,
|
8
|
+
epoch=0,
|
9
|
+
optimizer=None,
|
10
|
+
loss=None,
|
11
|
+
):
|
12
|
+
"""Save the model and related training information to a specified directory.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
model: The model to be saved.
|
16
|
+
save_dir: The directory where the model will be saved.
|
17
|
+
epoch (int): The current epoch number (default is 0).
|
18
|
+
optimizer: The optimizer used for training (default is None).
|
19
|
+
loss: The loss value (default is None).
|
20
|
+
"""
|
21
|
+
if isinstance(model, nn.DataParallel):
|
22
|
+
state_dict = model.module.state_dict()
|
23
|
+
else:
|
24
|
+
state_dict = model.state_dict()
|
25
|
+
if optimizer is None:
|
26
|
+
optim_params = None
|
27
|
+
else:
|
28
|
+
optim_params = optimizer.state_dict()
|
29
|
+
torch.save(
|
30
|
+
{
|
31
|
+
'init_args': model.init_args,
|
32
|
+
'epoch': epoch,
|
33
|
+
'model_state_dict': state_dict,
|
34
|
+
'optimizer_state_dict': optim_params,
|
35
|
+
'loss': loss,
|
36
|
+
},
|
37
|
+
save_dir
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
def load_model(
|
42
|
+
save_dir,
|
43
|
+
model_class=None,
|
44
|
+
model=None,
|
45
|
+
optimizer=None,
|
46
|
+
train=False,
|
47
|
+
):
|
48
|
+
"""Load a saved model from the specified directory.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
save_dir (str): The directory where the model checkpoint is saved.
|
52
|
+
model_class (torch.nn.Module, optional): The class of the model to be loaded. Defaults to None.
|
53
|
+
model (torch.nn.Module, optional): The model to load the state_dict into. Defaults to None.
|
54
|
+
optimizer (torch.optim.Optimizer, optional): The optimizer to load the state_dict into. Defaults to None.
|
55
|
+
train (bool, optional): Whether to set the model to training mode. Defaults to False.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
tuple: A tuple containing the loaded model, optimizer, epoch, and loss.
|
59
|
+
"""
|
60
|
+
# from .model_dict import MODEL_DICT
|
61
|
+
checkpoint = torch.load(save_dir)
|
62
|
+
if model is None:
|
63
|
+
init_args = checkpoint['init_args']
|
64
|
+
assert model_class is not None
|
65
|
+
model = model_class(**init_args)
|
66
|
+
model.load_state_dict(
|
67
|
+
checkpoint['model_state_dict'],
|
68
|
+
)
|
69
|
+
|
70
|
+
elif isinstance(model, nn.DataParallel):
|
71
|
+
state_dict = checkpoint['model_state_dict']
|
72
|
+
from collections import OrderedDict
|
73
|
+
new_state_dict = OrderedDict()
|
74
|
+
|
75
|
+
for k, v in state_dict.items():
|
76
|
+
if 'module' not in k:
|
77
|
+
k = 'module.' + k
|
78
|
+
else:
|
79
|
+
k = k.replace('features.module.', 'module.features.')
|
80
|
+
new_state_dict[k] = v
|
81
|
+
model.load_state_dict(new_state_dict)
|
82
|
+
else:
|
83
|
+
model.load_state_dict(
|
84
|
+
checkpoint['model_state_dict'],
|
85
|
+
)
|
86
|
+
|
87
|
+
if optimizer is not None:
|
88
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
89
|
+
epoch = checkpoint['epoch']
|
90
|
+
loss = checkpoint['loss']
|
91
|
+
|
92
|
+
if train:
|
93
|
+
model.train()
|
94
|
+
else:
|
95
|
+
model.eval()
|
96
|
+
|
97
|
+
return model, optimizer, epoch, loss
|
@@ -0,0 +1,159 @@
|
|
1
|
+
import typing as t
|
2
|
+
from collections import defaultdict
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
from scipy.spatial.distance import cdist
|
8
|
+
|
9
|
+
def spmmsp(
|
10
|
+
sp1: torch.sparse.Tensor,
|
11
|
+
sp2: torch.sparse.Tensor
|
12
|
+
) -> torch.sparse.Tensor:
|
13
|
+
from torch_sparse import spspmm
|
14
|
+
assert sp1.size(-1) == sp2.size(0) and sp1.is_sparse and sp2.is_sparse
|
15
|
+
m = sp1.size(0)
|
16
|
+
k = sp2.size(0)
|
17
|
+
n = sp2.size(-1)
|
18
|
+
indices, values = spspmm(
|
19
|
+
sp1.indices(), sp1.values(),
|
20
|
+
sp2.indices(), sp2.values(),
|
21
|
+
m, k, n
|
22
|
+
)
|
23
|
+
return torch.sparse_coo_tensor(
|
24
|
+
indices,
|
25
|
+
values,
|
26
|
+
torch.Size([m, n])
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
def label_to_onehot(ls, class_num, missing_label=-1):
|
31
|
+
"""
|
32
|
+
example:
|
33
|
+
>>>label_to_onehot([2,3,-1],6,-1)
|
34
|
+
array([[ 0., 0., 1., 0.],
|
35
|
+
[ 0., 0., 0., 1.],
|
36
|
+
[nan, nan, nan, nan]])
|
37
|
+
:param ls:
|
38
|
+
:param class_num:
|
39
|
+
:param missing_label:
|
40
|
+
:return:
|
41
|
+
"""
|
42
|
+
if isinstance(ls, torch.Tensor):
|
43
|
+
bool_t = ls == missing_label
|
44
|
+
clamp_t = torch.clamp(ls, min=0)
|
45
|
+
full_tensor = torch.zeros(ls.numel(), class_num)
|
46
|
+
full_tensor = full_tensor.scatter_(1, clamp_t.reshape(-1, 1), 1)
|
47
|
+
full_tensor[bool_t] = 0
|
48
|
+
return full_tensor
|
49
|
+
elif isinstance(ls, t.List):
|
50
|
+
ls = np.array(ls, dtype=np.int32)
|
51
|
+
bool_array = ls == missing_label
|
52
|
+
arr = np.zeros((ls.size, ls.max() + 1))
|
53
|
+
arr[np.arange(ls.size), ls] = 1
|
54
|
+
arr[bool_array] = 0
|
55
|
+
return arr
|
56
|
+
elif not isinstance(ls, t.Iterable):
|
57
|
+
arr = torch.zeros(class_num)
|
58
|
+
if ls != missing_label and not np.isnan(ls) and ls is not None:
|
59
|
+
arr[int(ls)] = 1
|
60
|
+
return arr
|
61
|
+
|
62
|
+
|
63
|
+
def onehot_to_label(tensor):
|
64
|
+
if isinstance(tensor, torch.Tensor):
|
65
|
+
return torch.argmax(tensor, dim=-1)
|
66
|
+
elif isinstance(tensor, np.ndarray):
|
67
|
+
return np.argmax(tensor, axis=-1)
|
68
|
+
|
69
|
+
|
70
|
+
def label_to_tensor(
|
71
|
+
label,
|
72
|
+
num_classes,
|
73
|
+
missing_label=-1,
|
74
|
+
device=torch.device('cpu')
|
75
|
+
):
|
76
|
+
if isinstance(label, t.List) and not any(label):
|
77
|
+
return torch.zeros(num_classes).to(device)
|
78
|
+
elif isinstance(label, t.List) and isinstance(label[0], t.Iterable):
|
79
|
+
max_length = max([len(_l) for _l in label])
|
80
|
+
index = [_l + _l[-1:] * (max_length - len(_l)) for _l in label]
|
81
|
+
tensor_list = []
|
82
|
+
# tensor = torch.zeros(len(label), num_classes, device=device)
|
83
|
+
for _idx in index:
|
84
|
+
_tensor = torch.zeros(num_classes).to(device)
|
85
|
+
_idx = torch.LongTensor(_idx)
|
86
|
+
_tensor = _tensor.scatter(0, _idx, 1)
|
87
|
+
tensor_list.append(_tensor)
|
88
|
+
|
89
|
+
return torch.vstack(tensor_list).to(device)
|
90
|
+
else:
|
91
|
+
if label == missing_label or np.isnan(label):
|
92
|
+
return torch.zeros(num_classes)
|
93
|
+
tensor = torch.zeros(num_classes).to(device)
|
94
|
+
tensor[int(label)] = 1
|
95
|
+
return tensor
|
96
|
+
|
97
|
+
|
98
|
+
def tensor_to_label(tensor, threshold=0.5):
|
99
|
+
label_list, label_dict = [], defaultdict(list)
|
100
|
+
labels = (tensor > threshold).nonzero(as_tuple=False)
|
101
|
+
for label in labels:
|
102
|
+
label_dict[label[0].item()].append(label[1].item())
|
103
|
+
for _, label_value in label_dict.items():
|
104
|
+
label_list.append(label_value)
|
105
|
+
return label_list
|
106
|
+
|
107
|
+
|
108
|
+
def get_dist_matrix(
|
109
|
+
a: np.ndarray, b: np.ndarray
|
110
|
+
):
|
111
|
+
return cdist(a, b)
|
112
|
+
# aSumSquare = np.sum(np.square(a), axis=1)
|
113
|
+
# bSumSquare = np.sum(np.square(b), axis=1)
|
114
|
+
# mul = np.dot(a, b.T)
|
115
|
+
# dists = np.sqrt(aSumSquare[:, np.newaxis] + bSumSquare - 2 * mul)
|
116
|
+
# return dists
|
117
|
+
|
118
|
+
|
119
|
+
def get_valid_indices(labels):
|
120
|
+
if isinstance(labels, torch.Tensor):
|
121
|
+
nan_indices = torch.isnan(labels)
|
122
|
+
valid_indices = (
|
123
|
+
nan_indices == False
|
124
|
+
).nonzero(as_tuple=True)[0]
|
125
|
+
else:
|
126
|
+
target_pd = pd.array(labels)
|
127
|
+
nan_indices = pd.isna(target_pd)
|
128
|
+
valid_indices = torch.LongTensor(
|
129
|
+
np.where(nan_indices == False)[0]
|
130
|
+
)
|
131
|
+
return valid_indices
|
132
|
+
|
133
|
+
|
134
|
+
def smooth_max(
|
135
|
+
tensor: torch.Tensor,
|
136
|
+
inf_k: int = None,
|
137
|
+
**kwargs
|
138
|
+
):
|
139
|
+
if inf_k is None:
|
140
|
+
inf_k = 10
|
141
|
+
max_value = torch.log(
|
142
|
+
torch.sum(
|
143
|
+
torch.exp(tensor * inf_k),
|
144
|
+
**kwargs
|
145
|
+
)
|
146
|
+
) / inf_k
|
147
|
+
return max_value
|
148
|
+
|
149
|
+
|
150
|
+
def list_df(listA, listB):
|
151
|
+
retB = list(set(listA).intersection(set(listB)))
|
152
|
+
|
153
|
+
retC = list(set(listA).union(set(listB)))
|
154
|
+
|
155
|
+
retD = list(set(listB).difference(set(listA)))
|
156
|
+
|
157
|
+
retE = list(set(listA).difference(set(listB)))
|
158
|
+
return retB, retC, retD, retE
|
159
|
+
|
hdl/jupyfuncs/dl/uncs.py
ADDED
@@ -0,0 +1,112 @@
|
|
1
|
+
"""UNCERTAINTY SAMPLING
|
2
|
+
|
3
|
+
Uncertainty Sampling examples for Active Learning in PyTorch
|
4
|
+
|
5
|
+
It contains four Active Learning strategies:
|
6
|
+
1. Least Confidence Sampling
|
7
|
+
2. Margin of Confidence Sampling
|
8
|
+
3. Ratio of Confidence Sampling
|
9
|
+
4. Entropy-based Sampling
|
10
|
+
|
11
|
+
"""
|
12
|
+
from copy import deepcopy
|
13
|
+
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
__all__ = [
|
17
|
+
"get_prob_unc",
|
18
|
+
]
|
19
|
+
|
20
|
+
|
21
|
+
def least_conf_unc(prob_array: np.ndarray) -> np.ndarray:
|
22
|
+
"""Least confidence uncertainty
|
23
|
+
|
24
|
+
.. math::
|
25
|
+
\phi_{L C}(x)=\left(1-P_{\theta}\left(y^{*} \mid x\right)\right) \times \frac{n}{n-1}
|
26
|
+
|
27
|
+
Args:
|
28
|
+
prob_array (np.array): a 1D or 2D array of probabilities
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
np.ndarray: the uncertainty value(s)
|
32
|
+
"""
|
33
|
+
if prob_array.ndim == 1:
|
34
|
+
indices = prob_array.argmax()
|
35
|
+
else:
|
36
|
+
indices = (
|
37
|
+
np.arange(prob_array.shape[0]),
|
38
|
+
prob_array.argmax(-1)
|
39
|
+
)
|
40
|
+
num_labels = prob_array.shape[-1]
|
41
|
+
uncs = (1 - prob_array[indices]) * (num_labels / (num_labels - 1))
|
42
|
+
return uncs
|
43
|
+
|
44
|
+
|
45
|
+
def margin_conf_unc(prob_array: np.ndarray) -> np.ndarray:
|
46
|
+
"""The margin confidence uncertainty
|
47
|
+
|
48
|
+
.. math::
|
49
|
+
\phi_{M C}(x)=1-\left(P_{\theta}\left(y_{1}^{*} \mid x\right)-P_{\theta}\left(y_{2}^{*} \mid x\right)\right)
|
50
|
+
|
51
|
+
Args:
|
52
|
+
prob_array (np.array): a 1D or 2D probability array from an NN.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
np.array: the uncertainty value(s)
|
56
|
+
"""
|
57
|
+
probs = deepcopy(prob_array)
|
58
|
+
probs.sort(-1)
|
59
|
+
diffs = probs[..., -1] - probs[..., -2]
|
60
|
+
return 1 - diffs
|
61
|
+
|
62
|
+
|
63
|
+
def ratio_conf_unc(prob_array: np.ndarray) -> np.ndarray:
|
64
|
+
"""Ratio based uncertainties
|
65
|
+
|
66
|
+
.. math::
|
67
|
+
\phi_{R C}(x)=P_{\theta}\left(y_{2}^{*} \mid x\right) / P_{\theta}\left(y_{1}^{*} \mid x\right)
|
68
|
+
|
69
|
+
Args:
|
70
|
+
prob_array (np.array): a 1D or 2D probability array
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
np.array: the uncertainty value(s)
|
74
|
+
"""
|
75
|
+
probs = deepcopy(prob_array)
|
76
|
+
probs.sort(-1)
|
77
|
+
ratio = probs[..., -1] / probs[..., -2]
|
78
|
+
return ratio
|
79
|
+
|
80
|
+
|
81
|
+
def entropy_unc(prob_array: np.ndarray) -> np.ndarray:
|
82
|
+
"""Entropy based uncertainty
|
83
|
+
|
84
|
+
.. math::
|
85
|
+
\phi_{E N T}(x)=\frac{-\Sigma_{y} P_{\theta}(y \mid x) \log _{2} P_{\theta}(y \mid x)}{\log _{2}(n)}
|
86
|
+
|
87
|
+
Args:
|
88
|
+
prob_array (np.array): a 1D or 2D probability array
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
np.array: the uncertainty value(s)
|
92
|
+
"""
|
93
|
+
num_labels = prob_array.shape[-1]
|
94
|
+
log_probs = prob_array * np.log2(prob_array)
|
95
|
+
|
96
|
+
raw_entropy = 0 - np.sum(log_probs, -1)
|
97
|
+
|
98
|
+
normalized_entropy = raw_entropy / np.log2(num_labels)
|
99
|
+
|
100
|
+
return normalized_entropy
|
101
|
+
|
102
|
+
|
103
|
+
unc_dict = {
|
104
|
+
'least': least_conf_unc,
|
105
|
+
'margin': margin_conf_unc,
|
106
|
+
'ratio': ratio_conf_unc,
|
107
|
+
'entropy': entropy_unc,
|
108
|
+
}
|
109
|
+
|
110
|
+
|
111
|
+
def get_prob_unc(prob_array: np.ndarray, unc: str) -> np.ndarray:
|
112
|
+
return unc_dict[unc](prob_array)
|
File without changes
|