LetsANN 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
letsann/__init__.py ADDED
@@ -0,0 +1,44 @@
1
+ """LetsANN — A beginner-friendly ANN library on top of TensorFlow.
2
+
3
+ Public API:
4
+ - Model: thin wrapper around ``tf.keras.Sequential`` that builds a network
5
+ from a simple JSON-like list of layer specs.
6
+ - build_model: convenience factory for spec-based model construction.
7
+ - LAYER_REGISTRY: mapping of supported layer types to Keras classes
8
+ together with the metadata used by the web UI.
9
+
10
+ Importing ``letsann`` does *not* start the web server. The web UI has its own
11
+ console script (``letsann-web``) and must be started explicitly.
12
+ """
13
+
14
+ from ._version import __version__
15
+
16
+ __all__ = [
17
+ "Model",
18
+ "build_model",
19
+ "load_dataset",
20
+ "LAYER_REGISTRY",
21
+ "layer_catalog",
22
+ "__version__",
23
+ ]
24
+
25
+
26
+ def __getattr__(name):
27
+ """Lazily import the heavy (TensorFlow-backed) public API.
28
+
29
+ This keeps ``python -m letsann.cli version`` and similar tooling usable
30
+ even when TensorFlow has not been imported yet (or is not installed).
31
+ """
32
+ if name in {"Model", "build_model"}:
33
+ from .model import Model, build_model
34
+
35
+ return {"Model": Model, "build_model": build_model}[name]
36
+ if name in {"LAYER_REGISTRY", "layer_catalog"}:
37
+ from .layers import LAYER_REGISTRY, layer_catalog
38
+
39
+ return {"LAYER_REGISTRY": LAYER_REGISTRY, "layer_catalog": layer_catalog}[name]
40
+ if name == "load_dataset":
41
+ from .data import load_dataset
42
+
43
+ return load_dataset
44
+ raise AttributeError(f"module 'letsann' has no attribute {name!r}")
letsann/_version.py ADDED
@@ -0,0 +1,7 @@
1
+ """Version constant kept in a dependency-free module.
2
+
3
+ Keeping ``__version__`` isolated means ``letsann.cli`` and tooling can read
4
+ it without triggering the TensorFlow import chain.
5
+ """
6
+
7
+ __version__ = "0.1.0"
letsann/cli.py ADDED
@@ -0,0 +1,35 @@
1
+ """LetsANN 命令行工具。
2
+
3
+ 只做一件事:查版本号。
4
+ Web 界面已拆到独立的 ``letsann-web`` 项目,本包不再附带。
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+
11
+ from ._version import __version__
12
+
13
+
14
+ def build_parser() -> argparse.ArgumentParser:
15
+ parser = argparse.ArgumentParser(
16
+ prog="letsann",
17
+ description="LetsANN —— 基于 TensorFlow 的极简 ANN 库。",
18
+ )
19
+ sub = parser.add_subparsers(dest="command")
20
+ sub.required = True
21
+
22
+ ver = sub.add_parser("version", help="打印 LetsANN 版本。")
23
+ ver.set_defaults(func=lambda _a: (print(f"LetsANN {__version__}") or 0))
24
+
25
+ return parser
26
+
27
+
28
+ def main(argv=None) -> int:
29
+ parser = build_parser()
30
+ args = parser.parse_args(argv)
31
+ return args.func(args) or 0
32
+
33
+
34
+ if __name__ == "__main__": # pragma: no cover
35
+ raise SystemExit(main())
letsann/data.py ADDED
@@ -0,0 +1,197 @@
1
+ """Dataset helpers for LetsANN.
2
+
3
+ Users can:
4
+ * point LetsANN at a local CSV/NPZ file via :func:`load_dataset`;
5
+ * load the same way from uploaded files in the web UI.
6
+
7
+ Datasets are kept intentionally simple: tabular data in CSV (with a target
8
+ column) or NPZ archives that contain ``X`` / ``y`` arrays.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import io
14
+ import os
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+ from sklearn.model_selection import train_test_split
21
+ from sklearn.preprocessing import StandardScaler
22
+
23
+
24
+ @dataclass
25
+ class Dataset:
26
+ X_train: np.ndarray
27
+ X_val: np.ndarray
28
+ y_train: np.ndarray
29
+ y_val: np.ndarray
30
+ feature_names: Optional[list] = None
31
+ target_name: Optional[str] = None
32
+ n_classes: Optional[int] = None
33
+
34
+ @property
35
+ def input_shape(self) -> Tuple[int, ...]:
36
+ return self.X_train.shape[1:]
37
+
38
+ @property
39
+ def task_type(self) -> str:
40
+ """Heuristic classification-vs-regression detection."""
41
+ if self.n_classes and self.n_classes > 1:
42
+ return "classification"
43
+ return "regression"
44
+
45
+ def summary(self) -> Dict[str, Any]:
46
+ return {
47
+ "n_train": int(self.X_train.shape[0]),
48
+ "n_val": int(self.X_val.shape[0]),
49
+ "input_shape": list(self.input_shape),
50
+ "task_type": self.task_type,
51
+ "n_classes": self.n_classes,
52
+ "feature_names": self.feature_names,
53
+ "target_name": self.target_name,
54
+ }
55
+
56
+
57
+ def _infer_classes(y: np.ndarray) -> Optional[int]:
58
+ if y.ndim > 1 and y.shape[-1] > 1:
59
+ return int(y.shape[-1])
60
+ if np.issubdtype(y.dtype, np.integer):
61
+ uniq = np.unique(y)
62
+ if uniq.size <= max(50, int(np.sqrt(y.size))):
63
+ return int(uniq.size)
64
+ return None
65
+
66
+
67
+ def _from_dataframe(
68
+ df: pd.DataFrame,
69
+ target: Optional[str],
70
+ test_size: float,
71
+ normalize: bool,
72
+ random_state: int,
73
+ ) -> Dataset:
74
+ if target is None:
75
+ target = df.columns[-1]
76
+ if target not in df.columns:
77
+ raise ValueError(f"Target column {target!r} not in dataset. Columns: {list(df.columns)}")
78
+
79
+ features = [c for c in df.columns if c != target]
80
+ X = df[features].to_numpy(dtype=np.float32)
81
+ y_raw = df[target].to_numpy()
82
+
83
+ # If target is non-numeric, label-encode it.
84
+ if y_raw.dtype.kind in {"O", "U", "S"}:
85
+ classes, y = np.unique(y_raw, return_inverse=True)
86
+ y = y.astype(np.int64)
87
+ n_classes = int(classes.size)
88
+ else:
89
+ y = y_raw
90
+ n_classes = _infer_classes(y)
91
+ if n_classes is None:
92
+ y = y.astype(np.float32)
93
+
94
+ if normalize:
95
+ scaler = StandardScaler()
96
+ X = scaler.fit_transform(X).astype(np.float32)
97
+
98
+ X_train, X_val, y_train, y_val = train_test_split(
99
+ X, y, test_size=test_size, random_state=random_state,
100
+ stratify=y if n_classes else None,
101
+ )
102
+
103
+ return Dataset(
104
+ X_train=X_train,
105
+ X_val=X_val,
106
+ y_train=y_train,
107
+ y_val=y_val,
108
+ feature_names=features,
109
+ target_name=target,
110
+ n_classes=n_classes,
111
+ )
112
+
113
+
114
+ def load_dataset(
115
+ source: Union[str, bytes, io.BytesIO, pd.DataFrame],
116
+ *,
117
+ target: Optional[str] = None,
118
+ test_size: float = 0.2,
119
+ normalize: bool = True,
120
+ random_state: int = 42,
121
+ file_name: Optional[str] = None,
122
+ ) -> Dataset:
123
+ """Load a dataset from a path, bytes buffer, or DataFrame.
124
+
125
+ Parameters
126
+ ----------
127
+ source:
128
+ Path to a CSV/NPZ file, a bytes buffer (e.g. uploaded file), or a
129
+ pandas DataFrame.
130
+ target:
131
+ Target column name. When ``None`` the last column is used.
132
+ test_size:
133
+ Fraction kept for the validation split.
134
+ normalize:
135
+ Whether to StandardScaler-normalise the features.
136
+ file_name:
137
+ Optional hint used when ``source`` is raw bytes and its extension is
138
+ otherwise unknown.
139
+ """
140
+
141
+ if isinstance(source, pd.DataFrame):
142
+ return _from_dataframe(source, target, test_size, normalize, random_state)
143
+
144
+ if isinstance(source, (bytes, bytearray)):
145
+ buf = io.BytesIO(source)
146
+ ext = os.path.splitext(file_name or "")[1].lower()
147
+ return _load_from_buffer(buf, ext, target, test_size, normalize, random_state)
148
+
149
+ if isinstance(source, io.IOBase):
150
+ ext = os.path.splitext(file_name or getattr(source, "name", ""))[1].lower()
151
+ return _load_from_buffer(source, ext, target, test_size, normalize, random_state)
152
+
153
+ # assume string path
154
+ path = str(source)
155
+ ext = os.path.splitext(path)[1].lower()
156
+ if ext in {".csv", ".tsv", ".txt"}:
157
+ sep = "\t" if ext == ".tsv" else ","
158
+ df = pd.read_csv(path, sep=sep)
159
+ return _from_dataframe(df, target, test_size, normalize, random_state)
160
+ if ext in {".npz"}:
161
+ return _load_npz(np.load(path), test_size, normalize, random_state)
162
+ raise ValueError(f"Unsupported file extension: {ext!r}")
163
+
164
+
165
+ def _load_from_buffer(buf, ext, target, test_size, normalize, random_state) -> Dataset:
166
+ if ext in {".csv", ".tsv", ".txt", ""}:
167
+ sep = "\t" if ext == ".tsv" else ","
168
+ df = pd.read_csv(buf, sep=sep)
169
+ return _from_dataframe(df, target, test_size, normalize, random_state)
170
+ if ext == ".npz":
171
+ return _load_npz(np.load(buf), test_size, normalize, random_state)
172
+ raise ValueError(f"Unsupported upload extension: {ext!r}")
173
+
174
+
175
+ def _load_npz(npz, test_size: float, normalize: bool, random_state: int) -> Dataset:
176
+ if "X" not in npz or "y" not in npz:
177
+ raise ValueError("NPZ file must contain 'X' and 'y' arrays.")
178
+ X = np.asarray(npz["X"], dtype=np.float32)
179
+ y = np.asarray(npz["y"])
180
+ n_classes = _infer_classes(y)
181
+ if n_classes is None:
182
+ y = y.astype(np.float32)
183
+
184
+ if normalize and X.ndim == 2:
185
+ X = StandardScaler().fit_transform(X).astype(np.float32)
186
+
187
+ X_train, X_val, y_train, y_val = train_test_split(
188
+ X, y, test_size=test_size, random_state=random_state,
189
+ stratify=y if n_classes else None,
190
+ )
191
+ return Dataset(
192
+ X_train=X_train,
193
+ X_val=X_val,
194
+ y_train=y_train,
195
+ y_val=y_val,
196
+ n_classes=n_classes,
197
+ )
letsann/layers.py ADDED
@@ -0,0 +1,217 @@
1
+ """层注册表:同时被 Python API 与 Web UI 使用。
2
+
3
+ 每一项描述了 LetsANN 所支持的一种层。``params`` 列表会驱动 Web 界面
4
+ 的表单生成,因此 UI 展示始终与后端可实例化的参数保持一致。
5
+
6
+ UI 中展示的 ``description`` / ``label`` 使用中文,便于教学场景。
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any, Dict, List
12
+
13
+ import tensorflow as tf
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # 每个条目的字段:
17
+ # keras_cls : 创建该层所用的 Keras 类
18
+ # params : Web 表单使用的参数描述
19
+ # { name, label, type, default, options?, min?, max?, step? }
20
+ # description : 在层列表中展示的一句话说明(中文)
21
+ # ---------------------------------------------------------------------------
22
+
23
+ ACTIVATIONS = ["relu", "sigmoid", "tanh", "softmax", "linear", "elu", "selu"]
24
+ INITIALIZERS = [
25
+ "glorot_uniform",
26
+ "glorot_normal",
27
+ "he_uniform",
28
+ "he_normal",
29
+ "random_normal",
30
+ "zeros",
31
+ ]
32
+
33
+ LAYER_REGISTRY: Dict[str, Dict[str, Any]] = {
34
+ "Input": {
35
+ "keras_cls": tf.keras.layers.InputLayer,
36
+ "description": "输入层:指定单个样本的形状。",
37
+ "params": [
38
+ {
39
+ "name": "shape",
40
+ "label": "输入形状",
41
+ "type": "shape",
42
+ "default": "4",
43
+ "help": "用英文逗号分隔,例如 4 或 28,28,1",
44
+ }
45
+ ],
46
+ },
47
+ "Dense": {
48
+ "keras_cls": tf.keras.layers.Dense,
49
+ "description": "全连接层(Dense)。",
50
+ "params": [
51
+ {"name": "units", "label": "神经元数", "type": "int", "default": 32, "min": 1},
52
+ {
53
+ "name": "activation",
54
+ "label": "激活函数",
55
+ "type": "select",
56
+ "default": "relu",
57
+ "options": ACTIVATIONS,
58
+ },
59
+ {
60
+ "name": "use_bias",
61
+ "label": "是否使用偏置",
62
+ "type": "bool",
63
+ "default": True,
64
+ "advanced": True,
65
+ },
66
+ {
67
+ "name": "kernel_initializer",
68
+ "label": "权重初始化",
69
+ "type": "select",
70
+ "default": "glorot_uniform",
71
+ "options": INITIALIZERS,
72
+ "advanced": True,
73
+ },
74
+ ],
75
+ },
76
+
77
+ "Dropout": {
78
+ "keras_cls": tf.keras.layers.Dropout,
79
+ "description": "Dropout:训练时按比例随机丢弃神经元。",
80
+ "params": [
81
+ {
82
+ "name": "rate",
83
+ "label": "丢弃比例",
84
+ "type": "float",
85
+ "default": 0.2,
86
+ "min": 0.0,
87
+ "max": 0.95,
88
+ "step": 0.05,
89
+ }
90
+ ],
91
+ },
92
+ "BatchNormalization": {
93
+ "keras_cls": tf.keras.layers.BatchNormalization,
94
+ "description": "批归一化(BatchNormalization)。",
95
+ "params": [],
96
+ },
97
+ "Flatten": {
98
+ "keras_cls": tf.keras.layers.Flatten,
99
+ "description": "展平:将多维输入拉平成一维向量。",
100
+ "params": [],
101
+ },
102
+ "Activation": {
103
+ "keras_cls": tf.keras.layers.Activation,
104
+ "description": "独立的激活函数层。",
105
+ "params": [
106
+ {
107
+ "name": "activation",
108
+ "label": "激活函数",
109
+ "type": "select",
110
+ "default": "relu",
111
+ "options": ACTIVATIONS,
112
+ }
113
+ ],
114
+ },
115
+ "Conv2D": {
116
+ "keras_cls": tf.keras.layers.Conv2D,
117
+ "description": "二维卷积层(输入需要是 3 维,如图像)。",
118
+ "params": [
119
+ {"name": "filters", "label": "卷积核数量", "type": "int", "default": 16, "min": 1},
120
+ {
121
+ "name": "kernel_size",
122
+ "label": "卷积核大小",
123
+ "type": "shape",
124
+ "default": "3,3",
125
+ },
126
+ {
127
+ "name": "activation",
128
+ "label": "激活函数",
129
+ "type": "select",
130
+ "default": "relu",
131
+ "options": ACTIVATIONS,
132
+ },
133
+ {
134
+ "name": "padding",
135
+ "label": "填充方式",
136
+ "type": "select",
137
+ "default": "valid",
138
+ "options": ["valid", "same"],
139
+ "advanced": True,
140
+ },
141
+ ],
142
+ },
143
+ "MaxPooling2D": {
144
+ "keras_cls": tf.keras.layers.MaxPooling2D,
145
+ "description": "二维最大池化(MaxPooling2D)。",
146
+ "params": [
147
+ {
148
+ "name": "pool_size",
149
+ "label": "池化窗口",
150
+ "type": "shape",
151
+ "default": "2,2",
152
+ }
153
+ ],
154
+ },
155
+ }
156
+
157
+
158
+ def _parse_shape(value: Any) -> Any:
159
+ """把诸如 ``"28,28,1"`` 的字符串转成整数元组。"""
160
+ if value is None or value == "":
161
+ return None
162
+ if isinstance(value, (list, tuple)):
163
+ return tuple(int(v) for v in value)
164
+ if isinstance(value, int):
165
+ return (value,)
166
+ parts = [p.strip() for p in str(value).split(",") if p.strip()]
167
+ if len(parts) == 1:
168
+ return (int(parts[0]),)
169
+ return tuple(int(p) for p in parts)
170
+
171
+
172
+ def _coerce(param_spec: Dict[str, Any], value: Any) -> Any:
173
+ ptype = param_spec.get("type")
174
+ if value is None:
175
+ return param_spec.get("default")
176
+ if ptype == "int":
177
+ return int(value)
178
+ if ptype == "float":
179
+ return float(value)
180
+ if ptype == "bool":
181
+ if isinstance(value, bool):
182
+ return value
183
+ return str(value).lower() in {"1", "true", "yes", "on"}
184
+ if ptype == "shape":
185
+ return _parse_shape(value)
186
+ return value
187
+
188
+
189
+ def build_keras_layer(layer_type: str, params: Dict[str, Any]) -> tf.keras.layers.Layer:
190
+ """按 ``layer_type`` / ``params`` 创建一个 Keras 层实例。"""
191
+ if layer_type not in LAYER_REGISTRY:
192
+ raise ValueError(f"未知的层类型:{layer_type!r}")
193
+
194
+ spec = LAYER_REGISTRY[layer_type]
195
+ coerced: Dict[str, Any] = {}
196
+ for p in spec["params"]:
197
+ name = p["name"]
198
+ coerced[name] = _coerce(p, params.get(name, p.get("default")))
199
+
200
+ if layer_type == "Input":
201
+ return tf.keras.layers.InputLayer(input_shape=coerced["shape"])
202
+
203
+ return spec["keras_cls"](**coerced)
204
+
205
+
206
+ def layer_catalog() -> List[Dict[str, Any]]:
207
+ """返回可 JSON 序列化的层目录,供 Web UI 使用。"""
208
+ catalog = []
209
+ for name, spec in LAYER_REGISTRY.items():
210
+ catalog.append(
211
+ {
212
+ "type": name,
213
+ "description": spec["description"],
214
+ "params": spec["params"],
215
+ }
216
+ )
217
+ return catalog
letsann/model.py ADDED
@@ -0,0 +1,132 @@
1
+ """High level wrapper around ``tf.keras.Sequential``.
2
+
3
+ The :class:`Model` class lets users describe a network with a list of simple
4
+ Python dicts and then train / evaluate / predict with the usual methods.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import os
11
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
12
+
13
+ import numpy as np
14
+ import tensorflow as tf
15
+
16
+ from .layers import build_keras_layer
17
+
18
+ LayerSpec = Dict[str, Any] # {"type": "Dense", "params": {...}}
19
+
20
+
21
+ class Model:
22
+ """A thin, beginner-friendly wrapper over ``tf.keras.Sequential``.
23
+
24
+ Example::
25
+
26
+ model = Model([
27
+ {"type": "Input", "params": {"shape": "4"}},
28
+ {"type": "Dense", "params": {"units": 16, "activation": "relu"}},
29
+ {"type": "Dense", "params": {"units": 3, "activation": "softmax"}},
30
+ ])
31
+ model.compile(loss="sparse_categorical_crossentropy",
32
+ optimizer="adam", metrics=["accuracy"])
33
+ model.fit(X, y, epochs=20, batch_size=16)
34
+ """
35
+
36
+ def __init__(self, layers: Sequence[LayerSpec]):
37
+ self.layer_specs: List[LayerSpec] = [dict(s) for s in layers]
38
+ self.keras_model: tf.keras.Sequential = self._build()
39
+
40
+ # ------------------------------------------------------------------
41
+ # Construction
42
+ # ------------------------------------------------------------------
43
+ def _build(self) -> tf.keras.Sequential:
44
+ if not self.layer_specs:
45
+ raise ValueError("At least one layer is required.")
46
+ model = tf.keras.Sequential()
47
+ for idx, spec in enumerate(self.layer_specs):
48
+ ltype = spec.get("type")
49
+ params = spec.get("params", {}) or {}
50
+ if ltype is None:
51
+ raise ValueError(f"Layer {idx} has no 'type'.")
52
+ model.add(build_keras_layer(ltype, params))
53
+ return model
54
+
55
+ # ------------------------------------------------------------------
56
+ # Training / inference API — thin pass-through with helpful defaults
57
+ # ------------------------------------------------------------------
58
+ def compile(
59
+ self,
60
+ optimizer: Union[str, tf.keras.optimizers.Optimizer] = "adam",
61
+ loss: str = "mse",
62
+ metrics: Optional[Iterable[str]] = None,
63
+ learning_rate: Optional[float] = None,
64
+ ) -> "Model":
65
+ if learning_rate is not None and isinstance(optimizer, str):
66
+ optimizer = tf.keras.optimizers.get(
67
+ {"class_name": optimizer, "config": {"learning_rate": learning_rate}}
68
+ )
69
+ self.keras_model.compile(
70
+ optimizer=optimizer,
71
+ loss=loss,
72
+ metrics=list(metrics) if metrics else None,
73
+ )
74
+ return self
75
+
76
+ def fit(self, X, y=None, **kwargs):
77
+ return self.keras_model.fit(X, y, **kwargs)
78
+
79
+ def evaluate(self, X, y=None, **kwargs):
80
+ return self.keras_model.evaluate(X, y, **kwargs)
81
+
82
+ def predict(self, X, **kwargs) -> np.ndarray:
83
+ return self.keras_model.predict(X, **kwargs)
84
+
85
+ # ------------------------------------------------------------------
86
+ # Introspection
87
+ # ------------------------------------------------------------------
88
+ def summary(self) -> str:
89
+ lines: List[str] = []
90
+ self.keras_model.summary(print_fn=lines.append)
91
+ return "\n".join(lines)
92
+
93
+ def describe(self) -> List[Dict[str, Any]]:
94
+ """Return a JSON-friendly description of the built Keras layers."""
95
+ info: List[Dict[str, Any]] = []
96
+ for layer in self.keras_model.layers:
97
+ try:
98
+ out_shape = layer.output_shape
99
+ except Exception:
100
+ out_shape = None
101
+ info.append(
102
+ {
103
+ "name": layer.name,
104
+ "class": layer.__class__.__name__,
105
+ "output_shape": out_shape,
106
+ "params": int(layer.count_params()),
107
+ }
108
+ )
109
+ return info
110
+
111
+ @property
112
+ def total_params(self) -> int:
113
+ return int(self.keras_model.count_params())
114
+
115
+ # ------------------------------------------------------------------
116
+ # Persistence
117
+ # ------------------------------------------------------------------
118
+ def save(self, path: str) -> None:
119
+ """Save the underlying Keras model plus the LetsANN spec."""
120
+ self.keras_model.save(path)
121
+ spec_path = os.path.join(path, "letsann_spec.json") if os.path.isdir(path) else path + ".letsann.json"
122
+ with open(spec_path, "w", encoding="utf-8") as f:
123
+ json.dump(self.layer_specs, f, ensure_ascii=False, indent=2)
124
+
125
+ @classmethod
126
+ def from_spec(cls, spec: Sequence[LayerSpec]) -> "Model":
127
+ return cls(spec)
128
+
129
+
130
+ def build_model(spec: Sequence[LayerSpec]) -> Model:
131
+ """Convenience factory — mirror of :meth:`Model.from_spec`."""
132
+ return Model.from_spec(spec)
letsann/trainer.py ADDED
@@ -0,0 +1,215 @@
1
+ """Training orchestration used by both the Python API and the web UI.
2
+
3
+ The :class:`TrainingJob` runs a Keras fit loop in a background thread and
4
+ streams live metrics into an in-memory queue so the web UI can poll or
5
+ stream them via Server-Sent Events.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import threading
11
+ import time
12
+ import uuid
13
+ from collections import deque
14
+ from typing import Any, Deque, Dict, List, Optional
15
+
16
+ import numpy as np
17
+ import tensorflow as tf
18
+
19
+ from .data import Dataset
20
+ from .model import Model
21
+
22
+
23
+ class _StreamCallback(tf.keras.callbacks.Callback):
24
+ """Keras callback that pushes every epoch/batch event into a job."""
25
+
26
+ def __init__(self, job: "TrainingJob") -> None:
27
+ super().__init__()
28
+ self.job = job
29
+
30
+ def on_train_begin(self, logs=None):
31
+ self.job._push({"event": "train_begin", "time": time.time()})
32
+
33
+ def on_epoch_begin(self, epoch, logs=None):
34
+ self.job._push({"event": "epoch_begin", "epoch": int(epoch)})
35
+
36
+ def on_epoch_end(self, epoch, logs=None):
37
+ payload = {"event": "epoch_end", "epoch": int(epoch)}
38
+ payload.update(_clean_logs(logs))
39
+ self.job._record_epoch(payload)
40
+ self.job._push(payload)
41
+
42
+ def on_train_end(self, logs=None):
43
+ self.job._push({"event": "train_end", "time": time.time()})
44
+ self.job.status = "finished"
45
+
46
+
47
+ def _clean_logs(logs: Optional[Dict[str, Any]]) -> Dict[str, float]:
48
+ if not logs:
49
+ return {}
50
+ out: Dict[str, float] = {}
51
+ for k, v in logs.items():
52
+ try:
53
+ out[k] = float(v)
54
+ except (TypeError, ValueError):
55
+ continue
56
+ return out
57
+
58
+
59
+ class TrainingJob:
60
+ """In-process training job with live event streaming."""
61
+
62
+ def __init__(
63
+ self,
64
+ model: Model,
65
+ dataset: Dataset,
66
+ *,
67
+ epochs: int = 10,
68
+ batch_size: int = 32,
69
+ optimizer: str = "adam",
70
+ loss: Optional[str] = None,
71
+ learning_rate: float = 1e-3,
72
+ metrics: Optional[List[str]] = None,
73
+ ) -> None:
74
+ self.id = uuid.uuid4().hex[:12]
75
+ self.model = model
76
+ self.dataset = dataset
77
+ self.epochs = int(epochs)
78
+ self.batch_size = int(batch_size)
79
+ self.optimizer = optimizer
80
+ self.loss = loss or _default_loss(dataset)
81
+ self.learning_rate = float(learning_rate)
82
+ self.metrics = metrics or _default_metrics(dataset)
83
+
84
+ self.status: str = "pending" # pending | running | finished | error
85
+ self.error: Optional[str] = None
86
+ self.history: List[Dict[str, Any]] = []
87
+ self.final_eval: Optional[Dict[str, float]] = None
88
+
89
+ self._events: Deque[Dict[str, Any]] = deque(maxlen=1024)
90
+ self._lock = threading.Lock()
91
+ self._thread: Optional[threading.Thread] = None
92
+
93
+ # ------------------------------------------------------------------
94
+ # Public control
95
+ # ------------------------------------------------------------------
96
+ def start(self) -> None:
97
+ if self._thread and self._thread.is_alive():
98
+ return
99
+ self.status = "running"
100
+ self._thread = threading.Thread(target=self._run, name=f"letsann-train-{self.id}", daemon=True)
101
+ self._thread.start()
102
+
103
+ # ------------------------------------------------------------------
104
+ # Event helpers (used by the callback and API)
105
+ # ------------------------------------------------------------------
106
+ def _push(self, event: Dict[str, Any]) -> None:
107
+ with self._lock:
108
+ self._events.append(event)
109
+
110
+ def _record_epoch(self, event: Dict[str, Any]) -> None:
111
+ self.history.append({k: v for k, v in event.items() if k != "event"})
112
+
113
+ def drain_events(self) -> List[Dict[str, Any]]:
114
+ with self._lock:
115
+ events = list(self._events)
116
+ self._events.clear()
117
+ return events
118
+
119
+ def snapshot(self) -> Dict[str, Any]:
120
+ return {
121
+ "id": self.id,
122
+ "status": self.status,
123
+ "error": self.error,
124
+ "epochs": self.epochs,
125
+ "batch_size": self.batch_size,
126
+ "optimizer": self.optimizer,
127
+ "loss": self.loss,
128
+ "metrics": self.metrics,
129
+ "history": list(self.history),
130
+ "final_eval": self.final_eval,
131
+ "dataset": self.dataset.summary(),
132
+ "model": self.model.describe(),
133
+ "total_params": self.model.total_params,
134
+ }
135
+
136
+ # ------------------------------------------------------------------
137
+ # Internal worker
138
+ # ------------------------------------------------------------------
139
+ def _run(self) -> None:
140
+ try:
141
+ self.model.compile(
142
+ optimizer=self.optimizer,
143
+ loss=self.loss,
144
+ metrics=self.metrics,
145
+ learning_rate=self.learning_rate,
146
+ )
147
+
148
+ self.model.fit(
149
+ self.dataset.X_train,
150
+ self.dataset.y_train,
151
+ validation_data=(self.dataset.X_val, self.dataset.y_val),
152
+ epochs=self.epochs,
153
+ batch_size=self.batch_size,
154
+ verbose=0,
155
+ callbacks=[_StreamCallback(self)],
156
+ )
157
+
158
+ results = self.model.evaluate(
159
+ self.dataset.X_val, self.dataset.y_val, verbose=0, return_dict=True
160
+ )
161
+ self.final_eval = {k: float(v) for k, v in results.items()}
162
+ self._push({"event": "final_eval", **self.final_eval})
163
+ self.status = "finished"
164
+ except Exception as exc: # pragma: no cover
165
+ self.status = "error"
166
+ self.error = str(exc)
167
+ self._push({"event": "error", "message": str(exc)})
168
+
169
+ # ------------------------------------------------------------------
170
+ # Prediction on validation sample
171
+ # ------------------------------------------------------------------
172
+ def sample_predictions(self, n: int = 10) -> Dict[str, Any]:
173
+ if self.status != "finished":
174
+ return {"ready": False}
175
+
176
+ n = int(min(n, self.dataset.X_val.shape[0]))
177
+ X = self.dataset.X_val[:n]
178
+ y_true = self.dataset.y_val[:n]
179
+ y_pred = self.model.predict(X, verbose=0)
180
+
181
+ if self.dataset.task_type == "classification":
182
+ if y_pred.ndim == 2 and y_pred.shape[1] > 1:
183
+ preds = y_pred.argmax(axis=1).tolist()
184
+ confidences = y_pred.max(axis=1).tolist()
185
+ else:
186
+ preds = (y_pred.ravel() > 0.5).astype(int).tolist()
187
+ confidences = y_pred.ravel().tolist()
188
+ return {
189
+ "ready": True,
190
+ "task": "classification",
191
+ "y_true": np.asarray(y_true).ravel().tolist(),
192
+ "y_pred": preds,
193
+ "confidence": confidences,
194
+ }
195
+
196
+ return {
197
+ "ready": True,
198
+ "task": "regression",
199
+ "y_true": np.asarray(y_true).ravel().tolist(),
200
+ "y_pred": np.asarray(y_pred).ravel().tolist(),
201
+ }
202
+
203
+
204
+ def _default_loss(dataset: Dataset) -> str:
205
+ if dataset.task_type == "classification":
206
+ if dataset.n_classes and dataset.n_classes > 2:
207
+ return "sparse_categorical_crossentropy"
208
+ return "binary_crossentropy"
209
+ return "mse"
210
+
211
+
212
+ def _default_metrics(dataset: Dataset) -> List[str]:
213
+ if dataset.task_type == "classification":
214
+ return ["accuracy"]
215
+ return ["mae"]
@@ -0,0 +1,118 @@
1
+ Metadata-Version: 2.4
2
+ Name: LetsANN
3
+ Version: 0.1.0
4
+ Summary: 基于 TensorFlow 的零基础 ANN 库:用简单的 Python 字典就能描述网络。
5
+ Author: LetsANN Contributors
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/letsann/letsann
8
+ Project-URL: Documentation, https://github.com/letsann/letsann#readme
9
+ Project-URL: Issues, https://github.com/letsann/letsann/issues
10
+ Keywords: tensorflow,keras,neural network,ann,deep learning,education
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.8
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Requires-Python: >=3.8
23
+ Description-Content-Type: text/markdown
24
+ License-File: LICENSE
25
+ Requires-Dist: tensorflow>=2.8
26
+ Requires-Dist: numpy>=1.19
27
+ Requires-Dist: pandas>=1.2
28
+ Requires-Dist: scikit-learn>=1.0
29
+ Provides-Extra: dev
30
+ Requires-Dist: pytest>=7.0; extra == "dev"
31
+ Requires-Dist: build; extra == "dev"
32
+ Requires-Dist: twine; extra == "dev"
33
+ Dynamic: license-file
34
+
35
+ # LetsANN
36
+
37
+ **LetsANN** 是一个基于 TensorFlow / Keras 的零基础 ANN 库。
38
+ 用最简单的 Python 列表描述网络,像搭积木一样训练模型。
39
+
40
+ > 需要可视化拖拽界面?请看配套的独立项目 [`letsann-web`](https://github.com/letsann/letsann-web)。
41
+
42
+ ## 安装
43
+
44
+ ```bash
45
+ pip install LetsANN
46
+ ```
47
+
48
+ 要求 Python **3.8 及以上**。
49
+
50
+ ## 最小示例
51
+
52
+ ```python
53
+ from letsann import Model, load_dataset
54
+
55
+ # 用 DataFrame 或 CSV 路径都行,最后一列默认为标签
56
+ ds = load_dataset("iris.csv", target="species")
57
+
58
+ # 用列表描述网络
59
+ model = Model([
60
+ {"type": "Input", "params": {"shape": "4"}},
61
+ {"type": "Dense", "params": {"units": 16, "activation": "relu"}},
62
+ {"type": "Dense", "params": {"units": 3, "activation": "softmax"}},
63
+ ])
64
+
65
+ # 和 Keras 一样编译、训练
66
+ model.compile(optimizer="adam",
67
+ loss="sparse_categorical_crossentropy",
68
+ metrics=["accuracy"])
69
+ model.fit(ds.X_train, ds.y_train,
70
+ validation_data=(ds.X_val, ds.y_val),
71
+ epochs=20, batch_size=16)
72
+
73
+ print(model.summary())
74
+ ```
75
+
76
+ 更多示例见 `examples/quickstart.py`。
77
+
78
+ ## 支持的层
79
+
80
+ `Input`、`Dense`、`Dropout`、`BatchNormalization`、`Flatten`、`Activation`、
81
+ `Conv2D`、`MaxPooling2D`。全部在 `letsann/layers.py` 中注册,想扩展就
82
+ 往 `LAYER_REGISTRY` 里加一条即可。
83
+
84
+ ## 数据集格式
85
+
86
+ - **CSV / TSV**:默认最后一列为标签;用 `target="col"` 指定其它列。
87
+ - **NPZ**:需要包含 `X` 和 `y` 两个数组。
88
+
89
+ ## 发布到 PyPI
90
+
91
+ ```bash
92
+ # 1. 安装打包工具
93
+ pip install build twine
94
+
95
+ # 2. 打包(在本目录运行)
96
+ python -m build # 会生成 dist/LetsANN-0.1.0.tar.gz 和 .whl
97
+
98
+ # 3. 先上传到 TestPyPI 验证
99
+ twine upload --repository testpypi dist/*
100
+
101
+ # 4. 确认没问题后,正式上传 PyPI
102
+ twine upload dist/*
103
+ ```
104
+
105
+ 上传需要在 <https://pypi.org> 先创建账号并生成 API Token,放进
106
+ `~/.pypirc` 或设置环境变量 `TWINE_USERNAME=__token__`、
107
+ `TWINE_PASSWORD=<你的 token>`。
108
+
109
+ ## 开发
110
+
111
+ ```bash
112
+ pip install -e ".[dev]"
113
+ pytest
114
+ ```
115
+
116
+ ## License
117
+
118
+ MIT
@@ -0,0 +1,13 @@
1
+ letsann/__init__.py,sha256=wmJiZCoLBsLuGHXGoHkae0w9d_eVkFkHTAMOhi2x_Vs,1538
2
+ letsann/_version.py,sha256=rBxpUCL_QmUlcUTrpvNt4WqmVZZMX4xqlGYZMVWEpZU,216
3
+ letsann/cli.py,sha256=5glGNHWYlf3wvvgKv0x3ecdz0vcwksZZYlVDHjwqqSk,903
4
+ letsann/data.py,sha256=xw4DEL3SfskhBY2LWXhHR2vNeqJm-q1mmAZf2bVhgy8,6424
5
+ letsann/layers.py,sha256=kQtbbfkyk0Wb2FsapVK18T91yoW6ARRxHdgtKbJnZ34,7088
6
+ letsann/model.py,sha256=joDLd26p7KIyjkpO7M4sx26RTgPRI2flphtphxWeLJU,5020
7
+ letsann/trainer.py,sha256=1FWySQxwRgWwLkdxTrw-Y_QdpsEawlvHJ2gZcBXL1qM,7634
8
+ letsann-0.1.0.dist-info/licenses/LICENSE,sha256=w3SbG5R22PYPf3Z3O6bgCXchLfh_TEE8SA64YsHgE8k,1098
9
+ letsann-0.1.0.dist-info/METADATA,sha256=icZlk_IitVkLQ6Gx14Y2aI8j9tuLiQkqik1gIy1zXL8,3584
10
+ letsann-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
11
+ letsann-0.1.0.dist-info/entry_points.txt,sha256=u3P6q5bQ6oJ307hSLwLZcTL2OSBf4Ir7rBItkUVe-8E,45
12
+ letsann-0.1.0.dist-info/top_level.txt,sha256=Tf524tX4l7BvnwZeemrDEUKpY1rNqeVz8GnEe2wLB2E,8
13
+ letsann-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ letsann = letsann.cli:main
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 LetsANN Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ letsann