umbrellm 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- umbrellm-0.1.0/PKG-INFO +90 -0
- umbrellm-0.1.0/README.md +60 -0
- umbrellm-0.1.0/pyproject.toml +44 -0
- umbrellm-0.1.0/setup.cfg +4 -0
- umbrellm-0.1.0/setup.py +32 -0
- umbrellm-0.1.0/umbrellm/__init__.py +19 -0
- umbrellm-0.1.0/umbrellm/__main__.py +75 -0
- umbrellm-0.1.0/umbrellm/_generate.py +30 -0
- umbrellm-0.1.0/umbrellm/_runtime.py +179 -0
- umbrellm-0.1.0/umbrellm/api.py +105 -0
- umbrellm-0.1.0/umbrellm/config.py +66 -0
- umbrellm-0.1.0/umbrellm/device.py +30 -0
- umbrellm-0.1.0/umbrellm/interaction.py +60 -0
- umbrellm-0.1.0/umbrellm/model.py +268 -0
- umbrellm-0.1.0/umbrellm/models.py +83 -0
- umbrellm-0.1.0/umbrellm/tokenizer.py +273 -0
- umbrellm-0.1.0/umbrellm/training.py +395 -0
- umbrellm-0.1.0/umbrellm/utllm.py +159 -0
- umbrellm-0.1.0/umbrellm.egg-info/PKG-INFO +90 -0
- umbrellm-0.1.0/umbrellm.egg-info/SOURCES.txt +22 -0
- umbrellm-0.1.0/umbrellm.egg-info/dependency_links.txt +1 -0
- umbrellm-0.1.0/umbrellm.egg-info/entry_points.txt +2 -0
- umbrellm-0.1.0/umbrellm.egg-info/requires.txt +3 -0
- umbrellm-0.1.0/umbrellm.egg-info/top_level.txt +1 -0
umbrellm-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: umbrellm
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A lightweight, trainable language model SDK
|
|
5
|
+
Home-page: https://github.com/The-Umbrellm-Project/UmbreLLM
|
|
6
|
+
Author: Fries
|
|
7
|
+
Author-email: Your Name <your.email@example.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Project-URL: Homepage, https://github.com/YOUR_USERNAME/umbrellm
|
|
10
|
+
Project-URL: Repository, https://github.com/YOUR_USERNAME/umbrellm
|
|
11
|
+
Project-URL: Issues, https://github.com/YOUR_USERNAME/umbrellm/issues
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
22
|
+
Requires-Python: >=3.9
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
Requires-Dist: torch>=2.0.0
|
|
25
|
+
Requires-Dist: fastapi>=0.100.0
|
|
26
|
+
Requires-Dist: uvicorn>=0.22.0
|
|
27
|
+
Dynamic: author
|
|
28
|
+
Dynamic: home-page
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
|
|
31
|
+
# Umbrellm
|
|
32
|
+
|
|
33
|
+
A small language model framework that lets you train and run your own decoder-only transformer locally.
|
|
34
|
+
|
|
35
|
+
## Overview
|
|
36
|
+
|
|
37
|
+
Umbrellm is a language model implementation written in Python and PyTorch. The default model configuration uses 12 transformer layers, 8 attention heads, and a hidden size of 512 dimensions, resulting in roughly 20 million parameters.
|
|
38
|
+
|
|
39
|
+
The goal of the project is to provide an implementation that is easy to understand and modify. It includes the components needed to train a model from scratch, prepare datasets, run inference, and expose the model through a simple API.
|
|
40
|
+
|
|
41
|
+
The architecture uses:
|
|
42
|
+
|
|
43
|
+
* Rotary Position Embeddings (RoPE)
|
|
44
|
+
* SwiGLU feed-forward activations
|
|
45
|
+
* RMSNorm
|
|
46
|
+
* Causal multi-head self-attention
|
|
47
|
+
|
|
48
|
+
## Features
|
|
49
|
+
|
|
50
|
+
* Train models from scratch using your own datasets
|
|
51
|
+
* Save and resume training from checkpoints
|
|
52
|
+
* Learning rate scheduling and gradient clipping
|
|
53
|
+
* Optional mixed-precision training
|
|
54
|
+
* BPE tokenizer training on custom corpora
|
|
55
|
+
* Dataset compilation using the `.ullm` format
|
|
56
|
+
* Text generation with temperature, top-k, top-p, and repetition penalty controls
|
|
57
|
+
* Token-by-token streaming generation
|
|
58
|
+
* FastAPI-based REST API server
|
|
59
|
+
|
|
60
|
+
## Quick start
|
|
61
|
+
|
|
62
|
+
```bash
|
|
63
|
+
pip install torch fastapi uvicorn
|
|
64
|
+
|
|
65
|
+
python scripts/compile_dataset.py
|
|
66
|
+
python scripts/train.py --epochs 3
|
|
67
|
+
python scripts/inference.py --interactive
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## Python SDK
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
import umbrellm
|
|
74
|
+
|
|
75
|
+
umbrellm.utllm.compile()
|
|
76
|
+
umbrellm.training.train({"epochs": 3})
|
|
77
|
+
|
|
78
|
+
response = umbrellm.generate("Explain black holes simply.")
|
|
79
|
+
print(response)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
## Requirements
|
|
83
|
+
|
|
84
|
+
* Python 3.9 or newer
|
|
85
|
+
* PyTorch 2.0 or newer
|
|
86
|
+
* FastAPI and uvicorn (only required for the API server)
|
|
87
|
+
|
|
88
|
+
## License
|
|
89
|
+
|
|
90
|
+
Released under the MIT License.
|
umbrellm-0.1.0/README.md
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Umbrellm
|
|
2
|
+
|
|
3
|
+
A small language model framework that lets you train and run your own decoder-only transformer locally.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
Umbrellm is a language model implementation written in Python and PyTorch. The default model configuration uses 12 transformer layers, 8 attention heads, and a hidden size of 512 dimensions, resulting in roughly 20 million parameters.
|
|
8
|
+
|
|
9
|
+
The goal of the project is to provide an implementation that is easy to understand and modify. It includes the components needed to train a model from scratch, prepare datasets, run inference, and expose the model through a simple API.
|
|
10
|
+
|
|
11
|
+
The architecture uses:
|
|
12
|
+
|
|
13
|
+
* Rotary Position Embeddings (RoPE)
|
|
14
|
+
* SwiGLU feed-forward activations
|
|
15
|
+
* RMSNorm
|
|
16
|
+
* Causal multi-head self-attention
|
|
17
|
+
|
|
18
|
+
## Features
|
|
19
|
+
|
|
20
|
+
* Train models from scratch using your own datasets
|
|
21
|
+
* Save and resume training from checkpoints
|
|
22
|
+
* Learning rate scheduling and gradient clipping
|
|
23
|
+
* Optional mixed-precision training
|
|
24
|
+
* BPE tokenizer training on custom corpora
|
|
25
|
+
* Dataset compilation using the `.ullm` format
|
|
26
|
+
* Text generation with temperature, top-k, top-p, and repetition penalty controls
|
|
27
|
+
* Token-by-token streaming generation
|
|
28
|
+
* FastAPI-based REST API server
|
|
29
|
+
|
|
30
|
+
## Quick start
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
pip install torch fastapi uvicorn
|
|
34
|
+
|
|
35
|
+
python scripts/compile_dataset.py
|
|
36
|
+
python scripts/train.py --epochs 3
|
|
37
|
+
python scripts/inference.py --interactive
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Python SDK
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
import umbrellm
|
|
44
|
+
|
|
45
|
+
umbrellm.utllm.compile()
|
|
46
|
+
umbrellm.training.train({"epochs": 3})
|
|
47
|
+
|
|
48
|
+
response = umbrellm.generate("Explain black holes simply.")
|
|
49
|
+
print(response)
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## Requirements
|
|
53
|
+
|
|
54
|
+
* Python 3.9 or newer
|
|
55
|
+
* PyTorch 2.0 or newer
|
|
56
|
+
* FastAPI and uvicorn (only required for the API server)
|
|
57
|
+
|
|
58
|
+
## License
|
|
59
|
+
|
|
60
|
+
Released under the MIT License.
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "umbrellm"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "A lightweight, trainable language model SDK"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = {text = "MIT"}
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Your Name", email = "your.email@example.com"}
|
|
14
|
+
]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Development Status :: 3 - Alpha",
|
|
17
|
+
"Intended Audience :: Developers",
|
|
18
|
+
"Intended Audience :: Science/Research",
|
|
19
|
+
"License :: OSI Approved :: MIT License",
|
|
20
|
+
"Programming Language :: Python :: 3",
|
|
21
|
+
"Programming Language :: Python :: 3.9",
|
|
22
|
+
"Programming Language :: Python :: 3.10",
|
|
23
|
+
"Programming Language :: Python :: 3.11",
|
|
24
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
25
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
dependencies = [
|
|
29
|
+
"torch>=2.0.0",
|
|
30
|
+
"fastapi>=0.100.0",
|
|
31
|
+
"uvicorn>=0.22.0",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[project.urls]
|
|
35
|
+
Homepage = "https://github.com/YOUR_USERNAME/umbrellm"
|
|
36
|
+
Repository = "https://github.com/YOUR_USERNAME/umbrellm"
|
|
37
|
+
Issues = "https://github.com/YOUR_USERNAME/umbrellm/issues"
|
|
38
|
+
|
|
39
|
+
[project.scripts]
|
|
40
|
+
umbrellm = "umbrellm.__main__:main"
|
|
41
|
+
|
|
42
|
+
[tool.setuptools.packages.find]
|
|
43
|
+
where = ["."]
|
|
44
|
+
include = ["umbrellm*"]
|
umbrellm-0.1.0/setup.cfg
ADDED
umbrellm-0.1.0/setup.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
with open("README.md", "r", encoding="utf-8") as fh:
|
|
4
|
+
long_description = fh.read()
|
|
5
|
+
|
|
6
|
+
setup(
|
|
7
|
+
name="umbrellm",
|
|
8
|
+
version="0.1.0",
|
|
9
|
+
author="Fries",
|
|
10
|
+
author_email="freshfrenchfries@proton.me",
|
|
11
|
+
description="A lightweight, trainable language model SDK",
|
|
12
|
+
long_description=long_description,
|
|
13
|
+
long_description_content_type="text/markdown",
|
|
14
|
+
url="https://github.com/The-Umbrellm-Project/UmbreLLM",
|
|
15
|
+
packages=find_packages(),
|
|
16
|
+
classifiers=[
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"License :: OSI Approved :: MIT License",
|
|
19
|
+
"Operating System :: OS Independent",
|
|
20
|
+
],
|
|
21
|
+
python_requires=">=3.9",
|
|
22
|
+
install_requires=[
|
|
23
|
+
"torch>=2.0.0",
|
|
24
|
+
"fastapi>=0.100.0",
|
|
25
|
+
"uvicorn>=0.22.0",
|
|
26
|
+
],
|
|
27
|
+
entry_points={
|
|
28
|
+
"console_scripts": [
|
|
29
|
+
"umbrellm=umbrellm.__main__:main",
|
|
30
|
+
],
|
|
31
|
+
},
|
|
32
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Umbrellm – a lightweight, trainable language model SDK.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from umbrellm import interaction, tokenizer, models, training, utllm, config, device
|
|
6
|
+
from umbrellm._generate import generate, complete
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
__all__ = [
|
|
10
|
+
"interaction",
|
|
11
|
+
"tokenizer",
|
|
12
|
+
"models",
|
|
13
|
+
"training",
|
|
14
|
+
"utllm",
|
|
15
|
+
"config",
|
|
16
|
+
"device",
|
|
17
|
+
"generate",
|
|
18
|
+
"complete",
|
|
19
|
+
]
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
python -m umbrellm entry-point.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def main():
|
|
10
|
+
parser = argparse.ArgumentParser(prog="umbrellm")
|
|
11
|
+
sub = parser.add_subparsers(dest="cmd")
|
|
12
|
+
|
|
13
|
+
# train
|
|
14
|
+
p_train = sub.add_parser("train", help="Train the model")
|
|
15
|
+
p_train.add_argument("--epochs", type=int, default=3)
|
|
16
|
+
p_train.add_argument("--batch-size", type=int, default=4)
|
|
17
|
+
p_train.add_argument("--lr", type=float, default=3e-4)
|
|
18
|
+
|
|
19
|
+
# chat
|
|
20
|
+
p_chat = sub.add_parser("chat", help="Interactive chat")
|
|
21
|
+
p_chat.add_argument("--model", default="umbrellm-20m")
|
|
22
|
+
|
|
23
|
+
# compile
|
|
24
|
+
sub.add_parser("compile", help="Compile UTLLM dataset")
|
|
25
|
+
|
|
26
|
+
# stats
|
|
27
|
+
sub.add_parser("stats", help="Dataset statistics")
|
|
28
|
+
|
|
29
|
+
# serve
|
|
30
|
+
sub.add_parser("serve", help="Start API server")
|
|
31
|
+
|
|
32
|
+
args = parser.parse_args()
|
|
33
|
+
|
|
34
|
+
if args.cmd == "train":
|
|
35
|
+
import umbrellm.training as tr
|
|
36
|
+
tr.train({"epochs": args.epochs, "batch_size": args.batch_size,
|
|
37
|
+
"learning_rate": args.lr})
|
|
38
|
+
|
|
39
|
+
elif args.cmd == "chat":
|
|
40
|
+
import umbrellm.interaction as inter
|
|
41
|
+
print(f"Umbrellm interactive chat (model: {args.model}). Type exit to quit.")
|
|
42
|
+
history = []
|
|
43
|
+
while True:
|
|
44
|
+
try:
|
|
45
|
+
user = input("You: ").strip()
|
|
46
|
+
except (EOFError, KeyboardInterrupt):
|
|
47
|
+
break
|
|
48
|
+
if user.lower() in ("exit", "quit", "q"):
|
|
49
|
+
break
|
|
50
|
+
if not user:
|
|
51
|
+
continue
|
|
52
|
+
res = inter.chat({"user": user, "model": args.model, "history": history})
|
|
53
|
+
print(f"Umbrellm: {res['response']}")
|
|
54
|
+
history.append({"user": user, "assistant": res["response"]})
|
|
55
|
+
|
|
56
|
+
elif args.cmd == "compile":
|
|
57
|
+
import umbrellm.utllm as ut
|
|
58
|
+
ut.compile()
|
|
59
|
+
|
|
60
|
+
elif args.cmd == "stats":
|
|
61
|
+
import umbrellm.utllm as ut
|
|
62
|
+
import json
|
|
63
|
+
s = ut.stats()
|
|
64
|
+
print(json.dumps(s, indent=2))
|
|
65
|
+
|
|
66
|
+
elif args.cmd == "serve":
|
|
67
|
+
from umbrellm.api import main as serve_main
|
|
68
|
+
serve_main()
|
|
69
|
+
|
|
70
|
+
else:
|
|
71
|
+
parser.print_help()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
main()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Top-level generate / complete helpers.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from umbrellm import interaction as _interaction
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def generate(prompt: str, **kwargs) -> str:
|
|
9
|
+
"""
|
|
10
|
+
Generate text from a prompt.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
text = umbrellm.generate("Write a poem about stars.")
|
|
14
|
+
"""
|
|
15
|
+
kwargs.setdefault("model", "umbrellm-20m")
|
|
16
|
+
result = _interaction.chat({"user": prompt, **kwargs})
|
|
17
|
+
return result["response"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def complete(prompt: str, **kwargs) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Complete a text prefix.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
text = umbrellm.complete("The meaning of life is")
|
|
26
|
+
"""
|
|
27
|
+
kwargs.setdefault("model", "umbrellm-20m")
|
|
28
|
+
kwargs.setdefault("system_prompt", "You are a helpful text completion engine. Continue the text naturally.")
|
|
29
|
+
result = _interaction.chat({"user": prompt, **kwargs})
|
|
30
|
+
return result["response"]
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Internal inference runtime.
|
|
3
|
+
Handles model loading, prompt formatting, and generation.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Iterator, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_SYSTEM_DEFAULT = (
|
|
13
|
+
"You are Umbrellm, a helpful AI assistant. "
|
|
14
|
+
"Respond clearly, concisely, and helpfully."
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InferenceRuntime:
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self._model = None
|
|
21
|
+
self._cfg = None
|
|
22
|
+
self._device = None
|
|
23
|
+
self._tokenizer_loaded = False
|
|
24
|
+
|
|
25
|
+
def _ensure_ready(self, model_name: str):
|
|
26
|
+
from umbrellm.config import load as load_cfg
|
|
27
|
+
from umbrellm.device import current as cur_device
|
|
28
|
+
import umbrellm.tokenizer as tok_mod
|
|
29
|
+
|
|
30
|
+
if self._cfg is None:
|
|
31
|
+
self._cfg = load_cfg()
|
|
32
|
+
|
|
33
|
+
if self._device is None:
|
|
34
|
+
self._device = cur_device()
|
|
35
|
+
|
|
36
|
+
# Ensure tokenizer is loaded
|
|
37
|
+
if not self._tokenizer_loaded:
|
|
38
|
+
try:
|
|
39
|
+
tok_mod.load()
|
|
40
|
+
except Exception:
|
|
41
|
+
pass
|
|
42
|
+
self._tokenizer_loaded = True
|
|
43
|
+
|
|
44
|
+
if self._model is None:
|
|
45
|
+
self._model = self._load_model(model_name)
|
|
46
|
+
|
|
47
|
+
def _load_model(self, model_name: str):
|
|
48
|
+
try:
|
|
49
|
+
import torch
|
|
50
|
+
from umbrellm.model import build_model
|
|
51
|
+
import umbrellm.models as mdl_mod
|
|
52
|
+
|
|
53
|
+
# Try loading existing checkpoint
|
|
54
|
+
ckpt_dir = self._cfg.get("checkpoint_dir", "checkpoints")
|
|
55
|
+
candidates = [
|
|
56
|
+
os.path.join(ckpt_dir, "latest.pt"),
|
|
57
|
+
os.path.join(ckpt_dir, "best.pt"),
|
|
58
|
+
]
|
|
59
|
+
# Also check if model_name is a path
|
|
60
|
+
if os.path.exists(model_name):
|
|
61
|
+
candidates.insert(0, model_name)
|
|
62
|
+
|
|
63
|
+
for path in candidates:
|
|
64
|
+
if os.path.exists(path):
|
|
65
|
+
model = mdl_mod.load(path)
|
|
66
|
+
return model.to(self._device)
|
|
67
|
+
|
|
68
|
+
# No checkpoint found – build a fresh (random) model
|
|
69
|
+
print(f"[runtime] No checkpoint found. Building fresh model for inference.")
|
|
70
|
+
model = build_model(self._cfg)
|
|
71
|
+
model.eval()
|
|
72
|
+
return model.to(self._device)
|
|
73
|
+
|
|
74
|
+
except ImportError as e:
|
|
75
|
+
raise RuntimeError(
|
|
76
|
+
"PyTorch is required. Install with: pip install torch"
|
|
77
|
+
) from e
|
|
78
|
+
|
|
79
|
+
def _build_prompt(self, params: dict) -> str:
|
|
80
|
+
system = params.get("system_prompt", _SYSTEM_DEFAULT)
|
|
81
|
+
history = params.get("history", [])
|
|
82
|
+
user = params.get("user", "")
|
|
83
|
+
|
|
84
|
+
parts = [f"<|system|>{system}<|assistant|>"]
|
|
85
|
+
for turn in history:
|
|
86
|
+
parts.append(f"<|user|>{turn.get('user', '')}")
|
|
87
|
+
parts.append(f"<|assistant|>{turn.get('assistant', '')}")
|
|
88
|
+
parts.append(f"<|user|>{user}")
|
|
89
|
+
parts.append("<|assistant|>")
|
|
90
|
+
return "".join(parts)
|
|
91
|
+
|
|
92
|
+
def chat(self, params: dict) -> dict:
|
|
93
|
+
model_name = params.get("model", "umbrellm-20m")
|
|
94
|
+
self._ensure_ready(model_name)
|
|
95
|
+
|
|
96
|
+
prompt = self._build_prompt(params)
|
|
97
|
+
start = time.time()
|
|
98
|
+
|
|
99
|
+
import torch
|
|
100
|
+
import umbrellm.tokenizer as tok_mod
|
|
101
|
+
from umbrellm.tokenizer import ASSISTANT_TOKEN, EOS_TOKEN
|
|
102
|
+
|
|
103
|
+
input_ids = tok_mod.encode(prompt, add_special_tokens=False)
|
|
104
|
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self._device)
|
|
105
|
+
|
|
106
|
+
stop_ids = []
|
|
107
|
+
eos_id = tok_mod._tokenizer.vocab.get(EOS_TOKEN)
|
|
108
|
+
if eos_id is not None:
|
|
109
|
+
stop_ids.append(eos_id)
|
|
110
|
+
|
|
111
|
+
with torch.no_grad():
|
|
112
|
+
output = self._model.generate(
|
|
113
|
+
input_tensor,
|
|
114
|
+
max_new_tokens=params.get("max_tokens", self._cfg.get("max_new_tokens", 256)),
|
|
115
|
+
temperature=params.get("temperature", self._cfg.get("temperature", 0.8)),
|
|
116
|
+
top_k=params.get("top_k", self._cfg.get("top_k", 50)),
|
|
117
|
+
top_p=params.get("top_p", self._cfg.get("top_p", 0.95)),
|
|
118
|
+
repetition_penalty=params.get("repetition_penalty",
|
|
119
|
+
self._cfg.get("repetition_penalty", 1.1)),
|
|
120
|
+
stop_ids=stop_ids or None,
|
|
121
|
+
seed=params.get("seed"),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
new_ids = output[0, len(input_ids):].tolist()
|
|
125
|
+
response = tok_mod.decode(new_ids, skip_special_tokens=True)
|
|
126
|
+
|
|
127
|
+
# Trim at stop sequences
|
|
128
|
+
for seq in params.get("stop_sequences", []):
|
|
129
|
+
idx = response.find(seq)
|
|
130
|
+
if idx != -1:
|
|
131
|
+
response = response[:idx]
|
|
132
|
+
|
|
133
|
+
elapsed = time.time() - start
|
|
134
|
+
return {
|
|
135
|
+
"response": response.strip(),
|
|
136
|
+
"tokens_generated": len(new_ids),
|
|
137
|
+
"generation_time": round(elapsed, 4),
|
|
138
|
+
"model": model_name,
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
def stream_chat(self, params: dict) -> Iterator[str]:
|
|
142
|
+
"""Yield tokens one-by-one using a simple greedy decode loop."""
|
|
143
|
+
model_name = params.get("model", "umbrellm-20m")
|
|
144
|
+
self._ensure_ready(model_name)
|
|
145
|
+
|
|
146
|
+
prompt = self._build_prompt(params)
|
|
147
|
+
import torch
|
|
148
|
+
import torch.nn.functional as F
|
|
149
|
+
import umbrellm.tokenizer as tok_mod
|
|
150
|
+
from umbrellm.tokenizer import EOS_TOKEN
|
|
151
|
+
|
|
152
|
+
input_ids = tok_mod.encode(prompt, add_special_tokens=False)
|
|
153
|
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self._device)
|
|
154
|
+
|
|
155
|
+
temperature = params.get("temperature", self._cfg.get("temperature", 0.8))
|
|
156
|
+
top_k = params.get("top_k", self._cfg.get("top_k", 50))
|
|
157
|
+
max_new_tokens = params.get("max_tokens", self._cfg.get("max_new_tokens", 256))
|
|
158
|
+
stop_ids = set()
|
|
159
|
+
eos_id = tok_mod._tokenizer.vocab.get(EOS_TOKEN)
|
|
160
|
+
if eos_id is not None:
|
|
161
|
+
stop_ids.add(eos_id)
|
|
162
|
+
|
|
163
|
+
generated = input_tensor.clone()
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
for _ in range(max_new_tokens):
|
|
166
|
+
ctx = generated if generated.shape[1] <= self._model.max_seq_len else generated[:, -self._model.max_seq_len:]
|
|
167
|
+
logits, _ = self._model(ctx)
|
|
168
|
+
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
|
169
|
+
if top_k > 0:
|
|
170
|
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
171
|
+
logits[logits < v[:, [-1]]] = float("-inf")
|
|
172
|
+
probs = F.softmax(logits, dim=-1)
|
|
173
|
+
next_tok = torch.multinomial(probs, num_samples=1)
|
|
174
|
+
tok_id = next_tok.item()
|
|
175
|
+
if tok_id in stop_ids:
|
|
176
|
+
break
|
|
177
|
+
generated = torch.cat([generated, next_tok], dim=1)
|
|
178
|
+
token_text = tok_mod.decode([tok_id], skip_special_tokens=True)
|
|
179
|
+
yield token_text
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Umbrellm REST API server (FastAPI).
|
|
3
|
+
|
|
4
|
+
Run:
|
|
5
|
+
python -m umbrellm.api
|
|
6
|
+
or:
|
|
7
|
+
uvicorn umbrellm.api:app --host 0.0.0.0 --port 8080
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import time
|
|
11
|
+
from typing import Optional, List
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from fastapi import FastAPI, HTTPException
|
|
15
|
+
from fastapi.responses import StreamingResponse
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
_FASTAPI_OK = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
_FASTAPI_OK = False
|
|
20
|
+
|
|
21
|
+
if _FASTAPI_OK:
|
|
22
|
+
app = FastAPI(
|
|
23
|
+
title="Umbrellm API",
|
|
24
|
+
description="REST API for the Umbrellm language model.",
|
|
25
|
+
version="0.1.0",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
class ChatRequest(BaseModel):
|
|
29
|
+
user: str
|
|
30
|
+
model: str = "umbrellm-20m"
|
|
31
|
+
history: Optional[List[dict]] = None
|
|
32
|
+
system_prompt: Optional[str] = None
|
|
33
|
+
temperature: Optional[float] = None
|
|
34
|
+
top_k: Optional[int] = None
|
|
35
|
+
top_p: Optional[float] = None
|
|
36
|
+
repetition_penalty: Optional[float] = None
|
|
37
|
+
max_tokens: Optional[int] = None
|
|
38
|
+
stop_sequences: Optional[List[str]] = None
|
|
39
|
+
stream: bool = False
|
|
40
|
+
seed: Optional[int] = None
|
|
41
|
+
|
|
42
|
+
class ChatResponse(BaseModel):
|
|
43
|
+
response: str
|
|
44
|
+
tokens_generated: int
|
|
45
|
+
generation_time: float
|
|
46
|
+
model: str
|
|
47
|
+
|
|
48
|
+
@app.get("/")
|
|
49
|
+
def root():
|
|
50
|
+
return {"service": "Umbrellm API", "version": "0.1.0"}
|
|
51
|
+
|
|
52
|
+
@app.get("/health")
|
|
53
|
+
def health():
|
|
54
|
+
return {"status": "ok", "timestamp": time.time()}
|
|
55
|
+
|
|
56
|
+
@app.get("/models")
|
|
57
|
+
def list_models():
|
|
58
|
+
import umbrellm.models as mdl
|
|
59
|
+
checkpoints = mdl.list()
|
|
60
|
+
return {"models": ["umbrellm-20m"] + checkpoints}
|
|
61
|
+
|
|
62
|
+
@app.post("/chat", response_model=ChatResponse)
|
|
63
|
+
def chat(req: ChatRequest):
|
|
64
|
+
import umbrellm.interaction as inter
|
|
65
|
+
params = req.dict(exclude_none=True)
|
|
66
|
+
if req.stream:
|
|
67
|
+
def _gen():
|
|
68
|
+
for tok in inter.stream_chat(params):
|
|
69
|
+
yield tok
|
|
70
|
+
return StreamingResponse(_gen(), media_type="text/plain")
|
|
71
|
+
result = inter.chat(params)
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
@app.post("/tokenize")
|
|
75
|
+
def tokenize(body: dict):
|
|
76
|
+
import umbrellm.tokenizer as tok
|
|
77
|
+
text = body.get("text", "")
|
|
78
|
+
ids = tok.encode(text)
|
|
79
|
+
return {"tokens": ids, "count": len(ids)}
|
|
80
|
+
|
|
81
|
+
@app.post("/detokenize")
|
|
82
|
+
def detokenize(body: dict):
|
|
83
|
+
import umbrellm.tokenizer as tok
|
|
84
|
+
ids = body.get("ids", [])
|
|
85
|
+
text = tok.decode(ids)
|
|
86
|
+
return {"text": text}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def main():
|
|
90
|
+
if not _FASTAPI_OK:
|
|
91
|
+
print("[api] FastAPI/uvicorn not installed. Run: pip install fastapi uvicorn")
|
|
92
|
+
return
|
|
93
|
+
import uvicorn
|
|
94
|
+
from umbrellm.config import load as load_cfg
|
|
95
|
+
cfg = load_cfg()
|
|
96
|
+
uvicorn.run(
|
|
97
|
+
"umbrellm.api:app",
|
|
98
|
+
host=cfg.get("api_host", "0.0.0.0"),
|
|
99
|
+
port=cfg.get("api_port", 8080),
|
|
100
|
+
reload=False,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
if __name__ == "__main__":
|
|
105
|
+
main()
|