mxlpy 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mxlpy/plot.py CHANGED
@@ -23,9 +23,12 @@ import contextlib
23
23
  from cycler import cycler
24
24
 
25
25
  __all__ = [
26
+ "Color",
26
27
  "FigAx",
27
28
  "FigAxs",
28
29
  "Linestyle",
30
+ "RGB",
31
+ "RGBA",
29
32
  "add_grid",
30
33
  "bars",
31
34
  "context",
@@ -71,7 +74,7 @@ from mpl_toolkits.mplot3d import Axes3D
71
74
  from mxlpy.label_map import LabelMapper
72
75
 
73
76
  if TYPE_CHECKING:
74
- from collections.abc import Generator
77
+ from collections.abc import Generator, Iterable
75
78
 
76
79
  from matplotlib.collections import QuadMesh
77
80
 
@@ -89,6 +92,11 @@ type Linestyle = Literal[
89
92
  "dashdot",
90
93
  ]
91
94
 
95
+
96
+ type RGB = tuple[float, float, float]
97
+ type RGBA = tuple[float, float, float, float]
98
+ type Color = str | RGB | RGBA
99
+
92
100
  ##########################################################################
93
101
  # Helpers
94
102
  ##########################################################################
@@ -133,6 +141,13 @@ def _get_norm(vmin: float, vmax: float) -> Normalize:
133
141
  return norm
134
142
 
135
143
 
144
+ def _norm(df: pd.DataFrame) -> Normalize:
145
+ """Get a normalization object for the given data."""
146
+ vmin = df.min().min()
147
+ vmax = df.max().max()
148
+ return _get_norm(vmin, vmax)
149
+
150
+
136
151
  def _norm_with_zero_center(df: pd.DataFrame) -> Normalize:
137
152
  """Get a normalization object with zero-centered values for the given data."""
138
153
  v = max(abs(df.min().min()), abs(df.max().max()))
@@ -166,7 +181,7 @@ def _split_large_groups[T](groups: list[list[T]], max_size: int) -> list[list[T]
166
181
  ) # type: ignore
167
182
 
168
183
 
169
- def _default_color(ax: Axes, color: str | None) -> str:
184
+ def _default_color(ax: Axes, color: Color | None) -> Color:
170
185
  """Get a default color for the given axis."""
171
186
  return f"C{len(ax.lines)}" if color is None else color
172
187
 
@@ -291,16 +306,16 @@ def reset_prop_cycle(ax: Axes) -> None:
291
306
  @contextlib.contextmanager
292
307
  def context(
293
308
  colors: list[str] | None = None,
294
- line_width: float | None = None,
295
- line_style: Linestyle | None = None,
309
+ linewidth: float | None = None,
310
+ linestyle: Linestyle | None = None,
296
311
  rc: dict[str, Any] | None = None,
297
312
  ) -> Generator[None, None, None]:
298
313
  """Context manager to set the defaults for plots.
299
314
 
300
315
  Args:
301
316
  colors: colors to use for the plot.
302
- line_width: line width to use for the plot.
303
- line_style: line style to use for the plot.
317
+ linewidth: line width to use for the plot.
318
+ linestyle: line style to use for the plot.
304
319
  rc: additional keyword arguments to pass to the rc context.
305
320
 
306
321
  """
@@ -309,11 +324,11 @@ def context(
309
324
  if colors is not None:
310
325
  rc["axes.prop_cycle"] = cycler(color=colors)
311
326
 
312
- if line_width is not None:
313
- rc["lines.linewidth"] = line_width
327
+ if linewidth is not None:
328
+ rc["lines.linewidth"] = linewidth
314
329
 
315
- if line_style is not None:
316
- rc["lines.linestyle"] = line_style
330
+ if linestyle is not None:
331
+ rc["lines.linestyle"] = linestyle
317
332
 
318
333
  with plt.rc_context(rc):
319
334
  yield
@@ -477,9 +492,12 @@ def lines(
477
492
  x: pd.DataFrame | pd.Series,
478
493
  *,
479
494
  ax: Axes | None = None,
480
- grid: bool = True,
481
495
  alpha: float = 1.0,
496
+ color: Color | list[Color] | None = None,
497
+ grid: bool = True,
482
498
  legend: bool = True,
499
+ linewidth: float | None = None,
500
+ linestyle: Linestyle | None = None,
483
501
  ) -> FigAx:
484
502
  """Plot multiple lines on the same axis."""
485
503
  fig, ax = _default_fig_ax(ax=ax, grid=grid)
@@ -487,6 +505,9 @@ def lines(
487
505
  x.index,
488
506
  x,
489
507
  alpha=alpha,
508
+ linewidth=linewidth,
509
+ linestyle=linestyle,
510
+ color=color,
490
511
  )
491
512
  _default_labels(ax, xlabel=x.index.name, ylabel=None)
492
513
  if legend:
@@ -497,6 +518,12 @@ def lines(
497
518
  return fig, ax
498
519
 
499
520
 
521
+ def _repeat_color_if_necessary(
522
+ color: list[list[Color]] | Color | None, n: int
523
+ ) -> Iterable[list[Color] | Color | None]:
524
+ return [color] * n if not isinstance(color, list) else color
525
+
526
+
500
527
  def lines_grouped(
501
528
  groups: list[pd.DataFrame] | list[pd.Series],
502
529
  *,
@@ -506,6 +533,9 @@ def lines_grouped(
506
533
  sharex: bool = True,
507
534
  sharey: bool = False,
508
535
  grid: bool = True,
536
+ color: Color | list[list[Color]] | None = None,
537
+ linewidth: float | None = None,
538
+ linestyle: Linestyle | None = None,
509
539
  ) -> FigAxs:
510
540
  """Plot multiple groups of lines on separate axes."""
511
541
  fig, axs = grid_layout(
@@ -518,8 +548,20 @@ def lines_grouped(
518
548
  grid=grid,
519
549
  )
520
550
 
521
- for group, ax in zip(groups, axs, strict=False):
522
- lines(group, ax=ax, grid=grid)
551
+ for group, ax, color_ in zip(
552
+ groups,
553
+ axs,
554
+ _repeat_color_if_necessary(color, n=len(groups)),
555
+ strict=False,
556
+ ):
557
+ lines(
558
+ group,
559
+ ax=ax,
560
+ grid=grid,
561
+ color=color_,
562
+ linewidth=linewidth,
563
+ linestyle=linestyle,
564
+ )
523
565
 
524
566
  for i in range(len(groups), len(axs)):
525
567
  axs[i].set_visible(False)
@@ -535,6 +577,9 @@ def line_autogrouped(
535
577
  row_height: float = 3,
536
578
  max_group_size: int = 6,
537
579
  grid: bool = True,
580
+ color: Color | list[list[Color]] | None = None,
581
+ linewidth: float | None = None,
582
+ linestyle: Linestyle | None = None,
538
583
  ) -> FigAxs:
539
584
  """Plot a series or dataframe with lines grouped by order of magnitude."""
540
585
  group_names = _split_large_groups(
@@ -556,6 +601,9 @@ def line_autogrouped(
556
601
  col_width=col_width,
557
602
  row_height=row_height,
558
603
  grid=grid,
604
+ color=color,
605
+ linestyle=linestyle,
606
+ linewidth=linewidth,
559
607
  )
560
608
 
561
609
 
@@ -564,7 +612,9 @@ def line_mean_std(
564
612
  *,
565
613
  label: str | None = None,
566
614
  ax: Axes | None = None,
567
- color: str | None = None,
615
+ color: Color | None = None,
616
+ linewidth: float | None = None,
617
+ linestyle: Linestyle | None = None,
568
618
  alpha: float = 0.2,
569
619
  grid: bool = True,
570
620
  ) -> FigAx:
@@ -579,6 +629,8 @@ def line_mean_std(
579
629
  mean,
580
630
  color=color,
581
631
  label=label,
632
+ linewidth=linewidth,
633
+ linestyle=linestyle,
582
634
  )
583
635
  ax.fill_between(
584
636
  df.index,
@@ -598,6 +650,9 @@ def lines_mean_std_from_2d_idx(
598
650
  ax: Axes | None = None,
599
651
  alpha: float = 0.2,
600
652
  grid: bool = True,
653
+ color: Color | None = None,
654
+ linewidth: float | None = None,
655
+ linestyle: Linestyle | None = None,
601
656
  ) -> FigAx:
602
657
  """Plot the mean and standard deviation of a 2D indexed dataframe."""
603
658
  if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
@@ -612,25 +667,32 @@ def lines_mean_std_from_2d_idx(
612
667
  label=name,
613
668
  alpha=alpha,
614
669
  ax=ax,
670
+ color=color,
671
+ linestyle=linestyle,
672
+ linewidth=linewidth,
615
673
  )
616
674
  ax.legend()
617
675
  return fig, ax
618
676
 
619
677
 
620
- def heatmap(
621
- df: pd.DataFrame,
678
+ def _create_heatmap(
622
679
  *,
680
+ ax: Axes | None,
681
+ df: pd.DataFrame,
682
+ title: str | None = None,
683
+ xlabel: str,
684
+ ylabel: str,
685
+ xticklabels: list[str],
686
+ yticklabels: list[str],
687
+ norm: Normalize,
623
688
  annotate: bool = False,
624
689
  colorbar: bool = True,
625
690
  invert_yaxis: bool = True,
626
691
  cmap: str = "RdBu_r",
627
- norm: Normalize | None = None,
628
- ax: Axes | None = None,
629
692
  cax: Axes | None = None,
630
693
  sci_annotation_bounds: tuple[float, float] = (0.01, 100),
631
694
  annotation_style: str = "2g",
632
695
  ) -> tuple[Figure, Axes, QuadMesh]:
633
- """Plot a heatmap of the given data."""
634
696
  fig, ax = _default_fig_ax(
635
697
  ax=ax,
636
698
  figsize=(
@@ -639,17 +701,23 @@ def heatmap(
639
701
  ),
640
702
  grid=False,
641
703
  )
642
- if norm is None:
643
- norm = _norm_with_zero_center(df)
644
704
 
705
+ # Note: pcolormesh swaps index/columns
645
706
  hm = ax.pcolormesh(df, norm=norm, cmap=cmap)
707
+
708
+ if xlabel is not None:
709
+ ax.set_xlabel(xlabel)
710
+ if ylabel is not None:
711
+ ax.set_ylabel(ylabel)
712
+ if title is not None:
713
+ ax.set_title(title)
646
714
  ax.set_xticks(
647
715
  np.arange(0, len(df.columns), 1) + 0.5,
648
- labels=df.columns,
716
+ labels=xticklabels,
649
717
  )
650
718
  ax.set_yticks(
651
719
  np.arange(0, len(df.index), 1) + 0.5,
652
- labels=df.index,
720
+ labels=yticklabels,
653
721
  )
654
722
 
655
723
  if annotate:
@@ -666,39 +734,77 @@ def heatmap(
666
734
  return fig, ax, hm
667
735
 
668
736
 
737
+ def heatmap(
738
+ df: pd.DataFrame,
739
+ *,
740
+ ax: Axes | None = None,
741
+ title: str | None = None,
742
+ annotate: bool = False,
743
+ colorbar: bool = True,
744
+ invert_yaxis: bool = True,
745
+ cmap: str = "RdBu_r",
746
+ norm: Normalize | None = None,
747
+ cax: Axes | None = None,
748
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
749
+ annotation_style: str = "2g",
750
+ ) -> tuple[Figure, Axes, QuadMesh]:
751
+ """Plot a heatmap of the given data."""
752
+ return _create_heatmap(
753
+ ax=ax,
754
+ df=df,
755
+ title=title,
756
+ xlabel=df.index.name,
757
+ ylabel=df.columns.name,
758
+ xticklabels=cast(list, df.columns),
759
+ yticklabels=cast(list, df.index),
760
+ annotate=annotate,
761
+ colorbar=colorbar,
762
+ invert_yaxis=invert_yaxis,
763
+ cmap=cmap,
764
+ norm=_norm_with_zero_center(df) if norm is None else norm,
765
+ cax=cax,
766
+ sci_annotation_bounds=sci_annotation_bounds,
767
+ annotation_style=annotation_style,
768
+ )
769
+
770
+
669
771
  def heatmap_from_2d_idx(
670
772
  df: pd.DataFrame,
671
773
  variable: str,
774
+ *,
672
775
  ax: Axes | None = None,
673
- ) -> FigAx:
776
+ annotate: bool = False,
777
+ colorbar: bool = True,
778
+ invert_yaxis: bool = False,
779
+ cmap: str = "viridis",
780
+ norm: Normalize | None = None,
781
+ cax: Axes | None = None,
782
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
783
+ annotation_style: str = "2g",
784
+ ) -> tuple[Figure, Axes, QuadMesh]:
674
785
  """Plot a heatmap of a 2D indexed dataframe."""
675
786
  if len(cast(pd.MultiIndex, df.index).levels) != 2: # noqa: PLR2004
676
787
  msg = "MultiIndex must have exactly two levels"
677
788
  raise ValueError(msg)
678
-
679
- fig, ax = _default_fig_ax(ax=ax, grid=False)
680
- df2d = df[variable].unstack()
681
-
682
- ax.set_title(variable)
683
- # Note: pcolormesh swaps index/columns
684
- hm = ax.pcolormesh(df2d.T)
685
- ax.set_xlabel(df2d.index.name)
686
- ax.set_ylabel(df2d.columns.name)
687
- ax.set_xticks(
688
- np.arange(0, len(df2d.index), 1) + 0.5,
689
- labels=[f"{i:.2f}" for i in df2d.index],
690
- )
691
- ax.set_yticks(
692
- np.arange(0, len(df2d.columns), 1) + 0.5,
693
- labels=[f"{i:.2f}" for i in df2d.columns],
789
+ df2d = df[variable].unstack().T
790
+
791
+ return _create_heatmap(
792
+ df=df2d,
793
+ xlabel=df2d.index.name,
794
+ ylabel=df2d.columns.name,
795
+ xticklabels=[f"{i:.2f}" for i in df2d.columns],
796
+ yticklabels=[f"{i:.2f}" for i in df2d.index],
797
+ ax=ax,
798
+ cax=cax,
799
+ annotate=annotate,
800
+ colorbar=colorbar,
801
+ invert_yaxis=invert_yaxis,
802
+ cmap=cmap,
803
+ norm=_norm(df2d) if norm is None else norm,
804
+ sci_annotation_bounds=sci_annotation_bounds,
805
+ annotation_style=annotation_style,
694
806
  )
695
807
 
696
- rotate_xlabels(ax, rotation=45, ha="right")
697
-
698
- # Add colorbar
699
- fig.colorbar(hm, ax=ax)
700
- return fig, ax
701
-
702
808
 
703
809
  def heatmaps_from_2d_idx(
704
810
  df: pd.DataFrame,
@@ -708,6 +814,13 @@ def heatmaps_from_2d_idx(
708
814
  row_height_factor: float = 0.6,
709
815
  sharex: bool = True,
710
816
  sharey: bool = False,
817
+ annotate: bool = False,
818
+ colorbar: bool = True,
819
+ invert_yaxis: bool = False,
820
+ cmap: str = "viridis",
821
+ norm: Normalize | None = None,
822
+ sci_annotation_bounds: tuple[float, float] = (0.01, 100),
823
+ annotation_style: str = "2g",
711
824
  ) -> FigAxs:
712
825
  """Plot multiple heatmaps of a 2D indexed dataframe."""
713
826
  idx = cast(pd.MultiIndex, df.index)
@@ -722,7 +835,18 @@ def heatmaps_from_2d_idx(
722
835
  grid=False,
723
836
  )
724
837
  for ax, var in zip(axs, df.columns, strict=False):
725
- heatmap_from_2d_idx(df, var, ax=ax)
838
+ heatmap_from_2d_idx(
839
+ df,
840
+ var,
841
+ ax=ax,
842
+ annotate=annotate,
843
+ colorbar=colorbar,
844
+ invert_yaxis=invert_yaxis,
845
+ cmap=cmap,
846
+ norm=norm,
847
+ sci_annotation_bounds=sci_annotation_bounds,
848
+ annotation_style=annotation_style,
849
+ )
726
850
  return fig, axs
727
851
 
728
852
 
@@ -892,6 +1016,9 @@ def relative_label_distribution(
892
1016
  row_height: float = 3,
893
1017
  sharey: bool = False,
894
1018
  grid: bool = True,
1019
+ color: Color | None = None,
1020
+ linewidth: float | None = None,
1021
+ linestyle: Linestyle | None = None,
895
1022
  ) -> FigAxs:
896
1023
  """Plot the relative distribution of labels in the given data."""
897
1024
  variables = list(mapper.label_variables) if subset is None else subset
@@ -903,20 +1030,37 @@ def relative_label_distribution(
903
1030
  sharey=sharey,
904
1031
  grid=grid,
905
1032
  )
1033
+ # FIXME: rewrite as building a dict of dataframes
1034
+ # and passing it to lines_grouped
906
1035
  if isinstance(mapper, LabelMapper):
907
1036
  for ax, name in zip(axs, variables, strict=False):
908
1037
  for i in range(mapper.label_variables[name]):
909
1038
  isos = mapper.get_isotopomers_of_at_position(name, i)
910
1039
  labels = cast(pd.DataFrame, concs.loc[:, isos])
911
1040
  total = concs.loc[:, f"{name}__total"]
912
- ax.plot(labels.index, (labels.sum(axis=1) / total), label=f"C{i + 1}")
1041
+ ax.plot(
1042
+ labels.index,
1043
+ (labels.sum(axis=1) / total),
1044
+ label=f"C{i + 1}",
1045
+ linewidth=linewidth,
1046
+ linestyle=linestyle,
1047
+ color=color,
1048
+ )
913
1049
  ax.set_title(name)
914
1050
  ax.legend()
915
1051
  else:
916
1052
  for ax, (name, isos) in zip(
917
- axs, mapper.get_isotopomers(variables).items(), strict=False
1053
+ axs,
1054
+ mapper.get_isotopomers(variables).items(),
1055
+ strict=False,
918
1056
  ):
919
- ax.plot(concs.index, concs.loc[:, isos])
1057
+ ax.plot(
1058
+ concs.index,
1059
+ concs.loc[:, isos],
1060
+ linewidth=linewidth,
1061
+ linestyle=linestyle,
1062
+ color=color,
1063
+ )
920
1064
  ax.set_title(name)
921
1065
  ax.legend([f"C{i + 1}" for i in range(len(isos))])
922
1066
 
@@ -5,12 +5,11 @@ from typing import Self
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  import torch
8
- import tqdm
9
8
  from torch import nn
10
9
  from torch.optim.adam import Adam
11
10
  from torch.optim.optimizer import ParamsT
12
11
 
13
- from mxlpy.nn._torch import MLP, DefaultDevice
12
+ from mxlpy.nn._torch import MLP, DefaultDevice, train
14
13
  from mxlpy.types import AbstractSurrogate
15
14
 
16
15
  type LossFn = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
@@ -106,28 +105,16 @@ class TorchTrainer:
106
105
  epochs: int,
107
106
  batch_size: int | None = None,
108
107
  ) -> Self:
109
- if batch_size is None:
110
- losses = _train_full(
111
- aprox=self.approximator,
112
- features=self.features,
113
- targets=self.targets,
114
- epochs=epochs,
115
- optimizer=self.optimizer,
116
- device=self.device,
117
- loss_fn=self.loss_fn,
118
- )
119
- else:
120
- losses = _train_batched(
121
- aprox=self.approximator,
122
- features=self.features,
123
- targets=self.targets,
124
- epochs=epochs,
125
- optimizer=self.optimizer,
126
- device=self.device,
127
- batch_size=batch_size,
128
- loss_fn=self.loss_fn,
129
- )
130
-
108
+ losses = train(
109
+ aprox=self.approximator,
110
+ features=self.features.to_numpy(),
111
+ targets=self.targets.to_numpy(),
112
+ epochs=epochs,
113
+ optimizer=self.optimizer,
114
+ batch_size=batch_size,
115
+ device=self.device,
116
+ loss_fn=self.loss_fn,
117
+ )
131
118
  if len(self.losses) > 0:
132
119
  losses.index += self.losses[-1].index[-1]
133
120
  self.losses.append(losses)
@@ -152,83 +139,6 @@ class TorchTrainer:
152
139
  )
153
140
 
154
141
 
155
- def _train_batched(
156
- aprox: nn.Module,
157
- features: pd.DataFrame,
158
- targets: pd.DataFrame,
159
- epochs: int,
160
- optimizer: Adam,
161
- device: torch.device,
162
- batch_size: int,
163
- loss_fn: LossFn,
164
- ) -> pd.Series:
165
- """Train the neural network using mini-batch gradient descent.
166
-
167
- Args:
168
- aprox: Neural network model to train.
169
- features: Input features as a tensor.
170
- targets: Target values as a tensor.
171
- epochs: Number of training epochs.
172
- optimizer: Optimizer for training.
173
- device: torch device
174
- batch_size: Size of mini-batches for training.
175
- loss_fn: Loss function
176
-
177
- Returns:
178
- pd.Series: Series containing the training loss history.
179
-
180
- """
181
- rng = np.random.default_rng()
182
- losses = {}
183
- for i in tqdm.trange(epochs):
184
- idxs = rng.choice(features.index, size=batch_size)
185
- X = torch.Tensor(features.iloc[idxs].to_numpy(), device=device)
186
- Y = torch.Tensor(targets.iloc[idxs].to_numpy(), device=device)
187
- optimizer.zero_grad()
188
- loss = loss_fn(aprox(X), Y)
189
- loss.backward()
190
- optimizer.step()
191
- losses[i] = loss.detach().numpy()
192
- return pd.Series(losses, dtype=float)
193
-
194
-
195
- def _train_full(
196
- aprox: nn.Module,
197
- features: pd.DataFrame,
198
- targets: pd.DataFrame,
199
- epochs: int,
200
- optimizer: Adam,
201
- device: torch.device,
202
- loss_fn: Callable,
203
- ) -> pd.Series:
204
- """Train the neural network using full-batch gradient descent.
205
-
206
- Args:
207
- aprox: Neural network model to train.
208
- features: Input features as a tensor.
209
- targets: Target values as a tensor.
210
- epochs: Number of training epochs.
211
- optimizer: Optimizer for training.
212
- device: Torch device
213
- loss_fn: Loss function
214
-
215
- Returns:
216
- pd.Series: Series containing the training loss history.
217
-
218
- """
219
- X = torch.Tensor(features.to_numpy(), device=device)
220
- Y = torch.Tensor(targets.to_numpy(), device=device)
221
-
222
- losses = {}
223
- for i in tqdm.trange(epochs):
224
- optimizer.zero_grad()
225
- loss = loss_fn(aprox(X), Y)
226
- loss.backward()
227
- optimizer.step()
228
- losses[i] = loss.detach().numpy()
229
- return pd.Series(losses, dtype=float)
230
-
231
-
232
142
  def train_torch(
233
143
  features: pd.DataFrame,
234
144
  targets: pd.DataFrame,
mxlpy/types.py CHANGED
@@ -46,9 +46,6 @@ __all__ = [
46
46
  "unwrap2",
47
47
  ]
48
48
 
49
- # Re-exporting some types here, because their imports have
50
- # changed between Python versions and I have no interest in
51
- # fixing it in every file
52
49
  from collections.abc import Callable, Iterator, Mapping
53
50
  from typing import TYPE_CHECKING, Any, ParamSpec, Protocol, TypeVar, cast
54
51
 
@@ -1,17 +1,17 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mxlpy
3
- Version: 0.17.0
3
+ Version: 0.18.0
4
4
  Summary: A package to build metabolic models
5
5
  Author-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
6
6
  Maintainer-email: Marvin van Aalst <marvin.vanaalst@gmail.com>
7
- License-Expression: GPL-3.0-or-later
7
+ License-Expression: MIT
8
8
  License-File: LICENSE
9
9
  Keywords: metabolic,modelling,ode
10
10
  Classifier: Development Status :: 5 - Production/Stable
11
11
  Classifier: Environment :: Console
12
12
  Classifier: Intended Audience :: Developers
13
13
  Classifier: Intended Audience :: Science/Research
14
- Classifier: License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)
14
+ Classifier: License :: OSI Approved :: MIT License
15
15
  Classifier: Operating System :: MacOS
16
16
  Classifier: Operating System :: Microsoft :: Windows
17
17
  Classifier: Operating System :: OS Independent
@@ -67,7 +67,7 @@ Description-Content-Type: text/markdown
67
67
 
68
68
  [![pypi](https://img.shields.io/pypi/v/mxlpy.svg)](https://pypi.python.org/pypi/mxlpy)
69
69
  [![docs][docs-badge]][docs]
70
- ![License](https://img.shields.io/badge/license-GPL--3.0-blue?style=flat-square)
70
+ ![License](https://img.shields.io/badge/license-MIT-blue?style=flat-square)
71
71
  ![Coverage](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fgist.github.com%2Fmarvinvanaalst%2F98ab3ce1db511de42f9871e91d85e4cd%2Fraw%2Fcoverage.json&query=%24.message&label=Coverage&color=%24.color&suffix=%20%25)
72
72
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
73
73
  [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit)
@@ -90,10 +90,10 @@ pixi add --pypi mxlpy[torch]
90
90
 
91
91
  ## How to cite
92
92
 
93
- If you use this software in your scientific work, please cite [this article](...):
93
+ If you use this software in your scientific work, please cite [this article](https://doi.org/10.1101/2025.05.06.652335):
94
94
 
95
- - [doi](https://doi.org/)
96
- - [bibtex file](https://fillme.out)
95
+ - [doi](https://doi.org/10.1101/2025.05.06.652335)
96
+ - [bibtex file](https://github.com/Computational-Biology-Aachen/MxlPy/citation.bibtex)
97
97
 
98
98
 
99
99
  ## Development setup