catfishml 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.pyd
4
+ *.so
5
+ .venv/
6
+ venv/
7
+ .pytest_cache/
8
+ .ruff_cache/
9
+ .mypy_cache/
10
+ .coverage
11
+ htmlcov/
12
+ dist/
13
+ build/
14
+ *.egg-info/
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 CatfishML Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,133 @@
1
+ Metadata-Version: 2.4
2
+ Name: catfishml
3
+ Version: 0.4.0
4
+ Summary: NeuroSplit Boosting for tabular data with differentiable soft trees and neural gating.
5
+ Project-URL: Homepage, https://github.com/catfishml/catfishml
6
+ Project-URL: Repository, https://github.com/catfishml/catfishml
7
+ Project-URL: Issues, https://github.com/catfishml/catfishml/issues
8
+ Author: CatfishML Contributors
9
+ License: MIT
10
+ License-File: LICENSE
11
+ Keywords: gradient-boosting,machine-learning,pytorch,soft-decision-tree,tabular
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Requires-Python: >=3.9
23
+ Requires-Dist: numpy>=1.23
24
+ Requires-Dist: pandas>=1.5
25
+ Requires-Dist: scikit-learn>=1.3
26
+ Requires-Dist: scipy>=1.10
27
+ Requires-Dist: torch>=2.1
28
+ Requires-Dist: typing-extensions>=4.8
29
+ Provides-Extra: dev
30
+ Requires-Dist: mypy>=1.11; extra == 'dev'
31
+ Requires-Dist: pytest-cov>=5.0; extra == 'dev'
32
+ Requires-Dist: pytest>=8.0; extra == 'dev'
33
+ Requires-Dist: ruff>=0.6; extra == 'dev'
34
+ Provides-Extra: viz
35
+ Requires-Dist: matplotlib>=3.8; extra == 'viz'
36
+ Description-Content-Type: text/markdown
37
+
38
+ # catfishml
39
+
40
+ `catfishml` is a Python library for **OmniBoost++**: generalized boosting with adaptive routing across heterogeneous weak learners (linear, spline/GAM-like, adaptive-depth MLP, differentiable soft tree).
41
+
42
+ ## Why catfishml
43
+
44
+ - Adaptive residual router with complexity/redundancy penalties.
45
+ - Additive boosting in natural predictor space with Newton-style targets.
46
+ - Numeric + categorical support (categorical embeddings).
47
+ - Missing value handling with `SimpleImputer` or MICE-style `IterativeImputer`.
48
+ - Adaptive behavior:
49
+ - automatic objective/distribution and metric selection,
50
+ - linearity probe (auto linear vs nonlinear mode),
51
+ - adaptive MLP depth + adaptive tree depth.
52
+ - CPU/GPU via PyTorch.
53
+ - Automatic dependency install for missing core libraries (can be disabled).
54
+
55
+ ## Install
56
+
57
+ ```bash
58
+ pip install catfishml
59
+ ```
60
+
61
+ For development:
62
+
63
+ ```bash
64
+ pip install -e .[dev]
65
+ ```
66
+
67
+ ## Quick start
68
+
69
+ ```python
70
+ import pandas as pd
71
+ from catfishml import FishyCatClassifier
72
+
73
+ X = pd.DataFrame(
74
+ {
75
+ "age": [25, 31, 45, None, 39, 22, 55],
76
+ "income": [2200, 3400, 7600, 5100, None, 1900, 8800],
77
+ "city": ["A", "B", "A", "C", "B", None, "A"],
78
+ }
79
+ )
80
+ y = [0, 0, 1, 1, 0, 0, 1]
81
+
82
+ model = FishyCatClassifier(
83
+ n_estimators=40,
84
+ tree_depth=3,
85
+ metrics="auto",
86
+ auto_metric=True,
87
+ impute_strategy="auto",
88
+ candidate_families="auto",
89
+ install_missing_libraries=True,
90
+ n_jobs=4,
91
+ verbose=1,
92
+ )
93
+
94
+ model.fit(X, y)
95
+ print(model.evaluate(X, y))
96
+ print(model.predict_proba(X)[:3])
97
+ fig = model.plot_visualization(kind="overview")
98
+ print(model.get_statistics())
99
+ print(model.get_history(as_dataframe=True).head())
100
+ ```
101
+
102
+ ## Main API
103
+
104
+ - `FishyCatBooster`
105
+ - `FishyCatClassifier`
106
+
107
+ For regression, use `FishyCatBooster(task="regression", ...)`.
108
+
109
+ Common parameters:
110
+
111
+ - `metrics`: metric name (`"auto"`, `"accuracy"`, `"auc"`, `"logloss"`, `"rmse"`, `"mae"`, `"r2"`) or callable.
112
+ - `auto_metric`: if `True`, metric and training validation feedback are auto-selected by task/data.
113
+ - `impute_strategy`: `"auto"`, `"simple"`, `"iterative"`, or `"none"`.
114
+ - `structure_mode`: `"auto"`, `"linear"`, `"nonlinear"`.
115
+ - `boosting_order`: `1` (gradient) or `2` (Newton-like weighted residuals).
116
+ - `candidate_families`: `"auto"` or subset of `["linear", "spline", "adaptive_mlp", "soft_tree"]`.
117
+ - `plot_visualization(kind=...)`: loss/routing/depth/overview diagnostics.
118
+ - `get_statistics()`: full training + data summary.
119
+ - `get_history(as_dataframe=True)`: per-iteration history (loss, metric, ETA, routing).
120
+ - `view_data(X, transformed=True/False)`: inspect raw or transformed data.
121
+ - `auto_install_dependencies`: auto-installs missing libs using pip at runtime.
122
+ - `install_plot_dependencies`: if `True`, auto-installs plotting dependencies too.
123
+ - `full_report(X, y)`: one-shot report (statistics + history + evaluation).
124
+ - `available_components()`: list of all integrated learner families/features.
125
+
126
+ ## Notes
127
+
128
+ - This repository provides a practical implementation of OmniBoost++ ideas; it is not a strict reproduction of a specific paper.
129
+ - For larger datasets, run on GPU: `device="cuda"`.
130
+
131
+ ## License
132
+
133
+ MIT
@@ -0,0 +1,96 @@
1
+ # catfishml
2
+
3
+ `catfishml` is a Python library for **OmniBoost++**: generalized boosting with adaptive routing across heterogeneous weak learners (linear, spline/GAM-like, adaptive-depth MLP, differentiable soft tree).
4
+
5
+ ## Why catfishml
6
+
7
+ - Adaptive residual router with complexity/redundancy penalties.
8
+ - Additive boosting in natural predictor space with Newton-style targets.
9
+ - Numeric + categorical support (categorical embeddings).
10
+ - Missing value handling with `SimpleImputer` or MICE-style `IterativeImputer`.
11
+ - Adaptive behavior:
12
+ - automatic objective/distribution and metric selection,
13
+ - linearity probe (auto linear vs nonlinear mode),
14
+ - adaptive MLP depth + adaptive tree depth.
15
+ - CPU/GPU via PyTorch.
16
+ - Automatic dependency install for missing core libraries (can be disabled).
17
+
18
+ ## Install
19
+
20
+ ```bash
21
+ pip install catfishml
22
+ ```
23
+
24
+ For development:
25
+
26
+ ```bash
27
+ pip install -e .[dev]
28
+ ```
29
+
30
+ ## Quick start
31
+
32
+ ```python
33
+ import pandas as pd
34
+ from catfishml import FishyCatClassifier
35
+
36
+ X = pd.DataFrame(
37
+ {
38
+ "age": [25, 31, 45, None, 39, 22, 55],
39
+ "income": [2200, 3400, 7600, 5100, None, 1900, 8800],
40
+ "city": ["A", "B", "A", "C", "B", None, "A"],
41
+ }
42
+ )
43
+ y = [0, 0, 1, 1, 0, 0, 1]
44
+
45
+ model = FishyCatClassifier(
46
+ n_estimators=40,
47
+ tree_depth=3,
48
+ metrics="auto",
49
+ auto_metric=True,
50
+ impute_strategy="auto",
51
+ candidate_families="auto",
52
+ install_missing_libraries=True,
53
+ n_jobs=4,
54
+ verbose=1,
55
+ )
56
+
57
+ model.fit(X, y)
58
+ print(model.evaluate(X, y))
59
+ print(model.predict_proba(X)[:3])
60
+ fig = model.plot_visualization(kind="overview")
61
+ print(model.get_statistics())
62
+ print(model.get_history(as_dataframe=True).head())
63
+ ```
64
+
65
+ ## Main API
66
+
67
+ - `FishyCatBooster`
68
+ - `FishyCatClassifier`
69
+
70
+ For regression, use `FishyCatBooster(task="regression", ...)`.
71
+
72
+ Common parameters:
73
+
74
+ - `metrics`: metric name (`"auto"`, `"accuracy"`, `"auc"`, `"logloss"`, `"rmse"`, `"mae"`, `"r2"`) or callable.
75
+ - `auto_metric`: if `True`, metric and training validation feedback are auto-selected by task/data.
76
+ - `impute_strategy`: `"auto"`, `"simple"`, `"iterative"`, or `"none"`.
77
+ - `structure_mode`: `"auto"`, `"linear"`, `"nonlinear"`.
78
+ - `boosting_order`: `1` (gradient) or `2` (Newton-like weighted residuals).
79
+ - `candidate_families`: `"auto"` or subset of `["linear", "spline", "adaptive_mlp", "soft_tree"]`.
80
+ - `plot_visualization(kind=...)`: loss/routing/depth/overview diagnostics.
81
+ - `get_statistics()`: full training + data summary.
82
+ - `get_history(as_dataframe=True)`: per-iteration history (loss, metric, ETA, routing).
83
+ - `view_data(X, transformed=True/False)`: inspect raw or transformed data.
84
+ - `auto_install_dependencies`: auto-installs missing libs using pip at runtime.
85
+ - `install_plot_dependencies`: if `True`, auto-installs plotting dependencies too.
86
+ - `full_report(X, y)`: one-shot report (statistics + history + evaluation).
87
+ - `available_components()`: list of all integrated learner families/features.
88
+
89
+ ## Notes
90
+
91
+ - This repository provides a practical implementation of OmniBoost++ ideas; it is not a strict reproduction of a specific paper.
92
+ - For larger datasets, run on GPU: `device="cuda"`.
93
+
94
+ ## License
95
+
96
+ MIT
@@ -0,0 +1,65 @@
1
+ """Minimal example for catfishml."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from catfishml import FishyCatClassifier
9
+
10
+
11
+ def main() -> None:
12
+ rng = np.random.default_rng(42)
13
+ n = 400
14
+
15
+ data = pd.DataFrame(
16
+ {
17
+ "age": rng.integers(18, 70, size=n),
18
+ "income": rng.normal(5000, 1800, size=n),
19
+ "click_rate": rng.uniform(0, 1, size=n),
20
+ "city": rng.choice(["B", "C", "D", "A"], size=n),
21
+ "device": rng.choice(["ios", "android", "web"], size=n),
22
+ }
23
+ )
24
+
25
+ data.loc[::17, "income"] = np.nan
26
+ data.loc[::23, "city"] = None
27
+
28
+ logits = (
29
+ 0.04 * (data["age"].fillna(40) - 35)
30
+ + 0.0003 * (data["income"].fillna(5000) - 5000)
31
+ + 1.4 * (data["click_rate"] - 0.4)
32
+ + 0.5 * (data["device"] == "ios").astype(float)
33
+ - 0.3 * (data["city"].fillna("A") == "D").astype(float)
34
+ )
35
+ probs = 1.0 / (1.0 + np.exp(-logits.to_numpy()))
36
+ y = (rng.random(n) < probs).astype(int)
37
+
38
+ model = FishyCatClassifier(
39
+ n_estimators=50,
40
+ tree_depth=3,
41
+ boosting_order=2,
42
+ metrics="auto",
43
+ auto_metric=True,
44
+ impute_strategy="auto",
45
+ structure_mode="auto",
46
+ neural_epochs=12,
47
+ batch_size=128,
48
+ candidate_families="auto",
49
+ install_missing_libraries=True,
50
+ n_jobs=4,
51
+ verbose=1,
52
+ )
53
+
54
+ model.fit(data, y)
55
+ metrics = model.evaluate(data, y)
56
+ print("Training metrics:", metrics)
57
+ print("Proba preview:\n", model.predict_proba(data.head(5)))
58
+ print("Statistics:", model.get_statistics())
59
+ print("History preview:\n", model.get_history(as_dataframe=True).head())
60
+ # Requires matplotlib
61
+ # model.plot_visualization(kind="overview")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()
@@ -0,0 +1,71 @@
1
+ [build-system]
2
+ requires = ["hatchling>=1.24"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "catfishml"
7
+ version = "0.4.0"
8
+ description = "NeuroSplit Boosting for tabular data with differentiable soft trees and neural gating."
9
+ readme = "README.md"
10
+ license = { text = "MIT" }
11
+ requires-python = ">=3.9"
12
+ authors = [
13
+ { name = "CatfishML Contributors" }
14
+ ]
15
+ keywords = ["machine-learning", "gradient-boosting", "tabular", "pytorch", "soft-decision-tree"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Developers",
19
+ "Intended Audience :: Science/Research",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: Python :: 3.9",
23
+ "Programming Language :: Python :: 3.10",
24
+ "Programming Language :: Python :: 3.11",
25
+ "Programming Language :: Python :: 3.12",
26
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
27
+ ]
28
+ dependencies = [
29
+ "numpy>=1.23",
30
+ "scipy>=1.10",
31
+ "torch>=2.1",
32
+ "scikit-learn>=1.3",
33
+ "pandas>=1.5",
34
+ "typing-extensions>=4.8"
35
+ ]
36
+
37
+ [project.optional-dependencies]
38
+ dev = [
39
+ "pytest>=8.0",
40
+ "pytest-cov>=5.0",
41
+ "ruff>=0.6",
42
+ "mypy>=1.11"
43
+ ]
44
+ viz = [
45
+ "matplotlib>=3.8"
46
+ ]
47
+
48
+ [project.urls]
49
+ Homepage = "https://github.com/catfishml/catfishml"
50
+ Repository = "https://github.com/catfishml/catfishml"
51
+ Issues = "https://github.com/catfishml/catfishml/issues"
52
+
53
+ [tool.hatch.build.targets.wheel]
54
+ packages = ["src/catfishml"]
55
+
56
+ [tool.pytest.ini_options]
57
+ minversion = "8.0"
58
+ addopts = "-ra"
59
+ testpaths = ["tests"]
60
+
61
+ [tool.ruff]
62
+ line-length = 100
63
+ target-version = "py39"
64
+
65
+ [tool.ruff.lint]
66
+ select = ["E", "F", "I", "B"]
67
+
68
+ [tool.ruff.format]
69
+ quote-style = "double"
70
+ indent-style = "space"
71
+ line-ending = "lf"
@@ -0,0 +1,13 @@
1
+ """catfishml public API."""
2
+
3
+ from ._autoinstall import ensure_runtime_dependencies
4
+
5
+ ensure_runtime_dependencies(auto_install=True, include_plot=False, quiet=True)
6
+
7
+ from .model import FishyCatBooster, FishyCatClassifier
8
+
9
+ __all__ = [
10
+ "FishyCatBooster",
11
+ "FishyCatClassifier",
12
+ ]
13
+ __version__ = "0.4.0"
@@ -0,0 +1,126 @@
1
+ """Adaptive-depth MLP weak learner for OmniBoost++."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Sequence
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ class TabularEncoder(nn.Module):
13
+ """Builds dense input from numeric features and categorical embeddings."""
14
+
15
+ def __init__(
16
+ self,
17
+ *,
18
+ num_numeric_features: int,
19
+ categorical_cardinalities: Sequence[int],
20
+ embedding_dropout: float = 0.0,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.num_numeric_features = num_numeric_features
24
+ self.embed_layers = nn.ModuleList()
25
+ embed_dims: list[int] = []
26
+
27
+ for cardinality in categorical_cardinalities:
28
+ dim = _default_embedding_dim(cardinality)
29
+ self.embed_layers.append(nn.Embedding(cardinality, dim))
30
+ embed_dims.append(dim)
31
+
32
+ self.output_dim = num_numeric_features + sum(embed_dims)
33
+ self.embedding_dropout = nn.Dropout(embedding_dropout) if embedding_dropout > 0 else nn.Identity()
34
+
35
+ def forward(self, x_num: torch.Tensor, x_cat: torch.Tensor | None) -> torch.Tensor:
36
+ parts = [x_num] if self.num_numeric_features > 0 else []
37
+
38
+ if self.embed_layers:
39
+ if x_cat is None:
40
+ raise ValueError("x_cat is required because categorical embeddings are enabled.")
41
+ embedded = [layer(x_cat[:, idx]) for idx, layer in enumerate(self.embed_layers)]
42
+ cat_vec = torch.cat(embedded, dim=1)
43
+ parts.append(self.embedding_dropout(cat_vec))
44
+
45
+ if not parts:
46
+ raise ValueError("AdaptiveDepthMLP requires at least one input feature.")
47
+
48
+ return torch.cat(parts, dim=1) if len(parts) > 1 else parts[0]
49
+
50
+
51
+ class AdaptiveDepthMLP(nn.Module):
52
+ """MLP template with differentiable per-layer depth gates."""
53
+
54
+ def __init__(
55
+ self,
56
+ *,
57
+ num_numeric_features: int,
58
+ categorical_cardinalities: Sequence[int],
59
+ output_dim: int,
60
+ hidden_dim: int = 128,
61
+ max_layers: int = 4,
62
+ dropout: float = 0.1,
63
+ embedding_dropout: float = 0.0,
64
+ ) -> None:
65
+ super().__init__()
66
+ if max_layers < 1:
67
+ raise ValueError("max_layers must be >= 1")
68
+
69
+ self.max_layers = max_layers
70
+ self.hidden_dim = hidden_dim
71
+ self.output_dim = output_dim
72
+
73
+ self.encoder = TabularEncoder(
74
+ num_numeric_features=num_numeric_features,
75
+ categorical_cardinalities=categorical_cardinalities,
76
+ embedding_dropout=embedding_dropout,
77
+ )
78
+ if self.encoder.output_dim == 0:
79
+ raise ValueError("AdaptiveDepthMLP requires at least one numeric or categorical feature.")
80
+
81
+ self.input_proj = nn.Linear(self.encoder.output_dim, hidden_dim)
82
+ self.blocks = nn.ModuleList(
83
+ [
84
+ nn.Sequential(
85
+ nn.Linear(hidden_dim, hidden_dim),
86
+ nn.SiLU(),
87
+ nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
88
+ )
89
+ for _ in range(max_layers)
90
+ ]
91
+ )
92
+ self.depth_logits = nn.Parameter(torch.zeros(max_layers))
93
+ self.output_head = nn.Linear(hidden_dim, output_dim)
94
+
95
+ def forward(
96
+ self,
97
+ x_num: torch.Tensor,
98
+ x_cat: torch.Tensor | None = None,
99
+ ) -> tuple[torch.Tensor, torch.Tensor]:
100
+ x = self.encoder(x_num, x_cat)
101
+ h = F.silu(self.input_proj(x))
102
+ gates = torch.sigmoid(self.depth_logits)
103
+
104
+ for layer_idx, block in enumerate(self.blocks):
105
+ proposal = block(h)
106
+ gate = gates[layer_idx]
107
+ h = gate * proposal + (1.0 - gate) * h
108
+
109
+ out = self.output_head(h)
110
+ return out, gates
111
+
112
+ def depth_penalty(self) -> torch.Tensor:
113
+ return torch.sigmoid(self.depth_logits).sum()
114
+
115
+ def sharpness_penalty(self) -> torch.Tensor:
116
+ gates = torch.sigmoid(self.depth_logits)
117
+ return (gates * (1.0 - gates)).sum()
118
+
119
+ def effective_depth(self) -> float:
120
+ return float(torch.sigmoid(self.depth_logits).sum().detach().cpu().item())
121
+
122
+
123
+
124
+ def _default_embedding_dim(cardinality: int) -> int:
125
+ cardinality = max(2, int(cardinality))
126
+ return min(32, max(4, int(round(cardinality**0.25 * 4))))
@@ -0,0 +1,93 @@
1
+ """Runtime dependency auto-installer for catfishml."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib.util
6
+ import os
7
+ import subprocess
8
+ import sys
9
+ from typing import Iterable
10
+
11
+ _CORE_REQUIREMENTS = {
12
+ "numpy": "numpy>=1.23",
13
+ "scipy": "scipy>=1.10",
14
+ "pandas": "pandas>=1.5",
15
+ "sklearn": "scikit-learn>=1.3",
16
+ "torch": "torch>=2.1",
17
+ }
18
+
19
+ _OPTIONAL_REQUIREMENTS = {
20
+ "matplotlib": "matplotlib>=3.8",
21
+ }
22
+
23
+ _ENSURED = False
24
+
25
+
26
+ def _is_installed(module_name: str) -> bool:
27
+ return importlib.util.find_spec(module_name) is not None
28
+
29
+
30
+ def _missing_modules(module_names: Iterable[str]) -> list[str]:
31
+ missing: list[str] = []
32
+ for module_name in module_names:
33
+ if not _is_installed(module_name):
34
+ missing.append(module_name)
35
+ return missing
36
+
37
+
38
+ def ensure_runtime_dependencies(
39
+ *,
40
+ auto_install: bool = True,
41
+ include_plot: bool = False,
42
+ quiet: bool = True,
43
+ ) -> None:
44
+ """Ensure core runtime dependencies exist; optionally auto-install them."""
45
+
46
+ global _ENSURED
47
+ if _ENSURED:
48
+ return
49
+
50
+ required = dict(_CORE_REQUIREMENTS)
51
+ if include_plot:
52
+ required.update(_OPTIONAL_REQUIREMENTS)
53
+
54
+ missing_modules = _missing_modules(required.keys())
55
+ if not missing_modules:
56
+ _ENSURED = True
57
+ return
58
+
59
+ if not auto_install:
60
+ pkgs = [required[name] for name in missing_modules]
61
+ raise ImportError(
62
+ "Missing dependencies: "
63
+ + ", ".join(pkgs)
64
+ + ". Enable auto_install_dependencies=True or install manually."
65
+ )
66
+
67
+ if os.environ.get("CATFISHML_DISABLE_AUTO_INSTALL", "0") in {"1", "true", "True"}:
68
+ pkgs = [required[name] for name in missing_modules]
69
+ raise ImportError(
70
+ "Automatic dependency install is disabled via CATFISHML_DISABLE_AUTO_INSTALL. "
71
+ + "Missing: "
72
+ + ", ".join(pkgs)
73
+ )
74
+
75
+ pkgs_to_install = [required[name] for name in missing_modules]
76
+ cmd = [sys.executable, "-m", "pip", "install"]
77
+ if quiet:
78
+ cmd.append("-q")
79
+ cmd.extend(pkgs_to_install)
80
+
81
+ try:
82
+ subprocess.check_call(cmd)
83
+ except subprocess.CalledProcessError as exc: # pragma: no cover
84
+ raise ImportError(
85
+ "Failed to auto-install dependencies: " + ", ".join(pkgs_to_install)
86
+ ) from exc
87
+
88
+ still_missing = _missing_modules(required.keys())
89
+ if still_missing:
90
+ pkgs = [required[name] for name in still_missing]
91
+ raise ImportError("Dependencies still missing after install: " + ", ".join(pkgs))
92
+
93
+ _ENSURED = True