lemonade-sdk 7.0.0__py3-none-any.whl → 7.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lemonade-sdk might be problematic. Click here for more details.
- lemonade/cli.py +2 -0
- lemonade/tools/accuracy.py +335 -0
- lemonade/tools/server/instructions.py +294 -0
- lemonade/tools/server/llamacpp.py +315 -0
- lemonade/tools/server/port_utils.py +57 -0
- lemonade/tools/server/pydantic_models.py +83 -0
- lemonade/tools/server/serve.py +225 -167
- lemonade/tools/server/static/styles.css +313 -0
- lemonade/tools/server/thread_utils.py +87 -0
- lemonade/tools/server/tool_calls.py +50 -43
- lemonade/version.py +1 -1
- {lemonade_sdk-7.0.0.dist-info → lemonade_sdk-7.0.2.dist-info}/METADATA +4 -7
- {lemonade_sdk-7.0.0.dist-info → lemonade_sdk-7.0.2.dist-info}/RECORD +21 -14
- {lemonade_sdk-7.0.0.dist-info → lemonade_sdk-7.0.2.dist-info}/WHEEL +1 -1
- lemonade_server/cli.py +4 -2
- lemonade_server/model_manager.py +34 -17
- lemonade_server/server_models.json +52 -3
- {lemonade_sdk-7.0.0.dist-info → lemonade_sdk-7.0.2.dist-info}/entry_points.txt +0 -0
- {lemonade_sdk-7.0.0.dist-info → lemonade_sdk-7.0.2.dist-info}/licenses/LICENSE +0 -0
- {lemonade_sdk-7.0.0.dist-info → lemonade_sdk-7.0.2.dist-info}/licenses/NOTICE.md +0 -0
- {lemonade_sdk-7.0.0.dist-info → lemonade_sdk-7.0.2.dist-info}/top_level.txt +0 -0
lemonade/cli.py
CHANGED
|
@@ -19,6 +19,7 @@ import lemonade.cache as cache
|
|
|
19
19
|
from lemonade.tools.mmlu import AccuracyMMLU
|
|
20
20
|
from lemonade.tools.humaneval import AccuracyHumaneval
|
|
21
21
|
from lemonade.tools.perplexity import AccuracyPerplexity
|
|
22
|
+
from lemonade.tools.accuracy import LMEvalHarness
|
|
22
23
|
from lemonade.tools.prompt import LLMPrompt
|
|
23
24
|
from lemonade.tools.quark.quark_load import QuarkLoad
|
|
24
25
|
from lemonade.tools.quark.quark_quantize import QuarkQuantize
|
|
@@ -36,6 +37,7 @@ def main():
|
|
|
36
37
|
AccuracyMMLU,
|
|
37
38
|
AccuracyHumaneval,
|
|
38
39
|
AccuracyPerplexity,
|
|
40
|
+
LMEvalHarness,
|
|
39
41
|
LLMPrompt,
|
|
40
42
|
HuggingfaceBench,
|
|
41
43
|
OgaBench,
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
import requests
|
|
11
|
+
|
|
12
|
+
from lemonade.state import State
|
|
13
|
+
from lemonade.tools import Tool
|
|
14
|
+
import lemonade.common.printing as printing
|
|
15
|
+
import lemonade.common.build as build
|
|
16
|
+
|
|
17
|
+
from lemonade.tools.server.thread_utils import ServerRunner
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def is_port_in_use(port, host="localhost"):
|
|
21
|
+
"""
|
|
22
|
+
Check if a port is in use
|
|
23
|
+
"""
|
|
24
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
25
|
+
return s.connect_ex((host, port)) == 0
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LMEvalHarness(Tool):
|
|
29
|
+
"""
|
|
30
|
+
Tool for evaluating LLMs using lm-eval-harness on industry standard benchmarks
|
|
31
|
+
like MMLU, GSM8k, and more. See docs/lemonade/lm_eval.md for more details.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
unique_name = "lm-eval-harness"
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
super().__init__(
|
|
38
|
+
monitor_message="Evaluate model accuracy using ElutherAI's lm-eval-harness"
|
|
39
|
+
)
|
|
40
|
+
self.status_stats = []
|
|
41
|
+
self.server_runner = None
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def parser(add_help: bool = True) -> argparse.ArgumentParser:
|
|
45
|
+
parser = __class__.helpful_parser(
|
|
46
|
+
short_description="Evaluate model using lm-eval-harness",
|
|
47
|
+
add_help=add_help,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
parser.add_argument(
|
|
51
|
+
"--task",
|
|
52
|
+
type=str,
|
|
53
|
+
required=True,
|
|
54
|
+
help="Task(s) to evaluate on (e.g., gsm8k, mmlu)",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--server-port", type=int, default=8000, help="Port to use for the server"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
parser.add_argument(
|
|
62
|
+
"--num-fewshot",
|
|
63
|
+
type=int,
|
|
64
|
+
default=0,
|
|
65
|
+
help="Number of examples in few-shot prompts",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--limit",
|
|
70
|
+
type=int,
|
|
71
|
+
default=None,
|
|
72
|
+
help="Limit the number of examples per task",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
parser.add_argument(
|
|
76
|
+
"--log-samples",
|
|
77
|
+
action="store_true",
|
|
78
|
+
help="Log samples for each task to log file",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--output-path",
|
|
83
|
+
type=str,
|
|
84
|
+
default=None,
|
|
85
|
+
help="Path to save evaluation results",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return parser
|
|
89
|
+
|
|
90
|
+
def _process_results(self, results_dir, state):
|
|
91
|
+
"""Process evaluation results and save to state stats"""
|
|
92
|
+
if not os.path.exists(results_dir) or not os.path.isdir(results_dir):
|
|
93
|
+
printing.log_warning(f"Results directory not found at {results_dir}")
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
model_dirs = [
|
|
97
|
+
d
|
|
98
|
+
for d in os.listdir(results_dir)
|
|
99
|
+
if os.path.isdir(os.path.join(results_dir, d))
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
if not model_dirs:
|
|
103
|
+
printing.log_warning(f"No model directories found in {results_dir}")
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
model_dir = os.path.join(results_dir, model_dirs[0])
|
|
107
|
+
printing.log_info(f"Found model directory: {model_dir}")
|
|
108
|
+
|
|
109
|
+
# Find the results JSON file with timestamp
|
|
110
|
+
results_files = [
|
|
111
|
+
f
|
|
112
|
+
for f in os.listdir(model_dir)
|
|
113
|
+
if f.startswith("results_") and f.endswith(".json")
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
if not results_files:
|
|
117
|
+
printing.log_warning(f"No results files found in {model_dir}")
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
# Sort by timestamp
|
|
121
|
+
results_files.sort(reverse=True)
|
|
122
|
+
results_file_path = os.path.join(model_dir, results_files[0])
|
|
123
|
+
printing.log_info(f"Processing results from {results_file_path}")
|
|
124
|
+
|
|
125
|
+
# Read and process results
|
|
126
|
+
try:
|
|
127
|
+
with open(results_file_path, "r", encoding="utf-8") as f:
|
|
128
|
+
results = json.load(f)
|
|
129
|
+
|
|
130
|
+
# Extract and display metrics
|
|
131
|
+
if "results" in results:
|
|
132
|
+
for task_name, metrics in results["results"].items():
|
|
133
|
+
printing.log_info(f"Results for {task_name}:")
|
|
134
|
+
|
|
135
|
+
for metric, value in metrics.items():
|
|
136
|
+
if isinstance(value, (int, float)) and not metric.startswith(
|
|
137
|
+
"alias"
|
|
138
|
+
):
|
|
139
|
+
# Format metric name for stats
|
|
140
|
+
clean_metric = metric.replace(",", "_")
|
|
141
|
+
stat_name = f"lm_eval_{task_name}_{clean_metric}"
|
|
142
|
+
|
|
143
|
+
# Save to state stats as percentage
|
|
144
|
+
state.save_stat(stat_name, float(value) * 100)
|
|
145
|
+
state.save_stat(f"{stat_name}_units", "%")
|
|
146
|
+
self.status_stats.append(stat_name)
|
|
147
|
+
|
|
148
|
+
printing.log_info(
|
|
149
|
+
f" {metric}: {value:.4f} ({value*100:.2f}%)"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Save summary metrics if available
|
|
153
|
+
avg_metrics = {}
|
|
154
|
+
if "higher_is_better" in results:
|
|
155
|
+
for metric_type in results["higher_is_better"].values():
|
|
156
|
+
for metric in metric_type.keys():
|
|
157
|
+
if metric not in avg_metrics:
|
|
158
|
+
avg_metrics[metric] = []
|
|
159
|
+
|
|
160
|
+
for task_metrics in results["results"].values():
|
|
161
|
+
for metric, value in task_metrics.items():
|
|
162
|
+
if isinstance(value, (int, float)) and not metric.startswith(
|
|
163
|
+
"alias"
|
|
164
|
+
):
|
|
165
|
+
base_metric = metric.split(",")[0]
|
|
166
|
+
if base_metric in avg_metrics:
|
|
167
|
+
avg_metrics[base_metric].append(value)
|
|
168
|
+
|
|
169
|
+
# Calculate and save averages
|
|
170
|
+
for metric, values in avg_metrics.items():
|
|
171
|
+
if values:
|
|
172
|
+
avg_value = sum(values) / len(values)
|
|
173
|
+
stat_name = f"lm_eval_average_{metric}"
|
|
174
|
+
state.save_stat(stat_name, float(avg_value) * 100)
|
|
175
|
+
state.save_stat(f"{stat_name}_units", "%")
|
|
176
|
+
self.status_stats.append(stat_name)
|
|
177
|
+
printing.log_info(
|
|
178
|
+
f"Average {metric}: {avg_value:.4f} ({avg_value*100:.2f}%)"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
except (IOError, json.JSONDecodeError) as e:
|
|
182
|
+
printing.log_error(f"Error processing results: {e}")
|
|
183
|
+
|
|
184
|
+
def run(
|
|
185
|
+
self,
|
|
186
|
+
state: State,
|
|
187
|
+
task: str,
|
|
188
|
+
server_port: int = 8000,
|
|
189
|
+
server_host: str = "localhost",
|
|
190
|
+
num_fewshot: int = 0,
|
|
191
|
+
limit: Optional[int] = None,
|
|
192
|
+
log_samples: bool = False,
|
|
193
|
+
output_path: Optional[str] = None,
|
|
194
|
+
) -> State:
|
|
195
|
+
|
|
196
|
+
model = state.model
|
|
197
|
+
tokenizer = state.tokenizer
|
|
198
|
+
|
|
199
|
+
if model is None or tokenizer is None:
|
|
200
|
+
raise ValueError(
|
|
201
|
+
"Model and tokenizer must be loaded in state before running lm-eval-harness"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Set up output path
|
|
205
|
+
if output_path is None:
|
|
206
|
+
output_path = os.path.join(
|
|
207
|
+
build.output_dir(state.cache_dir, state.build_name), "lm_eval_results"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
os.makedirs(output_path, exist_ok=True)
|
|
211
|
+
|
|
212
|
+
# Check if port is already in use
|
|
213
|
+
if is_port_in_use(server_port, server_host):
|
|
214
|
+
error_msg = (
|
|
215
|
+
f"Port {server_port} is already in use. "
|
|
216
|
+
"Please close all applications using this port and try again."
|
|
217
|
+
)
|
|
218
|
+
printing.log_error(error_msg)
|
|
219
|
+
raise RuntimeError(error_msg)
|
|
220
|
+
|
|
221
|
+
# Retroactively determine recipe based on model type to select correct iterator
|
|
222
|
+
# The model is already loaded in server, so we only need recipe for iterator selection
|
|
223
|
+
checkpoint = getattr(state, "checkpoint", "unknown")
|
|
224
|
+
if "OrtGenaiModel" in str(type(model)):
|
|
225
|
+
recipe = "oga-"
|
|
226
|
+
else:
|
|
227
|
+
recipe = "unknown"
|
|
228
|
+
|
|
229
|
+
# Start the server thread
|
|
230
|
+
self.server_runner = ServerRunner(
|
|
231
|
+
model=model,
|
|
232
|
+
tokenizer=tokenizer,
|
|
233
|
+
checkpoint=checkpoint,
|
|
234
|
+
recipe=recipe,
|
|
235
|
+
host=server_host,
|
|
236
|
+
port=server_port,
|
|
237
|
+
)
|
|
238
|
+
self.server_runner.start()
|
|
239
|
+
|
|
240
|
+
# Wait for server initialization
|
|
241
|
+
printing.log_info("Waiting for server initialization...")
|
|
242
|
+
|
|
243
|
+
# Wait for server to start and be responsive
|
|
244
|
+
server_url = f"http://{server_host}:{server_port}"
|
|
245
|
+
max_retries = 30
|
|
246
|
+
retry_delay = 1
|
|
247
|
+
|
|
248
|
+
printing.log_info(f"Checking if server is available at {server_url}...")
|
|
249
|
+
for i in range(max_retries):
|
|
250
|
+
try:
|
|
251
|
+
response = requests.get(f"{server_url}/api/v0/health", timeout=2)
|
|
252
|
+
if response.status_code == 200:
|
|
253
|
+
printing.log_info(f"Server is ready after {i+1} attempts")
|
|
254
|
+
break
|
|
255
|
+
except requests.exceptions.RequestException:
|
|
256
|
+
if i < max_retries - 1:
|
|
257
|
+
time.sleep(retry_delay)
|
|
258
|
+
else:
|
|
259
|
+
printing.log_error(
|
|
260
|
+
f"Server did not start after {max_retries} attempts"
|
|
261
|
+
)
|
|
262
|
+
raise RuntimeError("Failed to start the server")
|
|
263
|
+
|
|
264
|
+
# Build API URL
|
|
265
|
+
results_file = os.path.join(output_path, f"{task}_results")
|
|
266
|
+
|
|
267
|
+
printing.log_info(f"Running lm-eval-harness on {task}...")
|
|
268
|
+
|
|
269
|
+
# Build lm-eval-harness command
|
|
270
|
+
cmd = [
|
|
271
|
+
"lm_eval",
|
|
272
|
+
"--model",
|
|
273
|
+
"local-completions",
|
|
274
|
+
"--tasks",
|
|
275
|
+
task,
|
|
276
|
+
"--model_args",
|
|
277
|
+
(
|
|
278
|
+
f"model={checkpoint},"
|
|
279
|
+
f"base_url={server_url}/api/v0/completions,"
|
|
280
|
+
f"num_concurrent=1,"
|
|
281
|
+
f"max_retries=5,"
|
|
282
|
+
f"retry_timeout=10,"
|
|
283
|
+
f"tokenized_requests=False"
|
|
284
|
+
),
|
|
285
|
+
"--num_fewshot",
|
|
286
|
+
str(num_fewshot),
|
|
287
|
+
"--output_path",
|
|
288
|
+
results_file,
|
|
289
|
+
]
|
|
290
|
+
|
|
291
|
+
if limit is not None:
|
|
292
|
+
cmd.extend(["--limit", str(limit)])
|
|
293
|
+
|
|
294
|
+
if log_samples:
|
|
295
|
+
cmd.extend(["--log_samples"])
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
# On Windows, set UTF-8 mode to handle Unicode output
|
|
299
|
+
env = os.environ.copy()
|
|
300
|
+
if sys.platform == "win32":
|
|
301
|
+
env["PYTHONIOENCODING"] = "utf-8"
|
|
302
|
+
|
|
303
|
+
# Execute lm-eval-harness command
|
|
304
|
+
result = subprocess.run(
|
|
305
|
+
cmd, check=True, text=True, capture_output=True, env=env
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Log relevant output and skip any parts that might cause encoding issues
|
|
309
|
+
try:
|
|
310
|
+
printing.log_info(result.stdout)
|
|
311
|
+
except UnicodeEncodeError:
|
|
312
|
+
printing.log_info(
|
|
313
|
+
"Results obtained successfully but couldn't display due to encoding issues"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Process results from the correct location
|
|
317
|
+
results_dir = os.path.join(output_path, f"{task}_results")
|
|
318
|
+
self._process_results(results_dir, state)
|
|
319
|
+
|
|
320
|
+
except subprocess.CalledProcessError as e:
|
|
321
|
+
printing.log_error(f"Error running lm-eval-harness: {e}")
|
|
322
|
+
printing.log_error(f"stderr: {e.stderr}")
|
|
323
|
+
except (IOError, ValueError, requests.RequestException) as e:
|
|
324
|
+
printing.log_error(f"Error: {e}")
|
|
325
|
+
finally:
|
|
326
|
+
# Shut down server
|
|
327
|
+
if self.server_runner and self.server_runner.is_alive():
|
|
328
|
+
printing.log_info("Shutting down server runner...")
|
|
329
|
+
self.server_runner.shutdown()
|
|
330
|
+
|
|
331
|
+
# Make sure we don't have any lingering references to state's model/tokenizer
|
|
332
|
+
# that could prevent garbage collection
|
|
333
|
+
self.server_runner = None
|
|
334
|
+
|
|
335
|
+
return state
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import json
|
|
3
|
+
from fastapi.responses import HTMLResponse
|
|
4
|
+
from lemonade_server.model_manager import ModelManager
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_instructions_html(port=8000):
|
|
8
|
+
"""
|
|
9
|
+
Show instructions on how to use the server.
|
|
10
|
+
"""
|
|
11
|
+
# Load server models from JSON
|
|
12
|
+
server_models_path = (
|
|
13
|
+
Path(__file__).parent.parent.parent.parent
|
|
14
|
+
/ "lemonade_server"
|
|
15
|
+
/ "server_models.json"
|
|
16
|
+
)
|
|
17
|
+
with open(server_models_path, "r", encoding="utf-8") as f:
|
|
18
|
+
server_models = json.load(f)
|
|
19
|
+
|
|
20
|
+
# Use shared filter function from model_manager.py
|
|
21
|
+
filtered_models = ModelManager().filter_models_by_backend(server_models)
|
|
22
|
+
|
|
23
|
+
# Pass filtered server_models to JS
|
|
24
|
+
server_models_js = (
|
|
25
|
+
f"<script>window.SERVER_MODELS = {json.dumps(filtered_models)};</script>"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# New lemon-themed HTML structure
|
|
29
|
+
# pylint: disable=W1401
|
|
30
|
+
styled_html = f"""
|
|
31
|
+
<!DOCTYPE html>
|
|
32
|
+
<html lang=\"en\">
|
|
33
|
+
<head>
|
|
34
|
+
<meta charset=\"UTF-8\">
|
|
35
|
+
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">
|
|
36
|
+
<title>Lemonade Server</title>
|
|
37
|
+
<link rel="icon" href="data:,">
|
|
38
|
+
<link rel=\"stylesheet\" href=\"/static/styles.css\">
|
|
39
|
+
<script>
|
|
40
|
+
window.SERVER_PORT = {port};
|
|
41
|
+
</script>
|
|
42
|
+
{server_models_js}
|
|
43
|
+
</head>
|
|
44
|
+
<body>
|
|
45
|
+
<nav class=\"navbar\">
|
|
46
|
+
<a href=\"https://github.com/lemonade-sdk/lemonade\">GitHub</a>
|
|
47
|
+
<a href=\"https://lemonade-server.ai/docs/\">Docs</a>
|
|
48
|
+
<a href=\"https://lemonade-server.ai/docs/server/server_models/\">Models</a>
|
|
49
|
+
<a href=\"https://lemonade-server.ai/docs/server/apps/\">Featured Apps</a>
|
|
50
|
+
</nav>
|
|
51
|
+
<main class=\"main\">
|
|
52
|
+
<div class=\"title\">🍋 Lemonade Server</div>
|
|
53
|
+
<div class=\"tab-container\">
|
|
54
|
+
<div class=\"tabs\">
|
|
55
|
+
<button class=\"tab active\" id=\"tab-chat\" onclick=\"showTab('chat')\">LLM Chat</button>
|
|
56
|
+
<button class=\"tab\" id=\"tab-models\" onclick=\"showTab('models')\">Model Management</button>
|
|
57
|
+
</div>
|
|
58
|
+
<div class=\"tab-content active\" id=\"content-chat\">
|
|
59
|
+
<div class=\"chat-container\">
|
|
60
|
+
<div class=\"chat-history\" id=\"chat-history\"></div>
|
|
61
|
+
<div class=\"chat-input-row\">
|
|
62
|
+
<select id=\"model-select\"></select>
|
|
63
|
+
<input type=\"text\" id=\"chat-input\" placeholder=\"Type your message...\" />
|
|
64
|
+
<button id=\"send-btn\">Send</button>
|
|
65
|
+
</div>
|
|
66
|
+
</div>
|
|
67
|
+
</div>
|
|
68
|
+
<div class=\"tab-content\" id=\"content-models\">
|
|
69
|
+
<div class=\"model-mgmt-container\">
|
|
70
|
+
<div class=\"model-mgmt-pane\">
|
|
71
|
+
<h3>Installed Models</h3>
|
|
72
|
+
<table class=\"model-table\" id=\"installed-models-table\">
|
|
73
|
+
<colgroup><col style=\"width:100%\"></colgroup>
|
|
74
|
+
<tbody id=\"installed-models-tbody\"></tbody>
|
|
75
|
+
</table>
|
|
76
|
+
</div>
|
|
77
|
+
<div class=\"model-mgmt-pane\">
|
|
78
|
+
<h3>Suggested Models</h3>
|
|
79
|
+
<table class=\"model-table\" id=\"suggested-models-table\">
|
|
80
|
+
<tbody id=\"suggested-models-tbody\"></tbody>
|
|
81
|
+
</table>
|
|
82
|
+
</div>
|
|
83
|
+
</div>
|
|
84
|
+
</div>
|
|
85
|
+
</div>
|
|
86
|
+
</main>
|
|
87
|
+
<footer class=\"site-footer\">
|
|
88
|
+
<div class=\"dad-joke\">When life gives you LLMs, make an LLM aide.</div>
|
|
89
|
+
<div class=\"copyright\">Copyright 2025 AMD</div>
|
|
90
|
+
</footer>
|
|
91
|
+
<script src=\"https://cdn.jsdelivr.net/npm/openai@4.21.0/dist/openai.min.js\"></script>
|
|
92
|
+
<script>
|
|
93
|
+
// Tab switching logic
|
|
94
|
+
function showTab(tab) {{
|
|
95
|
+
document.getElementById('tab-chat').classList.remove('active');
|
|
96
|
+
document.getElementById('tab-models').classList.remove('active');
|
|
97
|
+
document.getElementById('content-chat').classList.remove('active');
|
|
98
|
+
document.getElementById('content-models').classList.remove('active');
|
|
99
|
+
if (tab === 'chat') {{
|
|
100
|
+
document.getElementById('tab-chat').classList.add('active');
|
|
101
|
+
document.getElementById('content-chat').classList.add('active');
|
|
102
|
+
}} else {{
|
|
103
|
+
document.getElementById('tab-models').classList.add('active');
|
|
104
|
+
document.getElementById('content-models').classList.add('active');
|
|
105
|
+
}}
|
|
106
|
+
}}
|
|
107
|
+
|
|
108
|
+
// Helper to get server base URL
|
|
109
|
+
function getServerBaseUrl() {{
|
|
110
|
+
const port = window.SERVER_PORT || 8000;
|
|
111
|
+
return `http://localhost:{port}`;
|
|
112
|
+
}}
|
|
113
|
+
|
|
114
|
+
// Populate model dropdown from /api/v1/models endpoint
|
|
115
|
+
async function loadModels() {{
|
|
116
|
+
try {{
|
|
117
|
+
const resp = await fetch(getServerBaseUrl() + '/api/v1/models');
|
|
118
|
+
const data = await resp.json();
|
|
119
|
+
const select = document.getElementById('model-select');
|
|
120
|
+
select.innerHTML = '';
|
|
121
|
+
if (!data.data || !Array.isArray(data.data)) {{
|
|
122
|
+
select.innerHTML = '<option>No models found (malformed response)</option>';
|
|
123
|
+
return;
|
|
124
|
+
}}
|
|
125
|
+
if (data.data.length === 0) {{
|
|
126
|
+
select.innerHTML = '<option>No models available</option>';
|
|
127
|
+
return;
|
|
128
|
+
}}
|
|
129
|
+
let defaultIndex = 0;
|
|
130
|
+
data.data.forEach(function(model, index) {{
|
|
131
|
+
const modelId = model.id || model.name || model;
|
|
132
|
+
const opt = document.createElement('option');
|
|
133
|
+
opt.value = modelId;
|
|
134
|
+
opt.textContent = modelId;
|
|
135
|
+
if (modelId === 'Llama-3.2-1B-Instruct-Hybrid') {{
|
|
136
|
+
defaultIndex = index;
|
|
137
|
+
}}
|
|
138
|
+
select.appendChild(opt);
|
|
139
|
+
}});
|
|
140
|
+
select.selectedIndex = defaultIndex;
|
|
141
|
+
}} catch (e) {{
|
|
142
|
+
const select = document.getElementById('model-select');
|
|
143
|
+
select.innerHTML = `<option>Error loading models: ${{e.message}}</option>`;
|
|
144
|
+
console.error('Error loading models:', e);
|
|
145
|
+
}}
|
|
146
|
+
}}
|
|
147
|
+
loadModels();
|
|
148
|
+
|
|
149
|
+
// Model Management Tab Logic
|
|
150
|
+
async function refreshModelMgmtUI() {{
|
|
151
|
+
// Get installed models from /api/v1/models
|
|
152
|
+
let installed = [];
|
|
153
|
+
try {{
|
|
154
|
+
const resp = await fetch(getServerBaseUrl() + '/api/v1/models');
|
|
155
|
+
const data = await resp.json();
|
|
156
|
+
if (data.data && Array.isArray(data.data)) {{
|
|
157
|
+
installed = data.data.map(m => m.id || m.name || m);
|
|
158
|
+
}}
|
|
159
|
+
}} catch (e) {{}}
|
|
160
|
+
// All models from server_models.json (window.SERVER_MODELS)
|
|
161
|
+
const allModels = window.SERVER_MODELS || {{}};
|
|
162
|
+
// Filter suggested models not installed
|
|
163
|
+
const suggested = Object.keys(allModels).filter(
|
|
164
|
+
k => allModels[k].suggested && !installed.includes(k)
|
|
165
|
+
);
|
|
166
|
+
// Render installed models as a table (two columns, second is invisible)
|
|
167
|
+
const installedTbody = document.getElementById('installed-models-tbody');
|
|
168
|
+
installedTbody.innerHTML = '';
|
|
169
|
+
installed.forEach(function(mid) {{
|
|
170
|
+
var tr = document.createElement('tr');
|
|
171
|
+
var tdName = document.createElement('td');
|
|
172
|
+
tdName.textContent = mid;
|
|
173
|
+
var tdEmpty = document.createElement('td');
|
|
174
|
+
tdEmpty.style.width = '0';
|
|
175
|
+
tdEmpty.style.padding = '0';
|
|
176
|
+
tdEmpty.style.border = 'none';
|
|
177
|
+
tr.appendChild(tdName);
|
|
178
|
+
tr.appendChild(tdEmpty);
|
|
179
|
+
installedTbody.appendChild(tr);
|
|
180
|
+
}});
|
|
181
|
+
// Render suggested models as a table
|
|
182
|
+
const suggestedTbody = document.getElementById('suggested-models-tbody');
|
|
183
|
+
suggestedTbody.innerHTML = '';
|
|
184
|
+
suggested.forEach(mid => {{
|
|
185
|
+
const tr = document.createElement('tr');
|
|
186
|
+
const tdName = document.createElement('td');
|
|
187
|
+
tdName.textContent = mid;
|
|
188
|
+
tdName.style.paddingRight = '1em';
|
|
189
|
+
tdName.style.verticalAlign = 'middle';
|
|
190
|
+
const tdBtn = document.createElement('td');
|
|
191
|
+
tdBtn.style.width = '1%';
|
|
192
|
+
tdBtn.style.verticalAlign = 'middle';
|
|
193
|
+
const btn = document.createElement('button');
|
|
194
|
+
btn.textContent = '+';
|
|
195
|
+
btn.title = 'Install model';
|
|
196
|
+
btn.onclick = async function() {{
|
|
197
|
+
btn.disabled = true;
|
|
198
|
+
btn.textContent = 'Installing...';
|
|
199
|
+
btn.classList.add('installing-btn');
|
|
200
|
+
try {{
|
|
201
|
+
await fetch(getServerBaseUrl() + '/api/v1/pull', {{
|
|
202
|
+
method: 'POST',
|
|
203
|
+
headers: {{ 'Content-Type': 'application/json' }},
|
|
204
|
+
body: JSON.stringify({{ model_name: mid }})
|
|
205
|
+
}});
|
|
206
|
+
await refreshModelMgmtUI();
|
|
207
|
+
await loadModels(); // update chat dropdown too
|
|
208
|
+
}} catch (e) {{
|
|
209
|
+
btn.textContent = 'Error';
|
|
210
|
+
}}
|
|
211
|
+
}};
|
|
212
|
+
tdBtn.appendChild(btn);
|
|
213
|
+
tr.appendChild(tdName);
|
|
214
|
+
tr.appendChild(tdBtn);
|
|
215
|
+
suggestedTbody.appendChild(tr);
|
|
216
|
+
}});
|
|
217
|
+
}}
|
|
218
|
+
// Initial load
|
|
219
|
+
refreshModelMgmtUI();
|
|
220
|
+
// Optionally, refresh when switching to the tab
|
|
221
|
+
document.getElementById('tab-models').addEventListener('click', refreshModelMgmtUI);
|
|
222
|
+
|
|
223
|
+
// Chat logic (streaming with OpenAI JS client placeholder)
|
|
224
|
+
const chatHistory = document.getElementById('chat-history');
|
|
225
|
+
const chatInput = document.getElementById('chat-input');
|
|
226
|
+
const sendBtn = document.getElementById('send-btn');
|
|
227
|
+
const modelSelect = document.getElementById('model-select');
|
|
228
|
+
let messages = [];
|
|
229
|
+
|
|
230
|
+
function appendMessage(role, text) {{
|
|
231
|
+
const div = document.createElement('div');
|
|
232
|
+
div.className = 'chat-message ' + role;
|
|
233
|
+
// Add a bubble for iMessage style
|
|
234
|
+
const bubble = document.createElement('div');
|
|
235
|
+
bubble.className = 'chat-bubble ' + role;
|
|
236
|
+
bubble.innerHTML = text;
|
|
237
|
+
div.appendChild(bubble);
|
|
238
|
+
chatHistory.appendChild(div);
|
|
239
|
+
chatHistory.scrollTop = chatHistory.scrollHeight;
|
|
240
|
+
}}
|
|
241
|
+
|
|
242
|
+
async function sendMessage() {{
|
|
243
|
+
const text = chatInput.value.trim();
|
|
244
|
+
if (!text) return;
|
|
245
|
+
appendMessage('user', text);
|
|
246
|
+
messages.push({{ role: 'user', content: text }});
|
|
247
|
+
chatInput.value = '';
|
|
248
|
+
sendBtn.disabled = true;
|
|
249
|
+
// Streaming OpenAI completions (placeholder, adapt as needed)
|
|
250
|
+
let llmText = '';
|
|
251
|
+
appendMessage('llm', '...');
|
|
252
|
+
const llmDiv = chatHistory.lastChild.querySelector('.chat-bubble.llm');
|
|
253
|
+
try {{
|
|
254
|
+
// Use the correct endpoint for chat completions
|
|
255
|
+
const resp = await fetch(getServerBaseUrl() + '/api/v1/chat/completions', {{
|
|
256
|
+
method: 'POST',
|
|
257
|
+
headers: {{ 'Content-Type': 'application/json' }},
|
|
258
|
+
body: JSON.stringify({{
|
|
259
|
+
model: modelSelect.value,
|
|
260
|
+
messages: messages,
|
|
261
|
+
stream: true
|
|
262
|
+
}})
|
|
263
|
+
}});
|
|
264
|
+
if (!resp.body) throw new Error('No stream');
|
|
265
|
+
const reader = resp.body.getReader();
|
|
266
|
+
let decoder = new TextDecoder();
|
|
267
|
+
llmDiv.textContent = '';
|
|
268
|
+
while (true) {{
|
|
269
|
+
const {{ done, value }} = await reader.read();
|
|
270
|
+
if (done) break;
|
|
271
|
+
const chunk = decoder.decode(value);
|
|
272
|
+
if (chunk.trim() === 'data: [DONE]' || chunk.trim() === '[DONE]') continue;
|
|
273
|
+
// Try to extract the content from the OpenAI chunk
|
|
274
|
+
const match = chunk.match(/"content"\s*:\s*"([^"]*)"/);
|
|
275
|
+
if (match && match[1]) {{
|
|
276
|
+
llmText += match[1];
|
|
277
|
+
llmDiv.textContent = llmText;
|
|
278
|
+
}}
|
|
279
|
+
}}
|
|
280
|
+
messages.push({{ role: 'assistant', content: llmText }});
|
|
281
|
+
}} catch (e) {{
|
|
282
|
+
llmDiv.textContent = '[Error: ' + e.message + ']';
|
|
283
|
+
}}
|
|
284
|
+
sendBtn.disabled = false;
|
|
285
|
+
}}
|
|
286
|
+
sendBtn.onclick = sendMessage;
|
|
287
|
+
chatInput.addEventListener('keydown', function(e) {{
|
|
288
|
+
if (e.key === 'Enter') sendMessage();
|
|
289
|
+
}});
|
|
290
|
+
</script>
|
|
291
|
+
</body>
|
|
292
|
+
</html>
|
|
293
|
+
"""
|
|
294
|
+
return HTMLResponse(content=styled_html)
|