pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
pgsui/utils/misc.py CHANGED
@@ -1,44 +1,38 @@
1
- import sys
2
- import os
3
- import functools
4
- import time
5
- import datetime
6
- import platform
7
- import subprocess
8
- import re
9
- import logging
1
+ from typing import Literal
10
2
 
11
3
  import numpy as np
12
4
  import pandas as pd
13
- from tqdm import tqdm
14
- from tqdm.utils import disp_len, _unicode # for overriding status_print
5
+ import torch
15
6
 
16
7
 
17
- # from skopt import BayesSearchCV
18
-
19
-
20
- def validate_input_type(X, return_type="array"):
8
+ def validate_input_type(
9
+ X: pd.DataFrame | np.ndarray | list | torch.Tensor,
10
+ return_type: Literal["array", "df", "list", "tensor"] = "array",
11
+ ) -> pd.DataFrame | np.ndarray | list | torch.Tensor:
21
12
  """Validate input type and return as numpy array.
22
13
 
14
+ This method validates the input type and returns the input data as a numpy array, pandas DataFrame, 2D list, or torch.Tensor.
15
+
23
16
  Args:
24
- X (pandas.DataFrame, numpy.ndarray, or List[List[int]]): Input data.
17
+ X (pandas.DataFrame | numpy.ndarray | list | torch.Tensor): Input data. Supported types include: pandas.DataFrame, numpy.ndarray, list, and torch.Tensor.
25
18
 
26
- return_type (str): Type of returned object. Supported options include: "df", "array", and "list". Defaults to "array".
19
+ return_type (Literal["array", "df", "list", "tensor"]): Type of returned object. Supported options include: "df", "array", "list", and "tensor". "df" corresponds to a pandas DataFrame. "array" corresponds to a numpy array. "list" corresponds to a 2D list. "tensor" corresponds to a torch.Tensor. Defaults to "array".
27
20
 
28
21
  Returns:
29
- pandas.DataFrame, numpy.ndarray, or List[List[int]]: Input data desired return_type.
22
+ pandas.DataFrame | numpy.ndarray | list | torch.Tensor: Input data as the desired return_type.
30
23
 
31
24
  Raises:
32
- TypeError: X must be of type pandas.DataFrame, numpy.ndarray, or List[List[int]].
33
-
34
- ValueError: Unsupported return_type provided. Supported types are "df", "array", and "list".
25
+ TypeError: X must be of type pandas.DataFrame, numpy.ndarray, list, or torch.Tensor.
26
+ ValueError: Unsupported return_type provided. Supported types are "df", "array", "list", and "tensor".
35
27
 
36
28
  """
37
- if not isinstance(X, (pd.DataFrame, np.ndarray, list)):
38
- raise TypeError(
39
- f"X must be of type pandas.DataFrame, numpy.ndarray, "
40
- f"or List[List[int]], but got {type(X)}"
41
- )
29
+ if not isinstance(X, (pd.DataFrame, np.ndarray, list, torch.Tensor)):
30
+ msg = f"X must be of type pandas.DataFrame, numpy.ndarray, list, or torch.Tensor, but got {type(X)}"
31
+ raise TypeError(msg)
32
+
33
+ if return_type not in {"df", "array", "list", "tensor"}:
34
+ msg = f"Unsupported return type provided: {return_type}. Supported types are 'df', 'array', 'list', and 'tensor'"
35
+ raise ValueError(msg)
42
36
 
43
37
  if return_type == "array":
44
38
  if isinstance(X, pd.DataFrame):
@@ -47,12 +41,16 @@ def validate_input_type(X, return_type="array"):
47
41
  return np.array(X)
48
42
  elif isinstance(X, np.ndarray):
49
43
  return X.copy()
44
+ elif isinstance(X, torch.Tensor):
45
+ return X.cpu().detach().numpy()
50
46
 
51
47
  elif return_type == "df":
52
48
  if isinstance(X, pd.DataFrame):
53
49
  return X.copy()
54
50
  elif isinstance(X, (np.ndarray, list)):
55
51
  return pd.DataFrame(X)
52
+ elif isinstance(X, torch.Tensor):
53
+ return pd.DataFrame(X.cpu().detach().numpy())
56
54
 
57
55
  elif return_type == "list":
58
56
  if isinstance(X, list):
@@ -61,458 +59,15 @@ def validate_input_type(X, return_type="array"):
61
59
  return X.tolist()
62
60
  elif isinstance(X, pd.DataFrame):
63
61
  return X.values.tolist()
62
+ elif isinstance(X, torch.Tensor):
63
+ return X.cpu().detach().numpy().tolist()
64
64
 
65
- else:
66
- raise ValueError(
67
- f"Unsupported return type provided: {return_type}. Supported types "
68
- f"are 'df', 'array', and 'list'"
69
- )
70
-
71
-
72
- def generate_random_dataset(
73
- min_value=0,
74
- max_value=2,
75
- nrows=35,
76
- ncols=20,
77
- min_missing_rate=0.15,
78
- max_missing_rate=0.5,
79
- ):
80
- """Generate a random integer dataset that can be used for testing.
81
-
82
- Will also add randomly missing values of random proportions between ``min_missing_rate`` and ``max_missing_rate``.
83
-
84
- Args:
85
- min_value (int, optional): Minimum value to use. Defaults to 0.
86
-
87
- max_value (int, optional): Maxiumum value to use. Defaults to 2.
88
-
89
- nrows (int, optional): Number of rows to use. Defaults to 35.
90
-
91
- ncols (int, optional): Number of columns to use. Defaults to 20.
92
-
93
- min_missing_rate (float, optional): Minimum proportion of missing data per column. Defaults to 0.15.
94
-
95
- max_missing_rate (float, optional): Maximum proportion of missing data per column.
96
-
97
- Returns:
98
- numpy.ndarray: Numpy array with randomly generated dataset.
99
- """
100
- assert (
101
- min_missing_rate >= 0 and min_missing_rate < 1.0
102
- ), f"min_missing_rate must be >= 0 and < 1.0, but got {min_missing_rate}"
103
-
104
- assert (
105
- max_missing_rate > 0 and max_missing_rate < 1.0
106
- ), f"max_missing_rate must be > 0 and < 1.0, but got {max_missing_rate}"
107
-
108
- assert nrows > 1, f"nrows must be > 1, but got {nrows}"
109
- assert ncols > 1, f"ncols must be > 1, but got {ncols}"
110
-
111
- try:
112
- min_missing_rate = float(min_missing_rate)
113
- max_missing_rate = float(max_missing_rate)
114
- except TypeError:
115
- sys.exit(
116
- "min_missing_rate and max_missing_rate must be of type float or "
117
- "must be cast-able to type float"
118
- )
119
-
120
- X = np.random.randint(
121
- min_value, max_value + 1, size=(nrows, ncols)
122
- ).astype(float)
123
- for i in range(X.shape[1]):
124
- drop_rate = int(
125
- np.random.choice(
126
- np.arange(min_missing_rate, max_missing_rate, 0.02), 1
127
- )[0]
128
- * X.shape[0]
129
- )
130
-
131
- rows = np.random.choice(np.arange(0, X.shape[0]), size=drop_rate)
132
- X[rows, i] = np.nan
133
-
134
- return X
135
-
136
-
137
- def generate_012_genotypes(
138
- nrows=35,
139
- ncols=20,
140
- max_missing_rate=0.5,
141
- min_het_rate=0.001,
142
- max_het_rate=0.3,
143
- min_alt_rate=0.001,
144
- max_alt_rate=0.3,
145
- ):
146
- """Generate random 012-encoded genotypes.
147
-
148
- Allows users to control the rate of reference, heterozygote, and alternate alleles. Will insert a random proportion between ``min_het_rate`` and ``max_het_rate`` and ``min_alt_rate`` and ``max_alt_rate`` and from no misssing data to a proportion of ``max_missing_rate``.
149
-
150
- Args:
151
- nrows (int, optional): Number of rows to generate. Defaults to 35.
152
-
153
- ncols (int, optional): Number of columns to generate. Defaults to 20.
154
-
155
- max_missing_rate (float, optional): Maximum proportion of missing data to use. Defaults to 0.5.
156
-
157
- min_het_rate (float, optional): Minimum proportion of heterozygotes (1's) to insert. Defaults to 0.001.
158
-
159
- max_het_rate (float, optional): Maximum proportion of heterozygotes (1's) to insert. Defaults to 0.3.
160
-
161
- min_alt_rate (float, optional): Minimum proportion of alternate alleles (2's) to insert. Defaults to 0.001.
162
-
163
- max_alt_rate (float, optional): Maximum proportion of alternate alleles (2's) to insert. Defaults to 0.3.
164
- """
165
- assert (
166
- min_het_rate > 0 and min_het_rate <= 1.0
167
- ), f"min_het_rate must be > 0 and <= 1.0, but got {min_het_rate}"
168
-
169
- assert (
170
- max_het_rate > 0 and max_het_rate <= 1.0
171
- ), f"max_het_rate must be > 0 and <= 1.0, but got {max_het_rate}"
172
-
173
- assert (
174
- min_alt_rate > 0 and min_alt_rate <= 1.0
175
- ), f"min_alt_rate must be > 0 and <= 1.0, but got {min_alt_rate}"
176
-
177
- assert (
178
- max_alt_rate > 0 and max_alt_rate <= 1.0
179
- ), f"max_alt_rate must be > 0 and <= 1.0, but got {max_alt_rate}"
180
-
181
- assert nrows > 1, f"The number of rows must be > 1, but got {nrows}"
182
-
183
- assert ncols > 1, f"The number of columns must be > 1, but got {ncols}"
184
-
185
- assert (
186
- max_missing_rate > 0 and max_missing_rate < 1.0
187
- ), f"max_missing rate must be > 0 and < 1.0, but got {max_missing_rate}"
188
-
189
- try:
190
- min_het_rate = float(min_het_rate)
191
- max_het_rate = float(max_het_rate)
192
- min_alt_rate = float(min_alt_rate)
193
- max_alt_rate = float(max_alt_rate)
194
- max_missing_rate = float(max_missing_rate)
195
- except TypeError:
196
- sys.exit(
197
- "max_missing_rate, min_het_rate, max_het_rate, min_alt_rate, and "
198
- "max_alt_rate must be of type float, or must be cast-able to type "
199
- "float"
200
- )
201
-
202
- X = np.zeros((nrows, ncols))
203
- for i in range(X.shape[1]):
204
- het_rate = int(
205
- np.ceil(
206
- np.random.choice(
207
- np.arange(min_het_rate, max_het_rate, 0.02), 1
208
- )[0]
209
- * X.shape[0]
210
- )
211
- )
212
-
213
- alt_rate = int(
214
- np.ceil(
215
- np.random.choice(
216
- np.arange(min_alt_rate, max_alt_rate, 0.02), 1
217
- )[0]
218
- * X.shape[0]
219
- )
220
- )
221
-
222
- het = np.sort(
223
- np.random.choice(
224
- np.arange(0, X.shape[0]), size=het_rate, replace=False
225
- )
226
- )
227
-
228
- alt = np.sort(
229
- np.random.choice(
230
- np.arange(0, X.shape[0]), size=alt_rate, replace=False
231
- )
232
- )
233
-
234
- sidx = alt.argsort()
235
- idx = np.searchsorted(alt, het, sorter=sidx)
236
- idx[idx == len(alt)] = 0
237
- het_unique = het[alt[sidx[idx]] != het]
238
-
239
- X[alt, i] = 2
240
- X[het_unique, i] = 1
241
-
242
- drop_rate = int(
243
- np.random.choice(np.arange(0.15, max_missing_rate, 0.02), 1)[0]
244
- * X.shape[0]
245
- )
246
-
247
- missing = np.random.choice(np.arange(0, X.shape[0]), size=drop_rate)
248
-
249
- X[missing, i] = np.nan
250
-
251
- print(
252
- f"Created a dataset of shape {X.shape} with {np.isnan(X).sum()} total missing values"
253
- )
254
-
255
- return X
256
-
257
-
258
- def unique2D_subarray(a):
259
- """Get unique subarrays for each column from numpy array."""
260
- dtype1 = np.dtype((np.void, a.dtype.itemsize * np.prod(a.shape[1:])))
261
- b = np.ascontiguousarray(a.reshape(a.shape[0], -1)).view(dtype1)
262
- return a[np.unique(b, return_index=1, axis=-1)[1]]
263
-
264
-
265
- def _remove_nonbiallelic(df, cv=5):
266
- """Remove non-biallelic sites from pandas.DataFrame.
267
-
268
- Remove sites that do not have both 0 and 2 encoded values in a column and if any of the allele counts is less than the number of cross-validation folds.
269
-
270
- Args:
271
- df (pandas.DataFrame): DataFrame with 012-encoded genotypes.
272
-
273
- Returns:
274
- pandas.DataFrame: DataFrame with non-biallelic sites dropped.
275
- """
276
- df_cp = df.copy()
277
- bad_cols = list()
278
- if pd.__version__[0] == 0:
279
- for col in df_cp.columns:
280
- if (
281
- not df_cp[col].isin([0.0]).any()
282
- or not df_cp[col].isin([2.0]).any()
283
- ):
284
- bad_cols.append(col)
285
-
286
- elif len(df_cp[df_cp[col] == 0.0]) < cv:
287
- bad_cols.append(col)
288
-
289
- elif df_cp[col].isin([1.0]).any():
290
- if len(df_cp[df_cp[col] == 1]) < cv:
291
- bad_cols.append(col)
292
-
293
- elif len(df_cp[df_cp[col] == 2.0]) < cv:
294
- bad_cols.append(col)
295
-
296
- # pandas 1.X.X
297
- else:
298
- for col in df_cp.columns:
299
- if 0.0 not in df[col].unique() and 2.0 not in df[col].unique():
300
- bad_cols.append(col)
301
-
302
- elif len(df_cp[df_cp[col] == 0.0]) < cv:
303
- bad_cols.append(col)
304
-
305
- elif 1.0 in df_cp[col].unique():
306
- if len(df_cp[df_cp[col] == 1.0]) < cv:
307
- bad_cols.append(col)
308
-
309
- elif len(df_cp[df_cp[col] == 2.0]) < cv:
310
- bad_cols.append(col)
311
-
312
- if bad_cols:
313
- df_cp.drop(bad_cols, axis=1, inplace=True)
314
-
315
- print(
316
- f"{len(bad_cols)} columns removed for being non-biallelic or "
317
- f"having genotype counts < number of cross-validation "
318
- f"folds\nSubsetting from {len(df_cp.columns)} remaining columns\n"
319
- )
320
-
321
- return df_cp
322
-
323
-
324
- def get_indices(l):
325
- """Takes a list and returns dict giving indices matching each possible
326
- list member.
327
-
328
- Example:
329
- Input [0, 1, 1, 0, 0]
330
- Output {0:[0,3,4], 1:[1,2]}
331
- """
332
- ret = dict()
333
- for member in set(l):
334
- ret[member] = list()
335
- i = 0
336
- for el in l:
337
- ret[el].append(i)
338
- i += 1
339
- return ret
340
-
341
-
342
- def all_zero(l):
343
- """Check whether list consists of all zeros.
344
-
345
- Returns TRUE if supplied list contains all zeros
346
- Returns FALSE if list contains ANY non-zero values
347
- Returns FALSE if list is empty.
348
-
349
- Args:
350
- l (List[int]): List to check.
351
-
352
- Returns:
353
- bool: True if all zeros, False otherwise.
354
- """
355
- values = set(l)
356
- if len(values) > 1:
357
- return False
358
- elif len(values) == 1 and l[0] in [0, 0.0, "0", "0.0"]:
359
- return True
360
- else:
361
- return False
362
-
363
-
364
- def weighted_draw(d, num_samples=1):
365
- choices = list(d.keys())
366
- weights = list(d.values())
367
- return np.random.choice(choices, num_samples, p=weights)
368
-
369
-
370
- def timer(func):
371
- """print the runtime of the decorated function in the format HH:MM:SS."""
372
-
373
- @functools.wraps(func)
374
- def wrapper_timer(*args, **kwargs):
375
- start_time = time.perf_counter()
376
- value = func(*args, **kwargs)
377
- end_time = time.perf_counter()
378
- run_time = end_time - start_time
379
- final_runtime = str(datetime.timedelta(seconds=run_time))
380
- print(f"Finshed {func.__name__!r} in {final_runtime}\n")
381
- return value
382
-
383
- return wrapper_timer
384
-
385
-
386
- def progressbar(it, prefix="", size=60, file=sys.stdout):
387
- count = len(it)
388
-
389
- def show(j):
390
- x = int(size * j / count)
391
- file.write(
392
- "%s[%s%s] %i/%i\r" % (prefix, "#" * x, "." * (size - x), j, count)
393
- )
394
- file.flush()
395
-
396
- show(0)
397
- for i, item in enumerate(it):
398
- yield item
399
- show(i + 1)
400
- file.write("\n")
401
- file.flush()
402
-
403
-
404
- def isnotebook():
405
- """Checks whether in Jupyter notebook.
406
-
407
- Returns:
408
- bool: True if in Jupyter notebook, False otherwise.
409
- """
410
- try:
411
- shell = get_ipython().__class__.__name__
412
- if shell == "ZMQInteractiveShell":
413
- # Jupyter notebook or qtconsole
414
- return True
415
- elif shell == "TerminalInteractiveShell":
416
- # Terminal running IPython
417
- return False
418
- else:
419
- # Other type (?)
420
- return False
421
- except NameError:
422
- # Probably standard Python interpreter
423
- return False
424
-
425
-
426
- def get_processor_name():
427
- if platform.system() == "Windows":
428
- return platform.processor()
429
- elif platform.system() == "Darwin":
430
- # os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
431
- arch = platform.processor()
432
- if arch[0] == "i":
433
- return "Intel"
434
- else:
435
- return arch
436
- elif platform.system() == "Linux":
437
- command = "cat /proc/cpuinfo"
438
- all_info = subprocess.check_output(command, shell=True).strip()
439
- all_info = all_info.decode("utf-8")
440
- for line in all_info.split("\n"):
441
- if "model name" in line:
442
- return re.sub(".*model name.*:", "", line, 1)
443
- return ""
444
-
445
-
446
- class tqdm_linux(tqdm):
447
- """Adds a dynamically updating progress bar.
448
-
449
- Decorate an iterable object, with a dynamically updating progressbar every time a value is requested.
450
- """
451
-
452
- @staticmethod
453
- def status_printer(self, file):
454
- """Manage the printing and in-place updating of a line of characters.
455
-
456
- NOTE: If the string is longer than a line, then in-place updating may not work (it will print a new line at each refresh).
457
-
458
- Overridden to work with linux HPC clusters. Replaced carriage return with linux newline in fp_write function.
459
-
460
- Args:
461
- file (str): Path of file to print status to.
462
- """
463
-
464
- fp = file
465
- fp_flush = getattr(fp, "flush", lambda: None)
466
-
467
- def fp_write(s):
468
- fp.write(_unicode(s))
469
- fp_flush()
470
-
471
- last_len = [0]
472
-
473
- def print_status(s):
474
- len_s = disp_len(s)
475
- fp_write("\n" + s + (" " * max(last_len[0] - len_s, 0)))
476
- last_len[0] = len_s
477
-
478
- return print_status
479
-
480
-
481
- class HiddenPrints:
482
- """Class to supress printing within a with statement."""
483
-
484
- def __enter__(self):
485
- self._original_stdout = sys.stdout
486
- sys.stdout = open(os.devnull, "w")
487
-
488
- def __exit__(self, exc_type, exc_val, exc_tb):
489
- sys.stdout.close()
490
- sys.stdout = self._original_stdout
491
-
492
-
493
- class StreamToLogger(object):
494
- """Fake file-like stream object that redirects writes to a logger instance."""
495
-
496
- def __init__(self, logger, log_level=logging.INFO):
497
- self.logger = logger
498
- self.log_level = log_level
499
- self.linebuf = ""
500
-
501
- def write(self, buf):
502
- temp_linebuf = self.linebuf + buf
503
- self.linebuf = ""
504
- for line in temp_linebuf.splitlines(True):
505
- # From the io.TextIOWrapper docs:
506
- # On output, if newline is None, any '\n' characters written
507
- # are translated to the system default line separator.
508
- # By default sys.stdout.write() expects '\n' newlines and then
509
- # translates them so this is still cross platform.
510
- if line[-1] == "\n":
511
- self.logger.log(self.log_level, line.rstrip())
512
- else:
513
- self.linebuf += line
514
-
515
- def flush(self):
516
- if self.linebuf != "":
517
- self.logger.log(self.log_level, self.linebuf.rstrip())
518
- self.linebuf = ""
65
+ elif return_type == "tensor":
66
+ if isinstance(X, torch.Tensor):
67
+ return X
68
+ elif isinstance(X, np.ndarray):
69
+ return torch.tensor(X, dtype=torch.float32)
70
+ elif isinstance(X, pd.DataFrame):
71
+ return torch.tensor(X.to_numpy(), dtype=torch.float32)
72
+ elif isinstance(X, list):
73
+ return torch.tensor(X, dtype=torch.float32)