umbrellm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
umbrellm/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """
2
+ Umbrellm – a lightweight, trainable language model SDK.
3
+ """
4
+
5
+ from umbrellm import interaction, tokenizer, models, training, utllm, config, device
6
+ from umbrellm._generate import generate, complete
7
+
8
+ __version__ = "0.1.0"
9
+ __all__ = [
10
+ "interaction",
11
+ "tokenizer",
12
+ "models",
13
+ "training",
14
+ "utllm",
15
+ "config",
16
+ "device",
17
+ "generate",
18
+ "complete",
19
+ ]
umbrellm/__main__.py ADDED
@@ -0,0 +1,75 @@
1
+ """
2
+ python -m umbrellm entry-point.
3
+ """
4
+
5
+ import sys
6
+ import argparse
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(prog="umbrellm")
11
+ sub = parser.add_subparsers(dest="cmd")
12
+
13
+ # train
14
+ p_train = sub.add_parser("train", help="Train the model")
15
+ p_train.add_argument("--epochs", type=int, default=3)
16
+ p_train.add_argument("--batch-size", type=int, default=4)
17
+ p_train.add_argument("--lr", type=float, default=3e-4)
18
+
19
+ # chat
20
+ p_chat = sub.add_parser("chat", help="Interactive chat")
21
+ p_chat.add_argument("--model", default="umbrellm-20m")
22
+
23
+ # compile
24
+ sub.add_parser("compile", help="Compile UTLLM dataset")
25
+
26
+ # stats
27
+ sub.add_parser("stats", help="Dataset statistics")
28
+
29
+ # serve
30
+ sub.add_parser("serve", help="Start API server")
31
+
32
+ args = parser.parse_args()
33
+
34
+ if args.cmd == "train":
35
+ import umbrellm.training as tr
36
+ tr.train({"epochs": args.epochs, "batch_size": args.batch_size,
37
+ "learning_rate": args.lr})
38
+
39
+ elif args.cmd == "chat":
40
+ import umbrellm.interaction as inter
41
+ print(f"Umbrellm interactive chat (model: {args.model}). Type exit to quit.")
42
+ history = []
43
+ while True:
44
+ try:
45
+ user = input("You: ").strip()
46
+ except (EOFError, KeyboardInterrupt):
47
+ break
48
+ if user.lower() in ("exit", "quit", "q"):
49
+ break
50
+ if not user:
51
+ continue
52
+ res = inter.chat({"user": user, "model": args.model, "history": history})
53
+ print(f"Umbrellm: {res['response']}")
54
+ history.append({"user": user, "assistant": res["response"]})
55
+
56
+ elif args.cmd == "compile":
57
+ import umbrellm.utllm as ut
58
+ ut.compile()
59
+
60
+ elif args.cmd == "stats":
61
+ import umbrellm.utllm as ut
62
+ import json
63
+ s = ut.stats()
64
+ print(json.dumps(s, indent=2))
65
+
66
+ elif args.cmd == "serve":
67
+ from umbrellm.api import main as serve_main
68
+ serve_main()
69
+
70
+ else:
71
+ parser.print_help()
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
umbrellm/_generate.py ADDED
@@ -0,0 +1,30 @@
1
+ """
2
+ Top-level generate / complete helpers.
3
+ """
4
+
5
+ from umbrellm import interaction as _interaction
6
+
7
+
8
+ def generate(prompt: str, **kwargs) -> str:
9
+ """
10
+ Generate text from a prompt.
11
+
12
+ Example:
13
+ text = umbrellm.generate("Write a poem about stars.")
14
+ """
15
+ kwargs.setdefault("model", "umbrellm-20m")
16
+ result = _interaction.chat({"user": prompt, **kwargs})
17
+ return result["response"]
18
+
19
+
20
+ def complete(prompt: str, **kwargs) -> str:
21
+ """
22
+ Complete a text prefix.
23
+
24
+ Example:
25
+ text = umbrellm.complete("The meaning of life is")
26
+ """
27
+ kwargs.setdefault("model", "umbrellm-20m")
28
+ kwargs.setdefault("system_prompt", "You are a helpful text completion engine. Continue the text naturally.")
29
+ result = _interaction.chat({"user": prompt, **kwargs})
30
+ return result["response"]
umbrellm/_runtime.py ADDED
@@ -0,0 +1,179 @@
1
+ """
2
+ Internal inference runtime.
3
+ Handles model loading, prompt formatting, and generation.
4
+ """
5
+
6
+ import time
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Iterator, Optional
10
+
11
+
12
+ _SYSTEM_DEFAULT = (
13
+ "You are Umbrellm, a helpful AI assistant. "
14
+ "Respond clearly, concisely, and helpfully."
15
+ )
16
+
17
+
18
+ class InferenceRuntime:
19
+ def __init__(self):
20
+ self._model = None
21
+ self._cfg = None
22
+ self._device = None
23
+ self._tokenizer_loaded = False
24
+
25
+ def _ensure_ready(self, model_name: str):
26
+ from umbrellm.config import load as load_cfg
27
+ from umbrellm.device import current as cur_device
28
+ import umbrellm.tokenizer as tok_mod
29
+
30
+ if self._cfg is None:
31
+ self._cfg = load_cfg()
32
+
33
+ if self._device is None:
34
+ self._device = cur_device()
35
+
36
+ # Ensure tokenizer is loaded
37
+ if not self._tokenizer_loaded:
38
+ try:
39
+ tok_mod.load()
40
+ except Exception:
41
+ pass
42
+ self._tokenizer_loaded = True
43
+
44
+ if self._model is None:
45
+ self._model = self._load_model(model_name)
46
+
47
+ def _load_model(self, model_name: str):
48
+ try:
49
+ import torch
50
+ from umbrellm.model import build_model
51
+ import umbrellm.models as mdl_mod
52
+
53
+ # Try loading existing checkpoint
54
+ ckpt_dir = self._cfg.get("checkpoint_dir", "checkpoints")
55
+ candidates = [
56
+ os.path.join(ckpt_dir, "latest.pt"),
57
+ os.path.join(ckpt_dir, "best.pt"),
58
+ ]
59
+ # Also check if model_name is a path
60
+ if os.path.exists(model_name):
61
+ candidates.insert(0, model_name)
62
+
63
+ for path in candidates:
64
+ if os.path.exists(path):
65
+ model = mdl_mod.load(path)
66
+ return model.to(self._device)
67
+
68
+ # No checkpoint found – build a fresh (random) model
69
+ print(f"[runtime] No checkpoint found. Building fresh model for inference.")
70
+ model = build_model(self._cfg)
71
+ model.eval()
72
+ return model.to(self._device)
73
+
74
+ except ImportError as e:
75
+ raise RuntimeError(
76
+ "PyTorch is required. Install with: pip install torch"
77
+ ) from e
78
+
79
+ def _build_prompt(self, params: dict) -> str:
80
+ system = params.get("system_prompt", _SYSTEM_DEFAULT)
81
+ history = params.get("history", [])
82
+ user = params.get("user", "")
83
+
84
+ parts = [f"<|system|>{system}<|assistant|>"]
85
+ for turn in history:
86
+ parts.append(f"<|user|>{turn.get('user', '')}")
87
+ parts.append(f"<|assistant|>{turn.get('assistant', '')}")
88
+ parts.append(f"<|user|>{user}")
89
+ parts.append("<|assistant|>")
90
+ return "".join(parts)
91
+
92
+ def chat(self, params: dict) -> dict:
93
+ model_name = params.get("model", "umbrellm-20m")
94
+ self._ensure_ready(model_name)
95
+
96
+ prompt = self._build_prompt(params)
97
+ start = time.time()
98
+
99
+ import torch
100
+ import umbrellm.tokenizer as tok_mod
101
+ from umbrellm.tokenizer import ASSISTANT_TOKEN, EOS_TOKEN
102
+
103
+ input_ids = tok_mod.encode(prompt, add_special_tokens=False)
104
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self._device)
105
+
106
+ stop_ids = []
107
+ eos_id = tok_mod._tokenizer.vocab.get(EOS_TOKEN)
108
+ if eos_id is not None:
109
+ stop_ids.append(eos_id)
110
+
111
+ with torch.no_grad():
112
+ output = self._model.generate(
113
+ input_tensor,
114
+ max_new_tokens=params.get("max_tokens", self._cfg.get("max_new_tokens", 256)),
115
+ temperature=params.get("temperature", self._cfg.get("temperature", 0.8)),
116
+ top_k=params.get("top_k", self._cfg.get("top_k", 50)),
117
+ top_p=params.get("top_p", self._cfg.get("top_p", 0.95)),
118
+ repetition_penalty=params.get("repetition_penalty",
119
+ self._cfg.get("repetition_penalty", 1.1)),
120
+ stop_ids=stop_ids or None,
121
+ seed=params.get("seed"),
122
+ )
123
+
124
+ new_ids = output[0, len(input_ids):].tolist()
125
+ response = tok_mod.decode(new_ids, skip_special_tokens=True)
126
+
127
+ # Trim at stop sequences
128
+ for seq in params.get("stop_sequences", []):
129
+ idx = response.find(seq)
130
+ if idx != -1:
131
+ response = response[:idx]
132
+
133
+ elapsed = time.time() - start
134
+ return {
135
+ "response": response.strip(),
136
+ "tokens_generated": len(new_ids),
137
+ "generation_time": round(elapsed, 4),
138
+ "model": model_name,
139
+ }
140
+
141
+ def stream_chat(self, params: dict) -> Iterator[str]:
142
+ """Yield tokens one-by-one using a simple greedy decode loop."""
143
+ model_name = params.get("model", "umbrellm-20m")
144
+ self._ensure_ready(model_name)
145
+
146
+ prompt = self._build_prompt(params)
147
+ import torch
148
+ import torch.nn.functional as F
149
+ import umbrellm.tokenizer as tok_mod
150
+ from umbrellm.tokenizer import EOS_TOKEN
151
+
152
+ input_ids = tok_mod.encode(prompt, add_special_tokens=False)
153
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self._device)
154
+
155
+ temperature = params.get("temperature", self._cfg.get("temperature", 0.8))
156
+ top_k = params.get("top_k", self._cfg.get("top_k", 50))
157
+ max_new_tokens = params.get("max_tokens", self._cfg.get("max_new_tokens", 256))
158
+ stop_ids = set()
159
+ eos_id = tok_mod._tokenizer.vocab.get(EOS_TOKEN)
160
+ if eos_id is not None:
161
+ stop_ids.add(eos_id)
162
+
163
+ generated = input_tensor.clone()
164
+ with torch.no_grad():
165
+ for _ in range(max_new_tokens):
166
+ ctx = generated if generated.shape[1] <= self._model.max_seq_len else generated[:, -self._model.max_seq_len:]
167
+ logits, _ = self._model(ctx)
168
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
169
+ if top_k > 0:
170
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
171
+ logits[logits < v[:, [-1]]] = float("-inf")
172
+ probs = F.softmax(logits, dim=-1)
173
+ next_tok = torch.multinomial(probs, num_samples=1)
174
+ tok_id = next_tok.item()
175
+ if tok_id in stop_ids:
176
+ break
177
+ generated = torch.cat([generated, next_tok], dim=1)
178
+ token_text = tok_mod.decode([tok_id], skip_special_tokens=True)
179
+ yield token_text
umbrellm/api.py ADDED
@@ -0,0 +1,105 @@
1
+ """
2
+ Umbrellm REST API server (FastAPI).
3
+
4
+ Run:
5
+ python -m umbrellm.api
6
+ or:
7
+ uvicorn umbrellm.api:app --host 0.0.0.0 --port 8080
8
+ """
9
+
10
+ import time
11
+ from typing import Optional, List
12
+
13
+ try:
14
+ from fastapi import FastAPI, HTTPException
15
+ from fastapi.responses import StreamingResponse
16
+ from pydantic import BaseModel
17
+ _FASTAPI_OK = True
18
+ except ImportError:
19
+ _FASTAPI_OK = False
20
+
21
+ if _FASTAPI_OK:
22
+ app = FastAPI(
23
+ title="Umbrellm API",
24
+ description="REST API for the Umbrellm language model.",
25
+ version="0.1.0",
26
+ )
27
+
28
+ class ChatRequest(BaseModel):
29
+ user: str
30
+ model: str = "umbrellm-20m"
31
+ history: Optional[List[dict]] = None
32
+ system_prompt: Optional[str] = None
33
+ temperature: Optional[float] = None
34
+ top_k: Optional[int] = None
35
+ top_p: Optional[float] = None
36
+ repetition_penalty: Optional[float] = None
37
+ max_tokens: Optional[int] = None
38
+ stop_sequences: Optional[List[str]] = None
39
+ stream: bool = False
40
+ seed: Optional[int] = None
41
+
42
+ class ChatResponse(BaseModel):
43
+ response: str
44
+ tokens_generated: int
45
+ generation_time: float
46
+ model: str
47
+
48
+ @app.get("/")
49
+ def root():
50
+ return {"service": "Umbrellm API", "version": "0.1.0"}
51
+
52
+ @app.get("/health")
53
+ def health():
54
+ return {"status": "ok", "timestamp": time.time()}
55
+
56
+ @app.get("/models")
57
+ def list_models():
58
+ import umbrellm.models as mdl
59
+ checkpoints = mdl.list()
60
+ return {"models": ["umbrellm-20m"] + checkpoints}
61
+
62
+ @app.post("/chat", response_model=ChatResponse)
63
+ def chat(req: ChatRequest):
64
+ import umbrellm.interaction as inter
65
+ params = req.dict(exclude_none=True)
66
+ if req.stream:
67
+ def _gen():
68
+ for tok in inter.stream_chat(params):
69
+ yield tok
70
+ return StreamingResponse(_gen(), media_type="text/plain")
71
+ result = inter.chat(params)
72
+ return result
73
+
74
+ @app.post("/tokenize")
75
+ def tokenize(body: dict):
76
+ import umbrellm.tokenizer as tok
77
+ text = body.get("text", "")
78
+ ids = tok.encode(text)
79
+ return {"tokens": ids, "count": len(ids)}
80
+
81
+ @app.post("/detokenize")
82
+ def detokenize(body: dict):
83
+ import umbrellm.tokenizer as tok
84
+ ids = body.get("ids", [])
85
+ text = tok.decode(ids)
86
+ return {"text": text}
87
+
88
+
89
+ def main():
90
+ if not _FASTAPI_OK:
91
+ print("[api] FastAPI/uvicorn not installed. Run: pip install fastapi uvicorn")
92
+ return
93
+ import uvicorn
94
+ from umbrellm.config import load as load_cfg
95
+ cfg = load_cfg()
96
+ uvicorn.run(
97
+ "umbrellm.api:app",
98
+ host=cfg.get("api_host", "0.0.0.0"),
99
+ port=cfg.get("api_port", 8080),
100
+ reload=False,
101
+ )
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()
umbrellm/config.py ADDED
@@ -0,0 +1,66 @@
1
+ """
2
+ Configuration utilities for Umbrellm.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+
9
+ _DEFAULT_CONFIG_PATH = "config/umbrellm.json"
10
+
11
+ DEFAULT_CONFIG = {
12
+ "model_name": "umbrellm-20m",
13
+ "vocab_size": 32000,
14
+ "d_model": 512,
15
+ "n_heads": 8,
16
+ "n_layers": 12,
17
+ "d_ff": 2048,
18
+ "max_seq_len": 1024,
19
+ "dropout": 0.1,
20
+ "learning_rate": 3e-4,
21
+ "batch_size": 4,
22
+ "epochs": 3,
23
+ "checkpoint_dir": "checkpoints",
24
+ "data_dir": "data",
25
+ "log_dir": "logs",
26
+ "device": "auto",
27
+ "seed": 42,
28
+ "mixed_precision": False,
29
+ "gradient_clip": 1.0,
30
+ "warmup_steps": 100,
31
+ "save_every": 500,
32
+ "eval_every": 250,
33
+ "api_host": "0.0.0.0",
34
+ "api_port": 8080,
35
+ "temperature": 0.8,
36
+ "top_k": 50,
37
+ "top_p": 0.95,
38
+ "repetition_penalty": 1.1,
39
+ "max_new_tokens": 256,
40
+ }
41
+
42
+
43
+ def load(path: str = _DEFAULT_CONFIG_PATH) -> dict:
44
+ """Load configuration from disk; falls back to defaults."""
45
+ p = Path(path)
46
+ if not p.exists():
47
+ return dict(DEFAULT_CONFIG)
48
+ with open(p) as f:
49
+ user_cfg = json.load(f)
50
+ merged = dict(DEFAULT_CONFIG)
51
+ merged.update(user_cfg)
52
+ return merged
53
+
54
+
55
+ def save(cfg: dict, path: str = _DEFAULT_CONFIG_PATH) -> None:
56
+ """Persist configuration to disk."""
57
+ p = Path(path)
58
+ p.parent.mkdir(parents=True, exist_ok=True)
59
+ with open(p, "w") as f:
60
+ json.dump(cfg, f, indent=2)
61
+ print(f"[config] Saved to {p}")
62
+
63
+
64
+ def get(key: str, default=None):
65
+ """Get a single config value."""
66
+ return load().get(key, default)
umbrellm/device.py ADDED
@@ -0,0 +1,30 @@
1
+ """
2
+ Device utilities for Umbrellm.
3
+ """
4
+
5
+
6
+ def current() -> str:
7
+ """
8
+ Return the best available device string.
9
+
10
+ Returns:
11
+ "cuda", "mps", or "cpu"
12
+ """
13
+ try:
14
+ import torch
15
+ if torch.cuda.is_available():
16
+ return "cuda"
17
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
18
+ return "mps"
19
+ return "cpu"
20
+ except ImportError:
21
+ return "cpu"
22
+
23
+
24
+ def torch_device():
25
+ """Return a torch.device object for the best available device."""
26
+ try:
27
+ import torch
28
+ return torch.device(current())
29
+ except ImportError:
30
+ raise RuntimeError("PyTorch is required. Install with: pip install torch")
@@ -0,0 +1,60 @@
1
+ """
2
+ Chat and streaming interaction interface.
3
+ """
4
+
5
+ import time
6
+ from typing import Optional, List, Iterator, Union
7
+
8
+ _RUNTIME = None
9
+
10
+
11
+ def _get_runtime():
12
+ global _RUNTIME
13
+ if _RUNTIME is None:
14
+ from umbrellm._runtime import InferenceRuntime
15
+ _RUNTIME = InferenceRuntime()
16
+ return _RUNTIME
17
+
18
+
19
+ def chat(params: dict) -> dict:
20
+ """
21
+ Send a chat message and return a response dict.
22
+
23
+ Required keys:
24
+ user (str): The user message.
25
+
26
+ Optional keys:
27
+ model (str): Model name or checkpoint path.
28
+ history (list): List of {"user": ..., "assistant": ...} dicts.
29
+ system_prompt (str): System instruction.
30
+ temperature (float): Sampling temperature.
31
+ top_k (int): Top-K sampling.
32
+ top_p (float): Nucleus sampling threshold.
33
+ repetition_penalty (float): Repetition penalty.
34
+ max_tokens (int): Maximum new tokens.
35
+ stop_sequences (list[str]): Stop strings.
36
+ stream (bool): Enable streaming (returns generator).
37
+ seed (int): RNG seed.
38
+
39
+ Returns:
40
+ {
41
+ "response": str,
42
+ "tokens_generated": int,
43
+ "generation_time": float,
44
+ "model": str
45
+ }
46
+ """
47
+ rt = _get_runtime()
48
+ return rt.chat(params)
49
+
50
+
51
+ def stream_chat(params: dict) -> Iterator[str]:
52
+ """
53
+ Stream tokens from a chat response.
54
+
55
+ Example:
56
+ for token in umbrellm.interaction.stream_chat({"user": "Hello"}):
57
+ print(token, end="", flush=True)
58
+ """
59
+ rt = _get_runtime()
60
+ yield from rt.stream_chat(params)