umbrellm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- umbrellm/__init__.py +19 -0
- umbrellm/__main__.py +75 -0
- umbrellm/_generate.py +30 -0
- umbrellm/_runtime.py +179 -0
- umbrellm/api.py +105 -0
- umbrellm/config.py +66 -0
- umbrellm/device.py +30 -0
- umbrellm/interaction.py +60 -0
- umbrellm/model.py +268 -0
- umbrellm/models.py +83 -0
- umbrellm/tokenizer.py +273 -0
- umbrellm/training.py +395 -0
- umbrellm/utllm.py +159 -0
- umbrellm-0.1.0.dist-info/METADATA +90 -0
- umbrellm-0.1.0.dist-info/RECORD +18 -0
- umbrellm-0.1.0.dist-info/WHEEL +5 -0
- umbrellm-0.1.0.dist-info/entry_points.txt +2 -0
- umbrellm-0.1.0.dist-info/top_level.txt +1 -0
umbrellm/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Umbrellm – a lightweight, trainable language model SDK.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from umbrellm import interaction, tokenizer, models, training, utllm, config, device
|
|
6
|
+
from umbrellm._generate import generate, complete
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
__all__ = [
|
|
10
|
+
"interaction",
|
|
11
|
+
"tokenizer",
|
|
12
|
+
"models",
|
|
13
|
+
"training",
|
|
14
|
+
"utllm",
|
|
15
|
+
"config",
|
|
16
|
+
"device",
|
|
17
|
+
"generate",
|
|
18
|
+
"complete",
|
|
19
|
+
]
|
umbrellm/__main__.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
python -m umbrellm entry-point.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def main():
|
|
10
|
+
parser = argparse.ArgumentParser(prog="umbrellm")
|
|
11
|
+
sub = parser.add_subparsers(dest="cmd")
|
|
12
|
+
|
|
13
|
+
# train
|
|
14
|
+
p_train = sub.add_parser("train", help="Train the model")
|
|
15
|
+
p_train.add_argument("--epochs", type=int, default=3)
|
|
16
|
+
p_train.add_argument("--batch-size", type=int, default=4)
|
|
17
|
+
p_train.add_argument("--lr", type=float, default=3e-4)
|
|
18
|
+
|
|
19
|
+
# chat
|
|
20
|
+
p_chat = sub.add_parser("chat", help="Interactive chat")
|
|
21
|
+
p_chat.add_argument("--model", default="umbrellm-20m")
|
|
22
|
+
|
|
23
|
+
# compile
|
|
24
|
+
sub.add_parser("compile", help="Compile UTLLM dataset")
|
|
25
|
+
|
|
26
|
+
# stats
|
|
27
|
+
sub.add_parser("stats", help="Dataset statistics")
|
|
28
|
+
|
|
29
|
+
# serve
|
|
30
|
+
sub.add_parser("serve", help="Start API server")
|
|
31
|
+
|
|
32
|
+
args = parser.parse_args()
|
|
33
|
+
|
|
34
|
+
if args.cmd == "train":
|
|
35
|
+
import umbrellm.training as tr
|
|
36
|
+
tr.train({"epochs": args.epochs, "batch_size": args.batch_size,
|
|
37
|
+
"learning_rate": args.lr})
|
|
38
|
+
|
|
39
|
+
elif args.cmd == "chat":
|
|
40
|
+
import umbrellm.interaction as inter
|
|
41
|
+
print(f"Umbrellm interactive chat (model: {args.model}). Type exit to quit.")
|
|
42
|
+
history = []
|
|
43
|
+
while True:
|
|
44
|
+
try:
|
|
45
|
+
user = input("You: ").strip()
|
|
46
|
+
except (EOFError, KeyboardInterrupt):
|
|
47
|
+
break
|
|
48
|
+
if user.lower() in ("exit", "quit", "q"):
|
|
49
|
+
break
|
|
50
|
+
if not user:
|
|
51
|
+
continue
|
|
52
|
+
res = inter.chat({"user": user, "model": args.model, "history": history})
|
|
53
|
+
print(f"Umbrellm: {res['response']}")
|
|
54
|
+
history.append({"user": user, "assistant": res["response"]})
|
|
55
|
+
|
|
56
|
+
elif args.cmd == "compile":
|
|
57
|
+
import umbrellm.utllm as ut
|
|
58
|
+
ut.compile()
|
|
59
|
+
|
|
60
|
+
elif args.cmd == "stats":
|
|
61
|
+
import umbrellm.utllm as ut
|
|
62
|
+
import json
|
|
63
|
+
s = ut.stats()
|
|
64
|
+
print(json.dumps(s, indent=2))
|
|
65
|
+
|
|
66
|
+
elif args.cmd == "serve":
|
|
67
|
+
from umbrellm.api import main as serve_main
|
|
68
|
+
serve_main()
|
|
69
|
+
|
|
70
|
+
else:
|
|
71
|
+
parser.print_help()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
main()
|
umbrellm/_generate.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Top-level generate / complete helpers.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from umbrellm import interaction as _interaction
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def generate(prompt: str, **kwargs) -> str:
|
|
9
|
+
"""
|
|
10
|
+
Generate text from a prompt.
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
text = umbrellm.generate("Write a poem about stars.")
|
|
14
|
+
"""
|
|
15
|
+
kwargs.setdefault("model", "umbrellm-20m")
|
|
16
|
+
result = _interaction.chat({"user": prompt, **kwargs})
|
|
17
|
+
return result["response"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def complete(prompt: str, **kwargs) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Complete a text prefix.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
text = umbrellm.complete("The meaning of life is")
|
|
26
|
+
"""
|
|
27
|
+
kwargs.setdefault("model", "umbrellm-20m")
|
|
28
|
+
kwargs.setdefault("system_prompt", "You are a helpful text completion engine. Continue the text naturally.")
|
|
29
|
+
result = _interaction.chat({"user": prompt, **kwargs})
|
|
30
|
+
return result["response"]
|
umbrellm/_runtime.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Internal inference runtime.
|
|
3
|
+
Handles model loading, prompt formatting, and generation.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Iterator, Optional
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_SYSTEM_DEFAULT = (
|
|
13
|
+
"You are Umbrellm, a helpful AI assistant. "
|
|
14
|
+
"Respond clearly, concisely, and helpfully."
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class InferenceRuntime:
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self._model = None
|
|
21
|
+
self._cfg = None
|
|
22
|
+
self._device = None
|
|
23
|
+
self._tokenizer_loaded = False
|
|
24
|
+
|
|
25
|
+
def _ensure_ready(self, model_name: str):
|
|
26
|
+
from umbrellm.config import load as load_cfg
|
|
27
|
+
from umbrellm.device import current as cur_device
|
|
28
|
+
import umbrellm.tokenizer as tok_mod
|
|
29
|
+
|
|
30
|
+
if self._cfg is None:
|
|
31
|
+
self._cfg = load_cfg()
|
|
32
|
+
|
|
33
|
+
if self._device is None:
|
|
34
|
+
self._device = cur_device()
|
|
35
|
+
|
|
36
|
+
# Ensure tokenizer is loaded
|
|
37
|
+
if not self._tokenizer_loaded:
|
|
38
|
+
try:
|
|
39
|
+
tok_mod.load()
|
|
40
|
+
except Exception:
|
|
41
|
+
pass
|
|
42
|
+
self._tokenizer_loaded = True
|
|
43
|
+
|
|
44
|
+
if self._model is None:
|
|
45
|
+
self._model = self._load_model(model_name)
|
|
46
|
+
|
|
47
|
+
def _load_model(self, model_name: str):
|
|
48
|
+
try:
|
|
49
|
+
import torch
|
|
50
|
+
from umbrellm.model import build_model
|
|
51
|
+
import umbrellm.models as mdl_mod
|
|
52
|
+
|
|
53
|
+
# Try loading existing checkpoint
|
|
54
|
+
ckpt_dir = self._cfg.get("checkpoint_dir", "checkpoints")
|
|
55
|
+
candidates = [
|
|
56
|
+
os.path.join(ckpt_dir, "latest.pt"),
|
|
57
|
+
os.path.join(ckpt_dir, "best.pt"),
|
|
58
|
+
]
|
|
59
|
+
# Also check if model_name is a path
|
|
60
|
+
if os.path.exists(model_name):
|
|
61
|
+
candidates.insert(0, model_name)
|
|
62
|
+
|
|
63
|
+
for path in candidates:
|
|
64
|
+
if os.path.exists(path):
|
|
65
|
+
model = mdl_mod.load(path)
|
|
66
|
+
return model.to(self._device)
|
|
67
|
+
|
|
68
|
+
# No checkpoint found – build a fresh (random) model
|
|
69
|
+
print(f"[runtime] No checkpoint found. Building fresh model for inference.")
|
|
70
|
+
model = build_model(self._cfg)
|
|
71
|
+
model.eval()
|
|
72
|
+
return model.to(self._device)
|
|
73
|
+
|
|
74
|
+
except ImportError as e:
|
|
75
|
+
raise RuntimeError(
|
|
76
|
+
"PyTorch is required. Install with: pip install torch"
|
|
77
|
+
) from e
|
|
78
|
+
|
|
79
|
+
def _build_prompt(self, params: dict) -> str:
|
|
80
|
+
system = params.get("system_prompt", _SYSTEM_DEFAULT)
|
|
81
|
+
history = params.get("history", [])
|
|
82
|
+
user = params.get("user", "")
|
|
83
|
+
|
|
84
|
+
parts = [f"<|system|>{system}<|assistant|>"]
|
|
85
|
+
for turn in history:
|
|
86
|
+
parts.append(f"<|user|>{turn.get('user', '')}")
|
|
87
|
+
parts.append(f"<|assistant|>{turn.get('assistant', '')}")
|
|
88
|
+
parts.append(f"<|user|>{user}")
|
|
89
|
+
parts.append("<|assistant|>")
|
|
90
|
+
return "".join(parts)
|
|
91
|
+
|
|
92
|
+
def chat(self, params: dict) -> dict:
|
|
93
|
+
model_name = params.get("model", "umbrellm-20m")
|
|
94
|
+
self._ensure_ready(model_name)
|
|
95
|
+
|
|
96
|
+
prompt = self._build_prompt(params)
|
|
97
|
+
start = time.time()
|
|
98
|
+
|
|
99
|
+
import torch
|
|
100
|
+
import umbrellm.tokenizer as tok_mod
|
|
101
|
+
from umbrellm.tokenizer import ASSISTANT_TOKEN, EOS_TOKEN
|
|
102
|
+
|
|
103
|
+
input_ids = tok_mod.encode(prompt, add_special_tokens=False)
|
|
104
|
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self._device)
|
|
105
|
+
|
|
106
|
+
stop_ids = []
|
|
107
|
+
eos_id = tok_mod._tokenizer.vocab.get(EOS_TOKEN)
|
|
108
|
+
if eos_id is not None:
|
|
109
|
+
stop_ids.append(eos_id)
|
|
110
|
+
|
|
111
|
+
with torch.no_grad():
|
|
112
|
+
output = self._model.generate(
|
|
113
|
+
input_tensor,
|
|
114
|
+
max_new_tokens=params.get("max_tokens", self._cfg.get("max_new_tokens", 256)),
|
|
115
|
+
temperature=params.get("temperature", self._cfg.get("temperature", 0.8)),
|
|
116
|
+
top_k=params.get("top_k", self._cfg.get("top_k", 50)),
|
|
117
|
+
top_p=params.get("top_p", self._cfg.get("top_p", 0.95)),
|
|
118
|
+
repetition_penalty=params.get("repetition_penalty",
|
|
119
|
+
self._cfg.get("repetition_penalty", 1.1)),
|
|
120
|
+
stop_ids=stop_ids or None,
|
|
121
|
+
seed=params.get("seed"),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
new_ids = output[0, len(input_ids):].tolist()
|
|
125
|
+
response = tok_mod.decode(new_ids, skip_special_tokens=True)
|
|
126
|
+
|
|
127
|
+
# Trim at stop sequences
|
|
128
|
+
for seq in params.get("stop_sequences", []):
|
|
129
|
+
idx = response.find(seq)
|
|
130
|
+
if idx != -1:
|
|
131
|
+
response = response[:idx]
|
|
132
|
+
|
|
133
|
+
elapsed = time.time() - start
|
|
134
|
+
return {
|
|
135
|
+
"response": response.strip(),
|
|
136
|
+
"tokens_generated": len(new_ids),
|
|
137
|
+
"generation_time": round(elapsed, 4),
|
|
138
|
+
"model": model_name,
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
def stream_chat(self, params: dict) -> Iterator[str]:
|
|
142
|
+
"""Yield tokens one-by-one using a simple greedy decode loop."""
|
|
143
|
+
model_name = params.get("model", "umbrellm-20m")
|
|
144
|
+
self._ensure_ready(model_name)
|
|
145
|
+
|
|
146
|
+
prompt = self._build_prompt(params)
|
|
147
|
+
import torch
|
|
148
|
+
import torch.nn.functional as F
|
|
149
|
+
import umbrellm.tokenizer as tok_mod
|
|
150
|
+
from umbrellm.tokenizer import EOS_TOKEN
|
|
151
|
+
|
|
152
|
+
input_ids = tok_mod.encode(prompt, add_special_tokens=False)
|
|
153
|
+
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self._device)
|
|
154
|
+
|
|
155
|
+
temperature = params.get("temperature", self._cfg.get("temperature", 0.8))
|
|
156
|
+
top_k = params.get("top_k", self._cfg.get("top_k", 50))
|
|
157
|
+
max_new_tokens = params.get("max_tokens", self._cfg.get("max_new_tokens", 256))
|
|
158
|
+
stop_ids = set()
|
|
159
|
+
eos_id = tok_mod._tokenizer.vocab.get(EOS_TOKEN)
|
|
160
|
+
if eos_id is not None:
|
|
161
|
+
stop_ids.add(eos_id)
|
|
162
|
+
|
|
163
|
+
generated = input_tensor.clone()
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
for _ in range(max_new_tokens):
|
|
166
|
+
ctx = generated if generated.shape[1] <= self._model.max_seq_len else generated[:, -self._model.max_seq_len:]
|
|
167
|
+
logits, _ = self._model(ctx)
|
|
168
|
+
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
|
169
|
+
if top_k > 0:
|
|
170
|
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
|
171
|
+
logits[logits < v[:, [-1]]] = float("-inf")
|
|
172
|
+
probs = F.softmax(logits, dim=-1)
|
|
173
|
+
next_tok = torch.multinomial(probs, num_samples=1)
|
|
174
|
+
tok_id = next_tok.item()
|
|
175
|
+
if tok_id in stop_ids:
|
|
176
|
+
break
|
|
177
|
+
generated = torch.cat([generated, next_tok], dim=1)
|
|
178
|
+
token_text = tok_mod.decode([tok_id], skip_special_tokens=True)
|
|
179
|
+
yield token_text
|
umbrellm/api.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Umbrellm REST API server (FastAPI).
|
|
3
|
+
|
|
4
|
+
Run:
|
|
5
|
+
python -m umbrellm.api
|
|
6
|
+
or:
|
|
7
|
+
uvicorn umbrellm.api:app --host 0.0.0.0 --port 8080
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import time
|
|
11
|
+
from typing import Optional, List
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from fastapi import FastAPI, HTTPException
|
|
15
|
+
from fastapi.responses import StreamingResponse
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
_FASTAPI_OK = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
_FASTAPI_OK = False
|
|
20
|
+
|
|
21
|
+
if _FASTAPI_OK:
|
|
22
|
+
app = FastAPI(
|
|
23
|
+
title="Umbrellm API",
|
|
24
|
+
description="REST API for the Umbrellm language model.",
|
|
25
|
+
version="0.1.0",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
class ChatRequest(BaseModel):
|
|
29
|
+
user: str
|
|
30
|
+
model: str = "umbrellm-20m"
|
|
31
|
+
history: Optional[List[dict]] = None
|
|
32
|
+
system_prompt: Optional[str] = None
|
|
33
|
+
temperature: Optional[float] = None
|
|
34
|
+
top_k: Optional[int] = None
|
|
35
|
+
top_p: Optional[float] = None
|
|
36
|
+
repetition_penalty: Optional[float] = None
|
|
37
|
+
max_tokens: Optional[int] = None
|
|
38
|
+
stop_sequences: Optional[List[str]] = None
|
|
39
|
+
stream: bool = False
|
|
40
|
+
seed: Optional[int] = None
|
|
41
|
+
|
|
42
|
+
class ChatResponse(BaseModel):
|
|
43
|
+
response: str
|
|
44
|
+
tokens_generated: int
|
|
45
|
+
generation_time: float
|
|
46
|
+
model: str
|
|
47
|
+
|
|
48
|
+
@app.get("/")
|
|
49
|
+
def root():
|
|
50
|
+
return {"service": "Umbrellm API", "version": "0.1.0"}
|
|
51
|
+
|
|
52
|
+
@app.get("/health")
|
|
53
|
+
def health():
|
|
54
|
+
return {"status": "ok", "timestamp": time.time()}
|
|
55
|
+
|
|
56
|
+
@app.get("/models")
|
|
57
|
+
def list_models():
|
|
58
|
+
import umbrellm.models as mdl
|
|
59
|
+
checkpoints = mdl.list()
|
|
60
|
+
return {"models": ["umbrellm-20m"] + checkpoints}
|
|
61
|
+
|
|
62
|
+
@app.post("/chat", response_model=ChatResponse)
|
|
63
|
+
def chat(req: ChatRequest):
|
|
64
|
+
import umbrellm.interaction as inter
|
|
65
|
+
params = req.dict(exclude_none=True)
|
|
66
|
+
if req.stream:
|
|
67
|
+
def _gen():
|
|
68
|
+
for tok in inter.stream_chat(params):
|
|
69
|
+
yield tok
|
|
70
|
+
return StreamingResponse(_gen(), media_type="text/plain")
|
|
71
|
+
result = inter.chat(params)
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
@app.post("/tokenize")
|
|
75
|
+
def tokenize(body: dict):
|
|
76
|
+
import umbrellm.tokenizer as tok
|
|
77
|
+
text = body.get("text", "")
|
|
78
|
+
ids = tok.encode(text)
|
|
79
|
+
return {"tokens": ids, "count": len(ids)}
|
|
80
|
+
|
|
81
|
+
@app.post("/detokenize")
|
|
82
|
+
def detokenize(body: dict):
|
|
83
|
+
import umbrellm.tokenizer as tok
|
|
84
|
+
ids = body.get("ids", [])
|
|
85
|
+
text = tok.decode(ids)
|
|
86
|
+
return {"text": text}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def main():
|
|
90
|
+
if not _FASTAPI_OK:
|
|
91
|
+
print("[api] FastAPI/uvicorn not installed. Run: pip install fastapi uvicorn")
|
|
92
|
+
return
|
|
93
|
+
import uvicorn
|
|
94
|
+
from umbrellm.config import load as load_cfg
|
|
95
|
+
cfg = load_cfg()
|
|
96
|
+
uvicorn.run(
|
|
97
|
+
"umbrellm.api:app",
|
|
98
|
+
host=cfg.get("api_host", "0.0.0.0"),
|
|
99
|
+
port=cfg.get("api_port", 8080),
|
|
100
|
+
reload=False,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
if __name__ == "__main__":
|
|
105
|
+
main()
|
umbrellm/config.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for Umbrellm.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
_DEFAULT_CONFIG_PATH = "config/umbrellm.json"
|
|
10
|
+
|
|
11
|
+
DEFAULT_CONFIG = {
|
|
12
|
+
"model_name": "umbrellm-20m",
|
|
13
|
+
"vocab_size": 32000,
|
|
14
|
+
"d_model": 512,
|
|
15
|
+
"n_heads": 8,
|
|
16
|
+
"n_layers": 12,
|
|
17
|
+
"d_ff": 2048,
|
|
18
|
+
"max_seq_len": 1024,
|
|
19
|
+
"dropout": 0.1,
|
|
20
|
+
"learning_rate": 3e-4,
|
|
21
|
+
"batch_size": 4,
|
|
22
|
+
"epochs": 3,
|
|
23
|
+
"checkpoint_dir": "checkpoints",
|
|
24
|
+
"data_dir": "data",
|
|
25
|
+
"log_dir": "logs",
|
|
26
|
+
"device": "auto",
|
|
27
|
+
"seed": 42,
|
|
28
|
+
"mixed_precision": False,
|
|
29
|
+
"gradient_clip": 1.0,
|
|
30
|
+
"warmup_steps": 100,
|
|
31
|
+
"save_every": 500,
|
|
32
|
+
"eval_every": 250,
|
|
33
|
+
"api_host": "0.0.0.0",
|
|
34
|
+
"api_port": 8080,
|
|
35
|
+
"temperature": 0.8,
|
|
36
|
+
"top_k": 50,
|
|
37
|
+
"top_p": 0.95,
|
|
38
|
+
"repetition_penalty": 1.1,
|
|
39
|
+
"max_new_tokens": 256,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def load(path: str = _DEFAULT_CONFIG_PATH) -> dict:
|
|
44
|
+
"""Load configuration from disk; falls back to defaults."""
|
|
45
|
+
p = Path(path)
|
|
46
|
+
if not p.exists():
|
|
47
|
+
return dict(DEFAULT_CONFIG)
|
|
48
|
+
with open(p) as f:
|
|
49
|
+
user_cfg = json.load(f)
|
|
50
|
+
merged = dict(DEFAULT_CONFIG)
|
|
51
|
+
merged.update(user_cfg)
|
|
52
|
+
return merged
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def save(cfg: dict, path: str = _DEFAULT_CONFIG_PATH) -> None:
|
|
56
|
+
"""Persist configuration to disk."""
|
|
57
|
+
p = Path(path)
|
|
58
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
59
|
+
with open(p, "w") as f:
|
|
60
|
+
json.dump(cfg, f, indent=2)
|
|
61
|
+
print(f"[config] Saved to {p}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get(key: str, default=None):
|
|
65
|
+
"""Get a single config value."""
|
|
66
|
+
return load().get(key, default)
|
umbrellm/device.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Device utilities for Umbrellm.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def current() -> str:
|
|
7
|
+
"""
|
|
8
|
+
Return the best available device string.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
"cuda", "mps", or "cpu"
|
|
12
|
+
"""
|
|
13
|
+
try:
|
|
14
|
+
import torch
|
|
15
|
+
if torch.cuda.is_available():
|
|
16
|
+
return "cuda"
|
|
17
|
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
18
|
+
return "mps"
|
|
19
|
+
return "cpu"
|
|
20
|
+
except ImportError:
|
|
21
|
+
return "cpu"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def torch_device():
|
|
25
|
+
"""Return a torch.device object for the best available device."""
|
|
26
|
+
try:
|
|
27
|
+
import torch
|
|
28
|
+
return torch.device(current())
|
|
29
|
+
except ImportError:
|
|
30
|
+
raise RuntimeError("PyTorch is required. Install with: pip install torch")
|
umbrellm/interaction.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chat and streaming interaction interface.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import Optional, List, Iterator, Union
|
|
7
|
+
|
|
8
|
+
_RUNTIME = None
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_runtime():
|
|
12
|
+
global _RUNTIME
|
|
13
|
+
if _RUNTIME is None:
|
|
14
|
+
from umbrellm._runtime import InferenceRuntime
|
|
15
|
+
_RUNTIME = InferenceRuntime()
|
|
16
|
+
return _RUNTIME
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def chat(params: dict) -> dict:
|
|
20
|
+
"""
|
|
21
|
+
Send a chat message and return a response dict.
|
|
22
|
+
|
|
23
|
+
Required keys:
|
|
24
|
+
user (str): The user message.
|
|
25
|
+
|
|
26
|
+
Optional keys:
|
|
27
|
+
model (str): Model name or checkpoint path.
|
|
28
|
+
history (list): List of {"user": ..., "assistant": ...} dicts.
|
|
29
|
+
system_prompt (str): System instruction.
|
|
30
|
+
temperature (float): Sampling temperature.
|
|
31
|
+
top_k (int): Top-K sampling.
|
|
32
|
+
top_p (float): Nucleus sampling threshold.
|
|
33
|
+
repetition_penalty (float): Repetition penalty.
|
|
34
|
+
max_tokens (int): Maximum new tokens.
|
|
35
|
+
stop_sequences (list[str]): Stop strings.
|
|
36
|
+
stream (bool): Enable streaming (returns generator).
|
|
37
|
+
seed (int): RNG seed.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
{
|
|
41
|
+
"response": str,
|
|
42
|
+
"tokens_generated": int,
|
|
43
|
+
"generation_time": float,
|
|
44
|
+
"model": str
|
|
45
|
+
}
|
|
46
|
+
"""
|
|
47
|
+
rt = _get_runtime()
|
|
48
|
+
return rt.chat(params)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def stream_chat(params: dict) -> Iterator[str]:
|
|
52
|
+
"""
|
|
53
|
+
Stream tokens from a chat response.
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
for token in umbrellm.interaction.stream_chat({"user": "Hello"}):
|
|
57
|
+
print(token, end="", flush=True)
|
|
58
|
+
"""
|
|
59
|
+
rt = _get_runtime()
|
|
60
|
+
yield from rt.stream_chat(params)
|