pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/RECORD +0 -75
  83. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
pgsui/utils/misc.py CHANGED
@@ -1,44 +1,38 @@
1
- import sys
2
- import os
3
- import functools
4
- import time
5
- import datetime
6
- import platform
7
- import subprocess
8
- import re
9
- import logging
1
+ from typing import Literal
10
2
 
11
3
  import numpy as np
12
4
  import pandas as pd
13
- from tqdm import tqdm
14
- from tqdm.utils import disp_len, _unicode # for overriding status_print
5
+ import torch
15
6
 
16
7
 
17
- # from skopt import BayesSearchCV
18
-
19
-
20
- def validate_input_type(X, return_type="array"):
8
+ def validate_input_type(
9
+ X: pd.DataFrame | np.ndarray | list | torch.Tensor,
10
+ return_type: Literal["array", "df", "list", "tensor"] = "array",
11
+ ) -> pd.DataFrame | np.ndarray | list | torch.Tensor:
21
12
  """Validate input type and return as numpy array.
22
13
 
14
+ This method validates the input type and returns the input data as a numpy array, pandas DataFrame, 2D list, or torch.Tensor.
15
+
23
16
  Args:
24
- X (pandas.DataFrame, numpy.ndarray, or List[List[int]]): Input data.
17
+ X (pandas.DataFrame | numpy.ndarray | list | torch.Tensor): Input data. Supported types include: pandas.DataFrame, numpy.ndarray, list, and torch.Tensor.
25
18
 
26
- return_type (str): Type of returned object. Supported options include: "df", "array", and "list". Defaults to "array".
19
+ return_type (Literal["array", "df", "list", "tensor"]): Type of returned object. Supported options include: "df", "array", "list", and "tensor". "df" corresponds to a pandas DataFrame. "array" corresponds to a numpy array. "list" corresponds to a 2D list. "tensor" corresponds to a torch.Tensor. Defaults to "array".
27
20
 
28
21
  Returns:
29
- pandas.DataFrame, numpy.ndarray, or List[List[int]]: Input data desired return_type.
22
+ pandas.DataFrame | numpy.ndarray | list | torch.Tensor: Input data as the desired return_type.
30
23
 
31
24
  Raises:
32
- TypeError: X must be of type pandas.DataFrame, numpy.ndarray, or List[List[int]].
33
-
34
- ValueError: Unsupported return_type provided. Supported types are "df", "array", and "list".
25
+ TypeError: X must be of type pandas.DataFrame, numpy.ndarray, list, or torch.Tensor.
26
+ ValueError: Unsupported return_type provided. Supported types are "df", "array", "list", and "tensor".
35
27
 
36
28
  """
37
- if not isinstance(X, (pd.DataFrame, np.ndarray, list)):
38
- raise TypeError(
39
- f"X must be of type pandas.DataFrame, numpy.ndarray, "
40
- f"or List[List[int]], but got {type(X)}"
41
- )
29
+ if not isinstance(X, (pd.DataFrame, np.ndarray, list, torch.Tensor)):
30
+ msg = f"X must be of type pandas.DataFrame, numpy.ndarray, list, or torch.Tensor, but got {type(X)}"
31
+ raise TypeError(msg)
32
+
33
+ if return_type not in {"df", "array", "list", "tensor"}:
34
+ msg = f"Unsupported return type provided: {return_type}. Supported types are 'df', 'array', 'list', and 'tensor'"
35
+ raise ValueError(msg)
42
36
 
43
37
  if return_type == "array":
44
38
  if isinstance(X, pd.DataFrame):
@@ -47,12 +41,16 @@ def validate_input_type(X, return_type="array"):
47
41
  return np.array(X)
48
42
  elif isinstance(X, np.ndarray):
49
43
  return X.copy()
44
+ elif isinstance(X, torch.Tensor):
45
+ return X.cpu().detach().numpy()
50
46
 
51
47
  elif return_type == "df":
52
48
  if isinstance(X, pd.DataFrame):
53
49
  return X.copy()
54
50
  elif isinstance(X, (np.ndarray, list)):
55
51
  return pd.DataFrame(X)
52
+ elif isinstance(X, torch.Tensor):
53
+ return pd.DataFrame(X.cpu().detach().numpy())
56
54
 
57
55
  elif return_type == "list":
58
56
  if isinstance(X, list):
@@ -61,458 +59,15 @@ def validate_input_type(X, return_type="array"):
61
59
  return X.tolist()
62
60
  elif isinstance(X, pd.DataFrame):
63
61
  return X.values.tolist()
62
+ elif isinstance(X, torch.Tensor):
63
+ return X.cpu().detach().numpy().tolist()
64
64
 
65
- else:
66
- raise ValueError(
67
- f"Unsupported return type provided: {return_type}. Supported types "
68
- f"are 'df', 'array', and 'list'"
69
- )
70
-
71
-
72
- def generate_random_dataset(
73
- min_value=0,
74
- max_value=2,
75
- nrows=35,
76
- ncols=20,
77
- min_missing_rate=0.15,
78
- max_missing_rate=0.5,
79
- ):
80
- """Generate a random integer dataset that can be used for testing.
81
-
82
- Will also add randomly missing values of random proportions between ``min_missing_rate`` and ``max_missing_rate``.
83
-
84
- Args:
85
- min_value (int, optional): Minimum value to use. Defaults to 0.
86
-
87
- max_value (int, optional): Maxiumum value to use. Defaults to 2.
88
-
89
- nrows (int, optional): Number of rows to use. Defaults to 35.
90
-
91
- ncols (int, optional): Number of columns to use. Defaults to 20.
92
-
93
- min_missing_rate (float, optional): Minimum proportion of missing data per column. Defaults to 0.15.
94
-
95
- max_missing_rate (float, optional): Maximum proportion of missing data per column.
96
-
97
- Returns:
98
- numpy.ndarray: Numpy array with randomly generated dataset.
99
- """
100
- assert (
101
- min_missing_rate >= 0 and min_missing_rate < 1.0
102
- ), f"min_missing_rate must be >= 0 and < 1.0, but got {min_missing_rate}"
103
-
104
- assert (
105
- max_missing_rate > 0 and max_missing_rate < 1.0
106
- ), f"max_missing_rate must be > 0 and < 1.0, but got {max_missing_rate}"
107
-
108
- assert nrows > 1, f"nrows must be > 1, but got {nrows}"
109
- assert ncols > 1, f"ncols must be > 1, but got {ncols}"
110
-
111
- try:
112
- min_missing_rate = float(min_missing_rate)
113
- max_missing_rate = float(max_missing_rate)
114
- except TypeError:
115
- sys.exit(
116
- "min_missing_rate and max_missing_rate must be of type float or "
117
- "must be cast-able to type float"
118
- )
119
-
120
- X = np.random.randint(
121
- min_value, max_value + 1, size=(nrows, ncols)
122
- ).astype(float)
123
- for i in range(X.shape[1]):
124
- drop_rate = int(
125
- np.random.choice(
126
- np.arange(min_missing_rate, max_missing_rate, 0.02), 1
127
- )[0]
128
- * X.shape[0]
129
- )
130
-
131
- rows = np.random.choice(np.arange(0, X.shape[0]), size=drop_rate)
132
- X[rows, i] = np.nan
133
-
134
- return X
135
-
136
-
137
- def generate_012_genotypes(
138
- nrows=35,
139
- ncols=20,
140
- max_missing_rate=0.5,
141
- min_het_rate=0.001,
142
- max_het_rate=0.3,
143
- min_alt_rate=0.001,
144
- max_alt_rate=0.3,
145
- ):
146
- """Generate random 012-encoded genotypes.
147
-
148
- Allows users to control the rate of reference, heterozygote, and alternate alleles. Will insert a random proportion between ``min_het_rate`` and ``max_het_rate`` and ``min_alt_rate`` and ``max_alt_rate`` and from no misssing data to a proportion of ``max_missing_rate``.
149
-
150
- Args:
151
- nrows (int, optional): Number of rows to generate. Defaults to 35.
152
-
153
- ncols (int, optional): Number of columns to generate. Defaults to 20.
154
-
155
- max_missing_rate (float, optional): Maximum proportion of missing data to use. Defaults to 0.5.
156
-
157
- min_het_rate (float, optional): Minimum proportion of heterozygotes (1's) to insert. Defaults to 0.001.
158
-
159
- max_het_rate (float, optional): Maximum proportion of heterozygotes (1's) to insert. Defaults to 0.3.
160
-
161
- min_alt_rate (float, optional): Minimum proportion of alternate alleles (2's) to insert. Defaults to 0.001.
162
-
163
- max_alt_rate (float, optional): Maximum proportion of alternate alleles (2's) to insert. Defaults to 0.3.
164
- """
165
- assert (
166
- min_het_rate > 0 and min_het_rate <= 1.0
167
- ), f"min_het_rate must be > 0 and <= 1.0, but got {min_het_rate}"
168
-
169
- assert (
170
- max_het_rate > 0 and max_het_rate <= 1.0
171
- ), f"max_het_rate must be > 0 and <= 1.0, but got {max_het_rate}"
172
-
173
- assert (
174
- min_alt_rate > 0 and min_alt_rate <= 1.0
175
- ), f"min_alt_rate must be > 0 and <= 1.0, but got {min_alt_rate}"
176
-
177
- assert (
178
- max_alt_rate > 0 and max_alt_rate <= 1.0
179
- ), f"max_alt_rate must be > 0 and <= 1.0, but got {max_alt_rate}"
180
-
181
- assert nrows > 1, f"The number of rows must be > 1, but got {nrows}"
182
-
183
- assert ncols > 1, f"The number of columns must be > 1, but got {ncols}"
184
-
185
- assert (
186
- max_missing_rate > 0 and max_missing_rate < 1.0
187
- ), f"max_missing rate must be > 0 and < 1.0, but got {max_missing_rate}"
188
-
189
- try:
190
- min_het_rate = float(min_het_rate)
191
- max_het_rate = float(max_het_rate)
192
- min_alt_rate = float(min_alt_rate)
193
- max_alt_rate = float(max_alt_rate)
194
- max_missing_rate = float(max_missing_rate)
195
- except TypeError:
196
- sys.exit(
197
- "max_missing_rate, min_het_rate, max_het_rate, min_alt_rate, and "
198
- "max_alt_rate must be of type float, or must be cast-able to type "
199
- "float"
200
- )
201
-
202
- X = np.zeros((nrows, ncols))
203
- for i in range(X.shape[1]):
204
- het_rate = int(
205
- np.ceil(
206
- np.random.choice(
207
- np.arange(min_het_rate, max_het_rate, 0.02), 1
208
- )[0]
209
- * X.shape[0]
210
- )
211
- )
212
-
213
- alt_rate = int(
214
- np.ceil(
215
- np.random.choice(
216
- np.arange(min_alt_rate, max_alt_rate, 0.02), 1
217
- )[0]
218
- * X.shape[0]
219
- )
220
- )
221
-
222
- het = np.sort(
223
- np.random.choice(
224
- np.arange(0, X.shape[0]), size=het_rate, replace=False
225
- )
226
- )
227
-
228
- alt = np.sort(
229
- np.random.choice(
230
- np.arange(0, X.shape[0]), size=alt_rate, replace=False
231
- )
232
- )
233
-
234
- sidx = alt.argsort()
235
- idx = np.searchsorted(alt, het, sorter=sidx)
236
- idx[idx == len(alt)] = 0
237
- het_unique = het[alt[sidx[idx]] != het]
238
-
239
- X[alt, i] = 2
240
- X[het_unique, i] = 1
241
-
242
- drop_rate = int(
243
- np.random.choice(np.arange(0.15, max_missing_rate, 0.02), 1)[0]
244
- * X.shape[0]
245
- )
246
-
247
- missing = np.random.choice(np.arange(0, X.shape[0]), size=drop_rate)
248
-
249
- X[missing, i] = np.nan
250
-
251
- print(
252
- f"Created a dataset of shape {X.shape} with {np.isnan(X).sum()} total missing values"
253
- )
254
-
255
- return X
256
-
257
-
258
- def unique2D_subarray(a):
259
- """Get unique subarrays for each column from numpy array."""
260
- dtype1 = np.dtype((np.void, a.dtype.itemsize * np.prod(a.shape[1:])))
261
- b = np.ascontiguousarray(a.reshape(a.shape[0], -1)).view(dtype1)
262
- return a[np.unique(b, return_index=1, axis=-1)[1]]
263
-
264
-
265
- def _remove_nonbiallelic(df, cv=5):
266
- """Remove non-biallelic sites from pandas.DataFrame.
267
-
268
- Remove sites that do not have both 0 and 2 encoded values in a column and if any of the allele counts is less than the number of cross-validation folds.
269
-
270
- Args:
271
- df (pandas.DataFrame): DataFrame with 012-encoded genotypes.
272
-
273
- Returns:
274
- pandas.DataFrame: DataFrame with non-biallelic sites dropped.
275
- """
276
- df_cp = df.copy()
277
- bad_cols = list()
278
- if pd.__version__[0] == 0:
279
- for col in df_cp.columns:
280
- if (
281
- not df_cp[col].isin([0.0]).any()
282
- or not df_cp[col].isin([2.0]).any()
283
- ):
284
- bad_cols.append(col)
285
-
286
- elif len(df_cp[df_cp[col] == 0.0]) < cv:
287
- bad_cols.append(col)
288
-
289
- elif df_cp[col].isin([1.0]).any():
290
- if len(df_cp[df_cp[col] == 1]) < cv:
291
- bad_cols.append(col)
292
-
293
- elif len(df_cp[df_cp[col] == 2.0]) < cv:
294
- bad_cols.append(col)
295
-
296
- # pandas 1.X.X
297
- else:
298
- for col in df_cp.columns:
299
- if 0.0 not in df[col].unique() and 2.0 not in df[col].unique():
300
- bad_cols.append(col)
301
-
302
- elif len(df_cp[df_cp[col] == 0.0]) < cv:
303
- bad_cols.append(col)
304
-
305
- elif 1.0 in df_cp[col].unique():
306
- if len(df_cp[df_cp[col] == 1.0]) < cv:
307
- bad_cols.append(col)
308
-
309
- elif len(df_cp[df_cp[col] == 2.0]) < cv:
310
- bad_cols.append(col)
311
-
312
- if bad_cols:
313
- df_cp.drop(bad_cols, axis=1, inplace=True)
314
-
315
- print(
316
- f"{len(bad_cols)} columns removed for being non-biallelic or "
317
- f"having genotype counts < number of cross-validation "
318
- f"folds\nSubsetting from {len(df_cp.columns)} remaining columns\n"
319
- )
320
-
321
- return df_cp
322
-
323
-
324
- def get_indices(l):
325
- """Takes a list and returns dict giving indices matching each possible
326
- list member.
327
-
328
- Example:
329
- Input [0, 1, 1, 0, 0]
330
- Output {0:[0,3,4], 1:[1,2]}
331
- """
332
- ret = dict()
333
- for member in set(l):
334
- ret[member] = list()
335
- i = 0
336
- for el in l:
337
- ret[el].append(i)
338
- i += 1
339
- return ret
340
-
341
-
342
- def all_zero(l):
343
- """Check whether list consists of all zeros.
344
-
345
- Returns TRUE if supplied list contains all zeros
346
- Returns FALSE if list contains ANY non-zero values
347
- Returns FALSE if list is empty.
348
-
349
- Args:
350
- l (List[int]): List to check.
351
-
352
- Returns:
353
- bool: True if all zeros, False otherwise.
354
- """
355
- values = set(l)
356
- if len(values) > 1:
357
- return False
358
- elif len(values) == 1 and l[0] in [0, 0.0, "0", "0.0"]:
359
- return True
360
- else:
361
- return False
362
-
363
-
364
- def weighted_draw(d, num_samples=1):
365
- choices = list(d.keys())
366
- weights = list(d.values())
367
- return np.random.choice(choices, num_samples, p=weights)
368
-
369
-
370
- def timer(func):
371
- """print the runtime of the decorated function in the format HH:MM:SS."""
372
-
373
- @functools.wraps(func)
374
- def wrapper_timer(*args, **kwargs):
375
- start_time = time.perf_counter()
376
- value = func(*args, **kwargs)
377
- end_time = time.perf_counter()
378
- run_time = end_time - start_time
379
- final_runtime = str(datetime.timedelta(seconds=run_time))
380
- print(f"Finshed {func.__name__!r} in {final_runtime}\n")
381
- return value
382
-
383
- return wrapper_timer
384
-
385
-
386
- def progressbar(it, prefix="", size=60, file=sys.stdout):
387
- count = len(it)
388
-
389
- def show(j):
390
- x = int(size * j / count)
391
- file.write(
392
- "%s[%s%s] %i/%i\r" % (prefix, "#" * x, "." * (size - x), j, count)
393
- )
394
- file.flush()
395
-
396
- show(0)
397
- for i, item in enumerate(it):
398
- yield item
399
- show(i + 1)
400
- file.write("\n")
401
- file.flush()
402
-
403
-
404
- def isnotebook():
405
- """Checks whether in Jupyter notebook.
406
-
407
- Returns:
408
- bool: True if in Jupyter notebook, False otherwise.
409
- """
410
- try:
411
- shell = get_ipython().__class__.__name__
412
- if shell == "ZMQInteractiveShell":
413
- # Jupyter notebook or qtconsole
414
- return True
415
- elif shell == "TerminalInteractiveShell":
416
- # Terminal running IPython
417
- return False
418
- else:
419
- # Other type (?)
420
- return False
421
- except NameError:
422
- # Probably standard Python interpreter
423
- return False
424
-
425
-
426
- def get_processor_name():
427
- if platform.system() == "Windows":
428
- return platform.processor()
429
- elif platform.system() == "Darwin":
430
- # os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
431
- arch = platform.processor()
432
- if arch[0] == "i":
433
- return "Intel"
434
- else:
435
- return arch
436
- elif platform.system() == "Linux":
437
- command = "cat /proc/cpuinfo"
438
- all_info = subprocess.check_output(command, shell=True).strip()
439
- all_info = all_info.decode("utf-8")
440
- for line in all_info.split("\n"):
441
- if "model name" in line:
442
- return re.sub(".*model name.*:", "", line, 1)
443
- return ""
444
-
445
-
446
- class tqdm_linux(tqdm):
447
- """Adds a dynamically updating progress bar.
448
-
449
- Decorate an iterable object, with a dynamically updating progressbar every time a value is requested.
450
- """
451
-
452
- @staticmethod
453
- def status_printer(self, file):
454
- """Manage the printing and in-place updating of a line of characters.
455
-
456
- NOTE: If the string is longer than a line, then in-place updating may not work (it will print a new line at each refresh).
457
-
458
- Overridden to work with linux HPC clusters. Replaced carriage return with linux newline in fp_write function.
459
-
460
- Args:
461
- file (str): Path of file to print status to.
462
- """
463
-
464
- fp = file
465
- fp_flush = getattr(fp, "flush", lambda: None)
466
-
467
- def fp_write(s):
468
- fp.write(_unicode(s))
469
- fp_flush()
470
-
471
- last_len = [0]
472
-
473
- def print_status(s):
474
- len_s = disp_len(s)
475
- fp_write("\n" + s + (" " * max(last_len[0] - len_s, 0)))
476
- last_len[0] = len_s
477
-
478
- return print_status
479
-
480
-
481
- class HiddenPrints:
482
- """Class to supress printing within a with statement."""
483
-
484
- def __enter__(self):
485
- self._original_stdout = sys.stdout
486
- sys.stdout = open(os.devnull, "w")
487
-
488
- def __exit__(self, exc_type, exc_val, exc_tb):
489
- sys.stdout.close()
490
- sys.stdout = self._original_stdout
491
-
492
-
493
- class StreamToLogger(object):
494
- """Fake file-like stream object that redirects writes to a logger instance."""
495
-
496
- def __init__(self, logger, log_level=logging.INFO):
497
- self.logger = logger
498
- self.log_level = log_level
499
- self.linebuf = ""
500
-
501
- def write(self, buf):
502
- temp_linebuf = self.linebuf + buf
503
- self.linebuf = ""
504
- for line in temp_linebuf.splitlines(True):
505
- # From the io.TextIOWrapper docs:
506
- # On output, if newline is None, any '\n' characters written
507
- # are translated to the system default line separator.
508
- # By default sys.stdout.write() expects '\n' newlines and then
509
- # translates them so this is still cross platform.
510
- if line[-1] == "\n":
511
- self.logger.log(self.log_level, line.rstrip())
512
- else:
513
- self.linebuf += line
514
-
515
- def flush(self):
516
- if self.linebuf != "":
517
- self.logger.log(self.log_level, self.linebuf.rstrip())
518
- self.linebuf = ""
65
+ elif return_type == "tensor":
66
+ if isinstance(X, torch.Tensor):
67
+ return X
68
+ elif isinstance(X, np.ndarray):
69
+ return torch.tensor(X, dtype=torch.float32)
70
+ elif isinstance(X, pd.DataFrame):
71
+ return torch.tensor(X.to_numpy(), dtype=torch.float32)
72
+ elif isinstance(X, list):
73
+ return torch.tensor(X, dtype=torch.float32)