maradoner 0.24.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,100 @@
1
+ Metadata-Version: 2.4
2
+ Name: maradoner
3
+ Version: 0.24.1
4
+ Summary: Variance-adjusted estimation of motif activities.
5
+ Home-page: https://github.com/autosome-ru/maradoner
6
+ Author: Georgy Meshcheryakov
7
+ Author-email: iam@georgy.top
8
+ Classifier: Programming Language :: Python :: 3.11
9
+ Classifier: Topic :: Scientific/Engineering
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Python: >=3.11
12
+ Description-Content-Type: text/markdown
13
+ Requires-Dist: pip>=24.0
14
+ Requires-Dist: typer>=0.13
15
+ Requires-Dist: numpy>=2.1
16
+ Requires-Dist: jax<0.5
17
+ Requires-Dist: jaxlib<0.5
18
+ Requires-Dist: matplotlib>=3.5
19
+ Requires-Dist: pandas>=2.2
20
+ Requires-Dist: scipy>=1.14
21
+ Requires-Dist: statsmodels>=0.14
22
+ Requires-Dist: datatable>=1.0.0
23
+ Requires-Dist: dill>=0.3.9
24
+ Requires-Dist: rich>=12.6.0
25
+ Requires-Dist: tqdm>=4.0
26
+ Requires-Dist: scikit-learn>=1.6
27
+ Requires-Dist: tables>=3.10
28
+ Requires-Dist: sympy>=1.12
29
+ Requires-Dist: seaborn>=0.12
30
+ Dynamic: author
31
+ Dynamic: author-email
32
+ Dynamic: classifier
33
+ Dynamic: description
34
+ Dynamic: description-content-type
35
+ Dynamic: home-page
36
+ Dynamic: requires-dist
37
+ Dynamic: requires-python
38
+ Dynamic: summary
39
+
40
+
41
+ **MARADONER**
42
+
43
+ # MARADONER: Motif Activity Response Analysis Done Right
44
+
45
+ MARADONER is a tool for analyzing motif activities using promoter expression data. It provides a streamlined workflow to estimate parameters, predict deviations, and export results in a tabular form.
46
+
47
+ ## Basic Workflow
48
+
49
+
50
+ A typical MARADONER analysis session involves running commands sequentially for a given project:
51
+
52
+ 1. **`create`**: Initialize the project. This step parses your input files (promoter expression, motif loadings, optional motif expression, and sample groupings), performs initial filtering, and sets up the project's internal data structures.
53
+ ```bash
54
+ # Example: Initialize a project named 'my_project'
55
+ maradoner create my_project path/to/expression.tsv path/to/loadings.tsv --sample-groups path/to/groups.json [other options...]
56
+ ```
57
+ * Input files are typically tabular (.tsv, .csv), potentially compressed.
58
+ * You only need to provide input data files at this stage.
59
+
60
+ 2. **`fit`**: Estimate the model's variance parameters and mean motif activities using the data prepared by `create`.
61
+ ```bash
62
+ maradoner fit my_project [options...]
63
+ ```
64
+
65
+ 3. **`predict`**: Estimate the *deviations* of motif activities from their means for each sample or group, based on the parameters estimated by `fit`.
66
+ ```bash
67
+ maradoner predict my_project [options...]
68
+ ```
69
+
70
+ 4. **`export`**: Save the final results, including estimated motif activities (mean + deviations), parameter estimates, goodness-of-fit statistics, and potentially statistical test results (like ANOVA) to a specified output folder.
71
+ ```bash
72
+ maradoner export my_project path/to/output_folder [options...]
73
+ ```
74
+
75
+ ## Other Useful Commands
76
+
77
+ * **`gof`**: After `fit`, calculate Goodness-of-Fit statistics (like Fraction of Variance Explained or Correlation) to evaluate how well the model components explain the observed expression data.
78
+ ```bash
79
+ maradoner gof my_project [options...]
80
+ ```
81
+ * **`select-motifs`**: If you provided multiple loading matrices in `create` (e.g., from different databases) with unique postfixes, this command helps select the single "best" variant for each motif based on statistical criteria. The output is a list of motif names intended to be used with the `--motif-filename` option in a subsequent `create` run.
82
+ ```bash
83
+ maradoner select-motifs my_project best_motifs.txt
84
+ # Then, potentially re-run create using the generated list:
85
+ # maradoner create my_project_filtered ... --motif-filename best_motifs.txt
86
+ ```
87
+ * **`generate`**: Create a synthetic dataset with known properties for testing or demonstration purposes.
88
+ ```bash
89
+ maradoner generate path/to/synthetic_data_output [options...]
90
+ ```
91
+
92
+ ## Getting Help
93
+
94
+ Each command has various options for customization. To see the full list of commands and their detailed options, use the `--help` flag:
95
+
96
+ ```bash
97
+ maradoner --help
98
+ maradoner create --help
99
+ maradoner fit --help
100
+ # and so on for each command
@@ -0,0 +1,61 @@
1
+
2
+ **MARADONER**
3
+
4
+ # MARADONER: Motif Activity Response Analysis Done Right
5
+
6
+ MARADONER is a tool for analyzing motif activities using promoter expression data. It provides a streamlined workflow to estimate parameters, predict deviations, and export results in a tabular form.
7
+
8
+ ## Basic Workflow
9
+
10
+
11
+ A typical MARADONER analysis session involves running commands sequentially for a given project:
12
+
13
+ 1. **`create`**: Initialize the project. This step parses your input files (promoter expression, motif loadings, optional motif expression, and sample groupings), performs initial filtering, and sets up the project's internal data structures.
14
+ ```bash
15
+ # Example: Initialize a project named 'my_project'
16
+ maradoner create my_project path/to/expression.tsv path/to/loadings.tsv --sample-groups path/to/groups.json [other options...]
17
+ ```
18
+ * Input files are typically tabular (.tsv, .csv), potentially compressed.
19
+ * You only need to provide input data files at this stage.
20
+
21
+ 2. **`fit`**: Estimate the model's variance parameters and mean motif activities using the data prepared by `create`.
22
+ ```bash
23
+ maradoner fit my_project [options...]
24
+ ```
25
+
26
+ 3. **`predict`**: Estimate the *deviations* of motif activities from their means for each sample or group, based on the parameters estimated by `fit`.
27
+ ```bash
28
+ maradoner predict my_project [options...]
29
+ ```
30
+
31
+ 4. **`export`**: Save the final results, including estimated motif activities (mean + deviations), parameter estimates, goodness-of-fit statistics, and potentially statistical test results (like ANOVA) to a specified output folder.
32
+ ```bash
33
+ maradoner export my_project path/to/output_folder [options...]
34
+ ```
35
+
36
+ ## Other Useful Commands
37
+
38
+ * **`gof`**: After `fit`, calculate Goodness-of-Fit statistics (like Fraction of Variance Explained or Correlation) to evaluate how well the model components explain the observed expression data.
39
+ ```bash
40
+ maradoner gof my_project [options...]
41
+ ```
42
+ * **`select-motifs`**: If you provided multiple loading matrices in `create` (e.g., from different databases) with unique postfixes, this command helps select the single "best" variant for each motif based on statistical criteria. The output is a list of motif names intended to be used with the `--motif-filename` option in a subsequent `create` run.
43
+ ```bash
44
+ maradoner select-motifs my_project best_motifs.txt
45
+ # Then, potentially re-run create using the generated list:
46
+ # maradoner create my_project_filtered ... --motif-filename best_motifs.txt
47
+ ```
48
+ * **`generate`**: Create a synthetic dataset with known properties for testing or demonstration purposes.
49
+ ```bash
50
+ maradoner generate path/to/synthetic_data_output [options...]
51
+ ```
52
+
53
+ ## Getting Help
54
+
55
+ Each command has various options for customization. To see the full list of commands and their detailed options, use the `--help` flag:
56
+
57
+ ```bash
58
+ maradoner --help
59
+ maradoner create --help
60
+ maradoner fit --help
61
+ # and so on for each command
@@ -0,0 +1,38 @@
1
+ # -*- coding: utf-8 -*-
2
+ __version__ = '0.24.1'
3
+ import importlib
4
+
5
+
6
+ __min_reqs__ = [
7
+ 'pip>=24.0',
8
+ 'typer>=0.13',
9
+ 'numpy>=2.1',
10
+ 'jax<0.5',
11
+ 'jaxlib<0.5',
12
+ 'matplotlib>=3.5',
13
+ 'pandas>=2.2',
14
+ 'scipy>=1.14',
15
+ 'statsmodels>=0.14',
16
+ 'datatable>=1.0.0' ,
17
+ 'dill>=0.3.9',
18
+ 'rich>=12.6.0',
19
+ 'tqdm>=4.0',
20
+ 'scikit-learn>=1.6',
21
+ 'tables>=3.10',
22
+ 'sympy>=1.12',
23
+ 'seaborn>=0.12'
24
+ ]
25
+
26
+ def versiontuple(v):
27
+ return tuple(map(int, (v.split("."))))
28
+
29
+ def check_packages():
30
+ for req in __min_reqs__:
31
+ try:
32
+ module, ver = req.split(' @').split('>=')
33
+ ver = versiontuple(ver)
34
+ v = versiontuple(importlib.import_module(module).__version__)
35
+ except (AttributeError, ValueError):
36
+ continue
37
+ if v < ver:
38
+ raise ImportError(f'Version of the {module} package should be at least {ver} (found: {v}).')
@@ -0,0 +1,199 @@
1
+ from .utils import logger_print, openers
2
+ from .dataset_filter import filter_lowexp
3
+ from .drist import DRIST
4
+ import multiprocessing
5
+ import scipy.stats as st
6
+ import datatable as dt
7
+ import pandas as pd
8
+ import numpy as np
9
+ import dill
10
+ import json
11
+ import os
12
+ import re
13
+
14
+
15
+ def drist_it(B: pd.DataFrame, Y: pd.DataFrame, test_chromosomes: list[str] = None,
16
+ share_function: bool = False, optimizer='jacobi'):
17
+ if test_chromosomes:
18
+ pattern = re.compile(r'chr([0-9XYM]+|\d+)')
19
+
20
+ test_chromosomes = set(test_chromosomes)
21
+ mask = [pattern.search(p).group() in test_chromosomes for i, p in enumerate(Y.index)]
22
+ mask = ~np.array(mask, dtype=bool)
23
+ else:
24
+ mask = np.ones(len(B), dtype=bool)
25
+ Y = Y.values
26
+ Y = Y - Y.mean(axis=1, keepdims=True)
27
+ Bt = B.values[mask, :]
28
+ Y = Y[mask, :]
29
+ drist = DRIST(max_iter=1000, verbose=True, share_function=share_function,
30
+ optimizer=optimizer)
31
+ B.values[mask, :] = drist.fit_transform(Bt, Y)
32
+ if not np.all(mask):
33
+ B.values[~mask, :] = drist.transform(B.values[~mask, :])
34
+
35
+ B = B - B.min()
36
+ return B
37
+
38
+
39
+ def transform_loadings(df, mode: str, zero_cutoff=1e-9, prom_inds=None, Y=None,
40
+ test_chromosomes: list[str] = None):
41
+ stds = df.std()
42
+ drop_inds = (stds == 0) | np.isnan(stds)
43
+ if prom_inds is not None:
44
+ df = df.loc[prom_inds, ~drop_inds]
45
+ else:
46
+ df = df.loc[:, ~drop_inds]
47
+ # if not mode or mode == 'none':
48
+ # df[df < zero_cutoff] = 0
49
+ # df = (df - df.min(axis=None)) / (df.max(axis=None) - df.min(axis=None))
50
+ if mode == 'ecdf':
51
+ for j in range(len(df.columns)):
52
+ v = df.iloc[:, j]
53
+ df.iloc[:, j] = st.ecdf(v).cdf.evaluate(v)
54
+ elif mode in ('esf',):
55
+ for j in range(len(df.columns)):
56
+ v = df.iloc[:, j]
57
+ v = st.ecdf(v).sf.evaluate(v)
58
+ t = np.unique(v)[1]
59
+ v[v < t] = t
60
+ df.iloc[:, j] = -np.log(v)
61
+ # if mode == 'drist':
62
+ # df = drist_it(df, Y, test_chromosomes=test_chromosomes)
63
+ elif mode.startswith('drist'):
64
+ df = drist_it(df, Y, test_chromosomes=test_chromosomes,
65
+ share_function=mode.endswith('un'))
66
+ elif mode == 'none':
67
+ pass
68
+ elif mode:
69
+ raise Exception('Unknown transformation mode ' + str(mode))
70
+ return df
71
+
72
+ def create_project(project_name: str, promoter_expression_filename: str, loading_matrix_filenames: list[str],
73
+ motif_expression_filenames=None, loading_matrix_transformations=None, sample_groups=None, motif_postfixes=None,
74
+ promoter_filter_lowexp_cutoff=0.95, promoter_filter_plot_filename=None, promoter_filter_max=True,
75
+ motif_names_filename=None, n_jobs:float = 0.5, compression='raw', dump=True, verbose=True):
76
+ if not os.path.isfile(promoter_expression_filename):
77
+ raise FileNotFoundError(f'Promoter expression file {promoter_expression_filename} not found.')
78
+ if type(loading_matrix_filenames) is str:
79
+ loading_matrix_filenames = [loading_matrix_filenames]
80
+ for mx_name in loading_matrix_filenames:
81
+ if not os.path.isfile(mx_name):
82
+ raise FileNotFoundError(f'Loading matrix file {mx_name} not found.')
83
+ if motif_expression_filenames:
84
+ if type(motif_expression_filenames) is str:
85
+ motif_expression_filenames = [motif_expression_filenames]
86
+ for exp_name in motif_expression_filenames:
87
+ if not os.path.isfile(exp_name):
88
+ raise FileNotFoundError(f'Motif expresion file {exp_name} not found.')
89
+ if type(sample_groups) is str:
90
+ with open(sample_groups, 'r') as f:
91
+ if sample_groups.endswith('.json'):
92
+ sample_groups = json.load(f)
93
+ else:
94
+ sample_groups = dict()
95
+ for line in f:
96
+ items = line.split()
97
+ sample_groups[items[0]] = items[1:]
98
+ if motif_names_filename is not None:
99
+ with open(motif_names_filename, 'r') as f:
100
+ motif_names = list()
101
+ for line in f:
102
+ line = line.strip().split()
103
+ for item in line:
104
+ if item:
105
+ motif_names.append(item)
106
+ else:
107
+ motif_names = None
108
+ cpu_count = multiprocessing.cpu_count()
109
+ if n_jobs < 1 and n_jobs > 0:
110
+ n_jobs = int(n_jobs * cpu_count)
111
+ elif n_jobs <= 0:
112
+ n_jobs = cpu_count
113
+ logger_print('Reading dataset...', verbose)
114
+ promoter_expression = dt.fread(promoter_expression_filename, nthreads=n_jobs).to_pandas()
115
+ promoter_expression = promoter_expression.set_index(promoter_expression.columns[0])
116
+
117
+ if sample_groups:
118
+ cols = set()
119
+ for vals in sample_groups.values():
120
+ cols.update(vals)
121
+ cols = list(cols)
122
+ promoter_expression = promoter_expression[cols]
123
+
124
+ proms = promoter_expression.index
125
+ sample_names = promoter_expression.columns
126
+ loading_matrices = [dt.fread(f, nthreads=n_jobs).to_pandas() for f in loading_matrix_filenames]
127
+ loading_matrices = [df.set_index(df.columns[0]).loc[proms] for df in loading_matrices]
128
+ if loading_matrix_transformations is None or type(loading_matrix_transformations) is str:
129
+ loading_matrix_transformations = [loading_matrix_transformations] * len(loading_matrices)
130
+ else:
131
+ if len(loading_matrix_transformations) == 1:
132
+ loading_matrix_transformations = [loading_matrix_transformations[0]] * len(loading_matrices)
133
+ elif len(loading_matrix_transformations) != len(loading_matrices):
134
+ raise Exception(f'Total number of loading matrices is {len(loading_matrices)}, but the number of transformations is '
135
+ f'{len(loading_matrix_transformations)}.')
136
+
137
+ logger_print('Filtering promoters of low expression...', verbose)
138
+ inds, weights = filter_lowexp(promoter_expression, cutoff=promoter_filter_lowexp_cutoff, fit_plot_filename=promoter_filter_plot_filename,
139
+ max_mode=promoter_filter_max)
140
+ promoter_expression = promoter_expression.loc[inds]
141
+ proms = promoter_expression.index
142
+ test_chromosomes = list() # ['chr2', 'chr15']
143
+ loading_matrices = [transform_loadings(df, mode, prom_inds=inds, test_chromosomes=test_chromosomes,
144
+ Y=promoter_expression) for df, mode in zip(loading_matrices, loading_matrix_transformations)]
145
+ if motif_postfixes is not None:
146
+ for mx, postfix in zip(loading_matrices, motif_postfixes):
147
+ mx.columns = [f'{c}_{postfix}' for c in mx.columns]
148
+ if motif_expression_filenames:
149
+ motif_expression = [dt.fread(f, nthreads=n_jobs).to_pandas() for f in motif_expression_filenames]
150
+ motif_expression = [df.set_index(df.columns[0]) for df in motif_expression]
151
+ if motif_postfixes is not None:
152
+ for mx, postfix in zip(motif_expression, motif_postfixes):
153
+ mx.index = [f'{c}_{postfix}' for c in mx.index]
154
+ if sample_groups:
155
+ if len(set(motif_expression[0].columns) & set(sample_groups)) == len(sample_groups):
156
+ for i in range(len(motif_expression)):
157
+ mx = motif_expression[i]
158
+ for group, cols in sample_groups.items():
159
+ for col in cols:
160
+ mx[col] = mx[group]
161
+ mx = mx.drop(sorted(sample_groups), axis=1)
162
+ motif_expression = [df.loc[mx.columns, sample_names] for df, mx in zip(motif_expression, loading_matrices)]
163
+ motif_expression = pd.concat(motif_expression, axis=0)
164
+ else:
165
+ motif_expression = None
166
+ loading_matrices = pd.concat(loading_matrices, axis=1)
167
+ if motif_names is not None:
168
+ motif_names = list(set(motif_names) & set(loading_matrices.columns))
169
+ loading_matrices = loading_matrices[motif_names]
170
+ proms = list(promoter_expression.index)
171
+ sample_names = list(promoter_expression.columns)
172
+ motif_names = list(loading_matrices.columns)
173
+ loading_matrices = loading_matrices.values
174
+ promoter_expression = promoter_expression.values
175
+ if motif_expression is not None:
176
+ motif_expression = motif_expression.values
177
+ if not sample_groups:
178
+ sample_groups = {n: [i] for i, n in enumerate(sample_names)}
179
+ else:
180
+ sample_groups = {n: sorted([sample_names.index(i) for i in inds]) for n, inds in sample_groups.items()}
181
+ res = {'expression': promoter_expression,
182
+ 'loadings': loading_matrices,
183
+ 'motif_expression': motif_expression,
184
+ 'motif_postfixes': motif_postfixes,
185
+ 'promoter_names': proms,
186
+ 'sample_names': sample_names,
187
+ 'motif_names': motif_names,
188
+ 'weights': weights,
189
+ 'groups': sample_groups}
190
+ if dump:
191
+ folder = os.path.split(project_name)[0]
192
+ name = os.path.split(project_name)[-1]
193
+ for file in os.listdir(folder if folder else None):
194
+ if file.startswith(f'{name}.') and file.endswith(tuple(openers.keys())):
195
+ os.remove(os.path.join(folder, file))
196
+ logger_print('Saving project...', verbose)
197
+ with openers[compression](f'{project_name}.init.{compression}', 'wb') as f:
198
+ dill.dump(res, f)
199
+ return res
@@ -0,0 +1,152 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit, grad
3
+ from jax.scipy.stats import norm
4
+ from jax.scipy.special import logsumexp, logit, expit
5
+ import pandas as pd
6
+ import numpy as np
7
+ from scipy.optimize import minimize
8
+ from functools import partial
9
+ from sklearn.mixture import GaussianMixture
10
+
11
+ def compute_leftmost_probability(Y):
12
+ Y = Y.reshape(-1, 1)
13
+ gmm = GaussianMixture(n_components=2, random_state=0)
14
+ gmm.fit(Y)
15
+
16
+ means = gmm.means_.flatten()
17
+ leftmost_component_index = np.argmin(means)
18
+ probas = gmm.predict_proba(Y)
19
+ leftmost_probs = probas[:, leftmost_component_index]
20
+
21
+ return leftmost_probs, gmm
22
+
23
+ def normax_logpdf(x: jnp.ndarray, mu: float, sigma: float, n: int):
24
+ x = (x - mu) / sigma
25
+ return jnp.log(n) - jnp.log(sigma) + norm.logpdf(x) + (n - 1) * norm.logcdf(x)
26
+
27
+ def logmixture(x: jnp.ndarray, mus: jnp.ndarray, sigmas: jnp.ndarray, w: float, n: int):
28
+ logpdf1 = normax_logpdf(x, mus[0], sigmas[0], n)
29
+ logpdf2 = normax_logpdf(x, mus[1], sigmas[1], n)
30
+ w = jnp.array([w, 1 - w]).reshape(-1,1)
31
+ logpdf = jnp.array([logpdf1, logpdf2])
32
+ return logsumexp(logpdf, b=w, axis=0)
33
+
34
+
35
+ def transform(params, forward=True):
36
+ mu = params[:2]
37
+ sigma = params[2:4]
38
+ w = params[-1:]
39
+ if forward:
40
+ sigma = sigma ** 2
41
+ w = expit(w)
42
+ else:
43
+ sigma = sigma ** 0.5
44
+ w = logit(w)
45
+ return jnp.concatenate([mu, sigma, w])
46
+
47
+ def loglik(params: jnp.ndarray, x: jnp.ndarray, n: int):
48
+ params = transform(params)
49
+ mu = params[:2]
50
+ sigma = params[2:4]
51
+ w = params[-1]
52
+ return -logmixture(x, mu, sigma, w, n).sum()
53
+
54
+ def filter_lowexp(expression: pd.DataFrame, cutoff=0.95, component_limit=0.6, max_mode=True,
55
+ fit_plot_filename=None, plot_dpi=200):
56
+ expression = (expression - expression.mean()) / expression.std()
57
+ if not max_mode:
58
+ expression = expression.mean(axis=1).values
59
+ probs, gmm = compute_leftmost_probability(expression)
60
+ inds = probs < (1-cutoff)
61
+ if fit_plot_filename:
62
+ import matplotlib.pyplot as plt
63
+ from matplotlib.collections import LineCollection
64
+ import seaborn as sns
65
+ x = np.array(sorted(expression))
66
+ pdf = np.exp(gmm.score_samples(expression[:, None]))
67
+ points = np.array([x, pdf]).T.reshape(-1, 1, 2)
68
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
69
+ plt.figure(dpi=plot_dpi, )
70
+ sns.histplot(expression, stat='density', color='grey')
71
+ lc = LineCollection(segments, cmap='winter')
72
+ lc.set_array(probs)
73
+ lc.set_linewidth(3)
74
+ line = plt.gca().add_collection(lc)
75
+ plt.colorbar(line)
76
+ plt.xlabel('Standardized expression')
77
+ plt.tight_layout()
78
+ plt.savefig(fit_plot_filename)
79
+ return inds, probs
80
+
81
+ expression_max = expression.max(axis=1).values
82
+
83
+ mu = [-1.0, 0.0]
84
+ sigmas = [1.0, 1.0]
85
+ w = [0.5]
86
+ x0 = jnp.array(mu + sigmas + w)
87
+ x0 = transform(x0, False)
88
+ fun = jit(partial(loglik, x=expression_max, n=expression.shape[1]))
89
+ jac = jit(grad(fun))
90
+ res = minimize(fun, x0, jac=jac)
91
+
92
+ params = transform(res.x)
93
+ mu = params[:2]
94
+ sigma = params[2:4]
95
+ w = params[-1]
96
+
97
+ mode1 = minimize(lambda x: -normax_logpdf(x, mu[0], sigma[0], n=expression.shape[1]), x0=[0.0]).x
98
+ mode2 = minimize(lambda x: -normax_logpdf(x, mu[1], sigma[1], n=expression.shape[1]), x0=[0.0]).x
99
+ if mode1 > mode2:
100
+ mu = mu[::-1]
101
+ sigma = sigma[::-1]
102
+ w = 1 - w
103
+
104
+ inds = np.argsort(expression_max)
105
+ inds_inv = np.empty_like(inds, dtype=int)
106
+ inds_inv[inds] = np.arange(len(inds))
107
+ x = expression_max[inds]
108
+ logpdf1 = normax_logpdf(x, mu[0], sigma[0], n=expression.shape[1])
109
+ logpdf2 = normax_logpdf(x, mu[1], sigma[1], n=expression.shape[1])
110
+ pdf1 = jnp.exp(logpdf1)
111
+ pdf2 = jnp.exp(logpdf2)
112
+ ws = np.array(pdf1 / ((w * pdf1 + (1-w)*pdf2)) * w)
113
+
114
+ if float(w) > component_limit:
115
+ ws[:] = 1.0
116
+ else:
117
+ ws = 1 - ws
118
+ if x[ws >= 0.5].mean() < x[ws < 0.5].mean():
119
+ ws = 1 - ws
120
+ j = np.argmax(ws)
121
+ l = np.argmin(ws)
122
+ ws[j:] = 1.0
123
+ ws[:l] = 0.0
124
+
125
+ k = 0
126
+ for k in range(len(ws)):
127
+ if ws[k] >= 1.0-cutoff:
128
+ break
129
+ if fit_plot_filename:
130
+ import matplotlib.pyplot as plt
131
+ from matplotlib.collections import LineCollection
132
+ import seaborn as sns
133
+ pdf = jnp.exp(logmixture(x, mu, sigma, w, n=expression.shape[1]))
134
+ points = np.array([x, pdf]).T.reshape(-1, 1, 2)
135
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
136
+ plt.figure(dpi=plot_dpi, )
137
+ sns.histplot(expression_max, stat='density', color='grey')
138
+ lc = LineCollection(segments, cmap='winter')
139
+ lc.set_array(ws)
140
+ lc.set_linewidth(3)
141
+ line = plt.gca().add_collection(lc)
142
+ plt.colorbar(line)
143
+ plt.xlabel('Standardized expression')
144
+ plt.tight_layout()
145
+ plt.savefig(fit_plot_filename)
146
+ ws = ws[inds_inv]
147
+ inds = np.ones(len(expression), dtype=bool)
148
+ inds[:k] = False
149
+ # print(inds)
150
+ # inds[:] = 1
151
+ inds = inds[inds_inv]
152
+ return inds, ws