hjxdl 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. hdl/_version.py +2 -2
  2. hdl/datasets/city_code.json +2576 -0
  3. hdl/datasets/defined_BaseFeatures.fdef +236 -0
  4. hdl/datasets/las.tsv +0 -0
  5. hdl/datasets/route_template.json +113 -0
  6. hdl/datasets/vocab.txt +591 -0
  7. hdl/ju/__init__.py +0 -0
  8. hdl/ju/setup.py +55 -0
  9. hdl/jupyfuncs/__init__.py +0 -0
  10. hdl/jupyfuncs/chem/__init__.py +0 -0
  11. hdl/jupyfuncs/chem/mol.py +548 -0
  12. hdl/jupyfuncs/chem/norm.py +268 -0
  13. hdl/jupyfuncs/chem/pdb_ext.py +94 -0
  14. hdl/jupyfuncs/chem/scaffold.py +25 -0
  15. hdl/jupyfuncs/chem/shape.py +241 -0
  16. hdl/jupyfuncs/chem/tokenizers.py +2 -0
  17. hdl/jupyfuncs/dbtools/__init__.py +0 -0
  18. hdl/jupyfuncs/dbtools/pg.py +42 -0
  19. hdl/jupyfuncs/dbtools/query_info.py +150 -0
  20. hdl/jupyfuncs/dl/__init__.py +0 -0
  21. hdl/jupyfuncs/dl/cp.py +54 -0
  22. hdl/jupyfuncs/dl/dataframe.py +38 -0
  23. hdl/jupyfuncs/dl/fp.py +49 -0
  24. hdl/jupyfuncs/dl/list.py +20 -0
  25. hdl/jupyfuncs/dl/model_utils.py +97 -0
  26. hdl/jupyfuncs/dl/tensor.py +159 -0
  27. hdl/jupyfuncs/dl/uncs.py +112 -0
  28. hdl/jupyfuncs/llm/__init__.py +0 -0
  29. hdl/jupyfuncs/llm/extract.py +123 -0
  30. hdl/jupyfuncs/llm/openapi.py +94 -0
  31. hdl/jupyfuncs/network/__init__.py +0 -0
  32. hdl/jupyfuncs/network/proxy.py +20 -0
  33. hdl/jupyfuncs/path/__init__.py +0 -0
  34. hdl/jupyfuncs/path/glob.py +285 -0
  35. hdl/jupyfuncs/path/strings.py +65 -0
  36. hdl/jupyfuncs/show/__init__.py +0 -0
  37. hdl/jupyfuncs/show/pbar.py +50 -0
  38. hdl/jupyfuncs/show/plot.py +259 -0
  39. hdl/jupyfuncs/utils/__init__.py +0 -0
  40. hdl/jupyfuncs/utils/wrappers.py +8 -0
  41. hdl/utils/llm/chat.py +4 -0
  42. {hjxdl-0.1.12.dist-info → hjxdl-0.1.14.dist-info}/METADATA +1 -1
  43. {hjxdl-0.1.12.dist-info → hjxdl-0.1.14.dist-info}/RECORD +45 -6
  44. {hjxdl-0.1.12.dist-info → hjxdl-0.1.14.dist-info}/WHEEL +1 -1
  45. {hjxdl-0.1.12.dist-info → hjxdl-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,150 @@
1
+ import re
2
+
3
+ import cirpy
4
+ import pubchempy as pcp
5
+ import molvs as mv
6
+ from psycopg import sql
7
+
8
+ from .pg import connect_by_infofile
9
+
10
+
11
+ def query_from_cir(query_name: str):
12
+ smiles = None
13
+ # cas_list = []
14
+ # name_list = [host=172.20.0.5 dbname=pistachio port=5432 user=postgres password=woshipostgres]
15
+
16
+ cas_list = cirpy.resolve(query_name, 'cas')
17
+ if cas_list is None or not cas_list:
18
+ cas_list = []
19
+ if isinstance(cas_list, str):
20
+ cas_list = [cas_list]
21
+
22
+ name_list = cirpy.resolve(query_name, 'names')
23
+ if name_list is None or not name_list:
24
+ name_list = []
25
+ if isinstance(name_list, str):
26
+ name_list = [name_list]
27
+
28
+ smiles = cirpy.resolve(query_name, 'smiles')
29
+ try:
30
+ smiles = mv.standardize_smiles(smiles)
31
+ except Exception as e:
32
+ print(e)
33
+
34
+ return smiles, cas_list, name_list
35
+
36
+
37
+ def query_from_pubchem(query_name: str):
38
+ results = pcp.get_compounds(query_name, 'name')
39
+ smiles = None
40
+ name_list = set()
41
+ cas_list = set()
42
+
43
+ if any(results):
44
+ try:
45
+ smiles = mv.standardize_smiles(results[0].canonical_smiles)
46
+ except Exception as e:
47
+ smiles = results[0].canonical_smiles
48
+ print(smiles)
49
+ print(e)
50
+ for compound in results:
51
+ name_list.update(set(compound.synonyms))
52
+ for syn in compound.synonyms:
53
+ match = re.match('(\d{2,7}-\d\d-\d)', syn)
54
+ if match:
55
+ cas_list.add(match.group(1))
56
+
57
+ cas_list = list(cas_list)
58
+ name_list = list(name_list)
59
+
60
+ return smiles, cas_list, name_list
61
+
62
+
63
+ def query_a_compound(
64
+ query_name: str,
65
+ connect_info: str,
66
+ by: str = 'name',
67
+ log_file: str = './err.log'
68
+ ):
69
+ fei = None
70
+ found = False
71
+
72
+ if by != 'smiles':
73
+ query_name = query_name.lower()
74
+
75
+ by = 'name'
76
+ table = by + '_maps'
77
+ # query_name = 'adipic acid'
78
+ query = sql.SQL(
79
+ "select fei from {table} where {by} = %s"
80
+ ).format(
81
+ table=sql.Identifier(table),
82
+ by=sql.Identifier(by)
83
+ )
84
+ conn = connect_by_infofile(connect_info)
85
+
86
+ cur = conn.execute(query, [query_name]).fetchone()
87
+
88
+ if cur is not None:
89
+ fei = cur[0]
90
+ found = True
91
+ return fei
92
+
93
+ if not found:
94
+ try:
95
+ smiles, cas_list, name_list = query_from_pubchem(query_name)
96
+ except Exception as e:
97
+ print(e)
98
+ smiles, cas_list, name_list = None, [], []
99
+ if smiles is not None:
100
+ found = True
101
+ else:
102
+ try:
103
+ smiles, cas_list, name_list = query_from_cir(query_name)
104
+ except Exception as e:
105
+ print(e)
106
+ smiles, cas_list, name_list = None, [], []
107
+ if smiles is not None:
108
+ found = True
109
+
110
+ if not found:
111
+ with open(log_file, 'a') as f:
112
+ f.write(query_name)
113
+ f.write('\n')
114
+ return
115
+ # raise ValueError('给的啥破玩意儿查都查不着!')
116
+ else:
117
+ query_compound = sql.SQL(
118
+ "select fei from compounds where smiles = %s"
119
+ )
120
+ cur = conn.execute(query_compound, [smiles]).fetchone()
121
+ if cur is not None:
122
+ fei = cur[0]
123
+ elif any(cas_list):
124
+ fei = cas_list[0]
125
+ insert_compounds_sql = sql.SQL(
126
+ "INSERT INTO compounds (fei, smiles) VALUES (%s, %s) ON CONFLICT (fei) DO NOTHING"
127
+ )
128
+ conn.execute(insert_compounds_sql, [fei, smiles])
129
+ for cas in cas_list:
130
+ insert_cas_map_sql = sql.SQL(
131
+ "INSERT INTO cas_maps (fei, cas) VALUES (%s, %s) ON CONFLICT (cas) DO NOTHING"
132
+ )
133
+ try:
134
+ conn.execute(insert_cas_map_sql, [fei, cas])
135
+ except Exception as e:
136
+ print(e)
137
+ for name in name_list:
138
+ insert_name_map_sql = sql.SQL(
139
+ "INSERT INTO name_maps (fei, name) VALUES (%s, %s) ON CONFLICT (fei, name) DO NOTHING"
140
+ # "INSERT INTO name_maps (fei, name) VALUES (%s, %s)"
141
+ )
142
+ try:
143
+ conn.execute(insert_name_map_sql, [fei, name.lower()])
144
+ except Exception as e:
145
+ print(e)
146
+
147
+ conn.commit()
148
+ conn.close()
149
+
150
+ return fei
File without changes
hdl/jupyfuncs/dl/cp.py ADDED
@@ -0,0 +1,54 @@
1
+ """
2
+ Conformer Prediction for classification Task
3
+
4
+ As we only consider the predicted values as input,
5
+ we do not differentiate between transductive and inductive conformers
6
+ """
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+
12
+ class CpClassfier:
13
+ def __init__(self):
14
+ self.cal_data = None
15
+ self.class_num = 0
16
+
17
+ def fit_with_data(self, cal_proba, cal_y, class_num=0):
18
+ """
19
+ :parm cal_proba: numpy array of shape [n_samples, n_classes]
20
+ predicted probability of calibration set
21
+ :parm cal_y: numpy array of shape [n_samples,]
22
+ true label of calibration set
23
+ """
24
+
25
+ if class_num <= 0:
26
+ print("Get class number for input data.")
27
+ self.class_num = cal_proba.shape[1]
28
+ else:
29
+ self.class_num = class_num
30
+
31
+ cal_df = pd.DataFrame(cal_proba)
32
+ cal_df.columns = ["class_%d"%i for i in range(self.class_num)]
33
+ cal_df["true_label"] = list(cal_y)
34
+ self.cal_data = cal_df
35
+
36
+ def fit_with_model(self, func):
37
+ #TODO: besides calibration data, we can also use our model
38
+ pass
39
+
40
+ def predict_with_proba(self, X_proba):
41
+ """
42
+ :parm X_proba: numpy array of shape [n_samples, n_classes]
43
+ predicted probabilities of calibration set
44
+ """
45
+ cp_proba = []
46
+ class_lsts = [sorted(self.cal_data[self.cal_data["true_label"] == i]["class_%d"%i]) \
47
+ for i in range(self.class_num)]
48
+ proba_lsts = [X_proba[:, i] for i in range(self.class_num)]
49
+ for c_lst, p_lst in zip(class_lsts, proba_lsts):
50
+ c_proba = np.searchsorted(c_lst, p_lst, side='left')/len(c_lst)
51
+ cp_proba.append(c_proba)
52
+
53
+ self.cp_P = np.array(cp_proba).T
54
+ return self.cp_P
@@ -0,0 +1,38 @@
1
+ def rm_index(df):
2
+ """Remove columns with 'Unnamed' in their names from the DataFrame.
3
+
4
+ Args:
5
+ df (pandas.DataFrame): The input DataFrame.
6
+
7
+ Returns:
8
+ pandas.DataFrame: DataFrame with columns containing 'Unnamed' removed.
9
+ """
10
+ return df.loc[:, ~df.columns.str.match('Unnamed')]
11
+
12
+
13
+ def rm_col(df, col_name):
14
+ """Remove a column from a DataFrame.
15
+
16
+ Args:
17
+ df (pandas.DataFrame): The input DataFrame.
18
+ col_name (str): The name of the column to be removed.
19
+
20
+ Returns:
21
+ pandas.DataFrame: A new DataFrame with the specified column removed.
22
+ """
23
+ return df.loc[:, ~df.columns.str.match(col_name)]
24
+
25
+
26
+ def shuffle_df(df):
27
+ """Shuffle the rows of a DataFrame.
28
+
29
+ Args:
30
+ df (pandas.DataFrame): The input DataFrame to shuffle.
31
+
32
+ Returns:
33
+ pandas.DataFrame: A new DataFrame with rows shuffled.
34
+
35
+ Example:
36
+ shuffled_df = shuffle_df(df)
37
+ """
38
+ return df.sample(frac=1).reset_index(drop=True)
hdl/jupyfuncs/dl/fp.py ADDED
@@ -0,0 +1,49 @@
1
+ from rdkit import Chem
2
+ import numpy as np
3
+ from rdkit.Chem import AllChem
4
+ from rdkit.Chem import MACCSkeys
5
+
6
+
7
+ __all__ = [
8
+ 'get_fp',
9
+ ]
10
+
11
+
12
+ def get_rdnorm_fp(smiles):
13
+ from descriptastorus.descriptors import rdNormalizedDescriptors
14
+ generator = rdNormalizedDescriptors.RDKit2DNormalized()
15
+ features = generator.process(smiles)[1:]
16
+ arr = np.array(features)
17
+ return arr
18
+
19
+
20
+ def get_maccs_fp(smiles):
21
+ arr = np.zeros(167)
22
+ try:
23
+ mol = Chem.MolFromSmiles(smiles)
24
+ vec = MACCSkeys.GenMACCSKeys(mol)
25
+ bv = list(vec.GetOnBits())
26
+ arr[bv] = 1
27
+ except Exception as e:
28
+ print(e)
29
+ return arr
30
+
31
+
32
+ def get_morgan_fp(smiles):
33
+ mol = Chem.MolFromSmiles(smiles)
34
+ vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
35
+ bv = list(vec.GetOnBits())
36
+ arr = np.zeros(1024)
37
+ arr[bv] = 1
38
+ return arr
39
+
40
+
41
+ fp_dict = {
42
+ 'rdnorm': get_rdnorm_fp,
43
+ 'maccs': get_maccs_fp,
44
+ 'morgan': get_morgan_fp
45
+ }
46
+
47
+
48
+ def get_fp(smiles, fp='maccs'):
49
+ return fp_dict[fp](smiles)
@@ -0,0 +1,20 @@
1
+ def list_diff(listA, listB, mode="intersection"):
2
+ """Calculate the difference between two lists based on the specified mode.
3
+
4
+ Args:
5
+ listA (list): The first list.
6
+ listB (list): The second list.
7
+ mode (str, optional): The mode to determine the difference.
8
+ Possible values are "intersection" (default), "union", or "diff".
9
+
10
+ Returns:
11
+ list: A list containing the elements based on the specified mode.
12
+ """
13
+ if mode == "intersection":
14
+ ret = list(set(listA).intersection(set(listB)))
15
+ elif mode == "union":
16
+ ret = list(set(listA).union(set(listB)))
17
+ elif mode == "diff":
18
+ ret = list(set(listB).difference(set(listA)))
19
+
20
+ return ret
@@ -0,0 +1,97 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def save_model(
6
+ model,
7
+ save_dir,
8
+ epoch=0,
9
+ optimizer=None,
10
+ loss=None,
11
+ ):
12
+ """Save the model and related training information to a specified directory.
13
+
14
+ Args:
15
+ model: The model to be saved.
16
+ save_dir: The directory where the model will be saved.
17
+ epoch (int): The current epoch number (default is 0).
18
+ optimizer: The optimizer used for training (default is None).
19
+ loss: The loss value (default is None).
20
+ """
21
+ if isinstance(model, nn.DataParallel):
22
+ state_dict = model.module.state_dict()
23
+ else:
24
+ state_dict = model.state_dict()
25
+ if optimizer is None:
26
+ optim_params = None
27
+ else:
28
+ optim_params = optimizer.state_dict()
29
+ torch.save(
30
+ {
31
+ 'init_args': model.init_args,
32
+ 'epoch': epoch,
33
+ 'model_state_dict': state_dict,
34
+ 'optimizer_state_dict': optim_params,
35
+ 'loss': loss,
36
+ },
37
+ save_dir
38
+ )
39
+
40
+
41
+ def load_model(
42
+ save_dir,
43
+ model_class=None,
44
+ model=None,
45
+ optimizer=None,
46
+ train=False,
47
+ ):
48
+ """Load a saved model from the specified directory.
49
+
50
+ Args:
51
+ save_dir (str): The directory where the model checkpoint is saved.
52
+ model_class (torch.nn.Module, optional): The class of the model to be loaded. Defaults to None.
53
+ model (torch.nn.Module, optional): The model to load the state_dict into. Defaults to None.
54
+ optimizer (torch.optim.Optimizer, optional): The optimizer to load the state_dict into. Defaults to None.
55
+ train (bool, optional): Whether to set the model to training mode. Defaults to False.
56
+
57
+ Returns:
58
+ tuple: A tuple containing the loaded model, optimizer, epoch, and loss.
59
+ """
60
+ # from .model_dict import MODEL_DICT
61
+ checkpoint = torch.load(save_dir)
62
+ if model is None:
63
+ init_args = checkpoint['init_args']
64
+ assert model_class is not None
65
+ model = model_class(**init_args)
66
+ model.load_state_dict(
67
+ checkpoint['model_state_dict'],
68
+ )
69
+
70
+ elif isinstance(model, nn.DataParallel):
71
+ state_dict = checkpoint['model_state_dict']
72
+ from collections import OrderedDict
73
+ new_state_dict = OrderedDict()
74
+
75
+ for k, v in state_dict.items():
76
+ if 'module' not in k:
77
+ k = 'module.' + k
78
+ else:
79
+ k = k.replace('features.module.', 'module.features.')
80
+ new_state_dict[k] = v
81
+ model.load_state_dict(new_state_dict)
82
+ else:
83
+ model.load_state_dict(
84
+ checkpoint['model_state_dict'],
85
+ )
86
+
87
+ if optimizer is not None:
88
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
89
+ epoch = checkpoint['epoch']
90
+ loss = checkpoint['loss']
91
+
92
+ if train:
93
+ model.train()
94
+ else:
95
+ model.eval()
96
+
97
+ return model, optimizer, epoch, loss
@@ -0,0 +1,159 @@
1
+ import typing as t
2
+ from collections import defaultdict
3
+
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ from scipy.spatial.distance import cdist
8
+
9
+ def spmmsp(
10
+ sp1: torch.sparse.Tensor,
11
+ sp2: torch.sparse.Tensor
12
+ ) -> torch.sparse.Tensor:
13
+ from torch_sparse import spspmm
14
+ assert sp1.size(-1) == sp2.size(0) and sp1.is_sparse and sp2.is_sparse
15
+ m = sp1.size(0)
16
+ k = sp2.size(0)
17
+ n = sp2.size(-1)
18
+ indices, values = spspmm(
19
+ sp1.indices(), sp1.values(),
20
+ sp2.indices(), sp2.values(),
21
+ m, k, n
22
+ )
23
+ return torch.sparse_coo_tensor(
24
+ indices,
25
+ values,
26
+ torch.Size([m, n])
27
+ )
28
+
29
+
30
+ def label_to_onehot(ls, class_num, missing_label=-1):
31
+ """
32
+ example:
33
+ >>>label_to_onehot([2,3,-1],6,-1)
34
+ array([[ 0., 0., 1., 0.],
35
+ [ 0., 0., 0., 1.],
36
+ [nan, nan, nan, nan]])
37
+ :param ls:
38
+ :param class_num:
39
+ :param missing_label:
40
+ :return:
41
+ """
42
+ if isinstance(ls, torch.Tensor):
43
+ bool_t = ls == missing_label
44
+ clamp_t = torch.clamp(ls, min=0)
45
+ full_tensor = torch.zeros(ls.numel(), class_num)
46
+ full_tensor = full_tensor.scatter_(1, clamp_t.reshape(-1, 1), 1)
47
+ full_tensor[bool_t] = 0
48
+ return full_tensor
49
+ elif isinstance(ls, t.List):
50
+ ls = np.array(ls, dtype=np.int32)
51
+ bool_array = ls == missing_label
52
+ arr = np.zeros((ls.size, ls.max() + 1))
53
+ arr[np.arange(ls.size), ls] = 1
54
+ arr[bool_array] = 0
55
+ return arr
56
+ elif not isinstance(ls, t.Iterable):
57
+ arr = torch.zeros(class_num)
58
+ if ls != missing_label and not np.isnan(ls) and ls is not None:
59
+ arr[int(ls)] = 1
60
+ return arr
61
+
62
+
63
+ def onehot_to_label(tensor):
64
+ if isinstance(tensor, torch.Tensor):
65
+ return torch.argmax(tensor, dim=-1)
66
+ elif isinstance(tensor, np.ndarray):
67
+ return np.argmax(tensor, axis=-1)
68
+
69
+
70
+ def label_to_tensor(
71
+ label,
72
+ num_classes,
73
+ missing_label=-1,
74
+ device=torch.device('cpu')
75
+ ):
76
+ if isinstance(label, t.List) and not any(label):
77
+ return torch.zeros(num_classes).to(device)
78
+ elif isinstance(label, t.List) and isinstance(label[0], t.Iterable):
79
+ max_length = max([len(_l) for _l in label])
80
+ index = [_l + _l[-1:] * (max_length - len(_l)) for _l in label]
81
+ tensor_list = []
82
+ # tensor = torch.zeros(len(label), num_classes, device=device)
83
+ for _idx in index:
84
+ _tensor = torch.zeros(num_classes).to(device)
85
+ _idx = torch.LongTensor(_idx)
86
+ _tensor = _tensor.scatter(0, _idx, 1)
87
+ tensor_list.append(_tensor)
88
+
89
+ return torch.vstack(tensor_list).to(device)
90
+ else:
91
+ if label == missing_label or np.isnan(label):
92
+ return torch.zeros(num_classes)
93
+ tensor = torch.zeros(num_classes).to(device)
94
+ tensor[int(label)] = 1
95
+ return tensor
96
+
97
+
98
+ def tensor_to_label(tensor, threshold=0.5):
99
+ label_list, label_dict = [], defaultdict(list)
100
+ labels = (tensor > threshold).nonzero(as_tuple=False)
101
+ for label in labels:
102
+ label_dict[label[0].item()].append(label[1].item())
103
+ for _, label_value in label_dict.items():
104
+ label_list.append(label_value)
105
+ return label_list
106
+
107
+
108
+ def get_dist_matrix(
109
+ a: np.ndarray, b: np.ndarray
110
+ ):
111
+ return cdist(a, b)
112
+ # aSumSquare = np.sum(np.square(a), axis=1)
113
+ # bSumSquare = np.sum(np.square(b), axis=1)
114
+ # mul = np.dot(a, b.T)
115
+ # dists = np.sqrt(aSumSquare[:, np.newaxis] + bSumSquare - 2 * mul)
116
+ # return dists
117
+
118
+
119
+ def get_valid_indices(labels):
120
+ if isinstance(labels, torch.Tensor):
121
+ nan_indices = torch.isnan(labels)
122
+ valid_indices = (
123
+ nan_indices == False
124
+ ).nonzero(as_tuple=True)[0]
125
+ else:
126
+ target_pd = pd.array(labels)
127
+ nan_indices = pd.isna(target_pd)
128
+ valid_indices = torch.LongTensor(
129
+ np.where(nan_indices == False)[0]
130
+ )
131
+ return valid_indices
132
+
133
+
134
+ def smooth_max(
135
+ tensor: torch.Tensor,
136
+ inf_k: int = None,
137
+ **kwargs
138
+ ):
139
+ if inf_k is None:
140
+ inf_k = 10
141
+ max_value = torch.log(
142
+ torch.sum(
143
+ torch.exp(tensor * inf_k),
144
+ **kwargs
145
+ )
146
+ ) / inf_k
147
+ return max_value
148
+
149
+
150
+ def list_df(listA, listB):
151
+ retB = list(set(listA).intersection(set(listB)))
152
+
153
+ retC = list(set(listA).union(set(listB)))
154
+
155
+ retD = list(set(listB).difference(set(listA)))
156
+
157
+ retE = list(set(listA).difference(set(listB)))
158
+ return retB, retC, retD, retE
159
+
@@ -0,0 +1,112 @@
1
+ """UNCERTAINTY SAMPLING
2
+
3
+ Uncertainty Sampling examples for Active Learning in PyTorch
4
+
5
+ It contains four Active Learning strategies:
6
+ 1. Least Confidence Sampling
7
+ 2. Margin of Confidence Sampling
8
+ 3. Ratio of Confidence Sampling
9
+ 4. Entropy-based Sampling
10
+
11
+ """
12
+ from copy import deepcopy
13
+
14
+ import numpy as np
15
+
16
+ __all__ = [
17
+ "get_prob_unc",
18
+ ]
19
+
20
+
21
+ def least_conf_unc(prob_array: np.ndarray) -> np.ndarray:
22
+ """Least confidence uncertainty
23
+
24
+ .. math::
25
+ \phi_{L C}(x)=\left(1-P_{\theta}\left(y^{*} \mid x\right)\right) \times \frac{n}{n-1}
26
+
27
+ Args:
28
+ prob_array (np.array): a 1D or 2D array of probabilities
29
+
30
+ Returns:
31
+ np.ndarray: the uncertainty value(s)
32
+ """
33
+ if prob_array.ndim == 1:
34
+ indices = prob_array.argmax()
35
+ else:
36
+ indices = (
37
+ np.arange(prob_array.shape[0]),
38
+ prob_array.argmax(-1)
39
+ )
40
+ num_labels = prob_array.shape[-1]
41
+ uncs = (1 - prob_array[indices]) * (num_labels / (num_labels - 1))
42
+ return uncs
43
+
44
+
45
+ def margin_conf_unc(prob_array: np.ndarray) -> np.ndarray:
46
+ """The margin confidence uncertainty
47
+
48
+ .. math::
49
+ \phi_{M C}(x)=1-\left(P_{\theta}\left(y_{1}^{*} \mid x\right)-P_{\theta}\left(y_{2}^{*} \mid x\right)\right)
50
+
51
+ Args:
52
+ prob_array (np.array): a 1D or 2D probability array from an NN.
53
+
54
+ Returns:
55
+ np.array: the uncertainty value(s)
56
+ """
57
+ probs = deepcopy(prob_array)
58
+ probs.sort(-1)
59
+ diffs = probs[..., -1] - probs[..., -2]
60
+ return 1 - diffs
61
+
62
+
63
+ def ratio_conf_unc(prob_array: np.ndarray) -> np.ndarray:
64
+ """Ratio based uncertainties
65
+
66
+ .. math::
67
+ \phi_{R C}(x)=P_{\theta}\left(y_{2}^{*} \mid x\right) / P_{\theta}\left(y_{1}^{*} \mid x\right)
68
+
69
+ Args:
70
+ prob_array (np.array): a 1D or 2D probability array
71
+
72
+ Returns:
73
+ np.array: the uncertainty value(s)
74
+ """
75
+ probs = deepcopy(prob_array)
76
+ probs.sort(-1)
77
+ ratio = probs[..., -1] / probs[..., -2]
78
+ return ratio
79
+
80
+
81
+ def entropy_unc(prob_array: np.ndarray) -> np.ndarray:
82
+ """Entropy based uncertainty
83
+
84
+ .. math::
85
+ \phi_{E N T}(x)=\frac{-\Sigma_{y} P_{\theta}(y \mid x) \log _{2} P_{\theta}(y \mid x)}{\log _{2}(n)}
86
+
87
+ Args:
88
+ prob_array (np.array): a 1D or 2D probability array
89
+
90
+ Returns:
91
+ np.array: the uncertainty value(s)
92
+ """
93
+ num_labels = prob_array.shape[-1]
94
+ log_probs = prob_array * np.log2(prob_array)
95
+
96
+ raw_entropy = 0 - np.sum(log_probs, -1)
97
+
98
+ normalized_entropy = raw_entropy / np.log2(num_labels)
99
+
100
+ return normalized_entropy
101
+
102
+
103
+ unc_dict = {
104
+ 'least': least_conf_unc,
105
+ 'margin': margin_conf_unc,
106
+ 'ratio': ratio_conf_unc,
107
+ 'entropy': entropy_unc,
108
+ }
109
+
110
+
111
+ def get_prob_unc(prob_array: np.ndarray, unc: str) -> np.ndarray:
112
+ return unc_dict[unc](prob_array)
File without changes