lemonade-sdk 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

Files changed (61) hide show
  1. lemonade/__init__.py +5 -0
  2. lemonade/api.py +125 -0
  3. lemonade/cache.py +85 -0
  4. lemonade/cli.py +135 -0
  5. lemonade/common/__init__.py +0 -0
  6. lemonade/common/analyze_model.py +26 -0
  7. lemonade/common/build.py +223 -0
  8. lemonade/common/cli_helpers.py +139 -0
  9. lemonade/common/exceptions.py +98 -0
  10. lemonade/common/filesystem.py +368 -0
  11. lemonade/common/labels.py +61 -0
  12. lemonade/common/onnx_helpers.py +176 -0
  13. lemonade/common/plugins.py +10 -0
  14. lemonade/common/printing.py +110 -0
  15. lemonade/common/status.py +490 -0
  16. lemonade/common/system_info.py +390 -0
  17. lemonade/common/tensor_helpers.py +83 -0
  18. lemonade/common/test_helpers.py +28 -0
  19. lemonade/profilers/__init__.py +1 -0
  20. lemonade/profilers/memory_tracker.py +257 -0
  21. lemonade/profilers/profiler.py +55 -0
  22. lemonade/sequence.py +363 -0
  23. lemonade/state.py +159 -0
  24. lemonade/tools/__init__.py +1 -0
  25. lemonade/tools/adapter.py +104 -0
  26. lemonade/tools/bench.py +284 -0
  27. lemonade/tools/huggingface_bench.py +267 -0
  28. lemonade/tools/huggingface_load.py +520 -0
  29. lemonade/tools/humaneval.py +258 -0
  30. lemonade/tools/llamacpp.py +261 -0
  31. lemonade/tools/llamacpp_bench.py +154 -0
  32. lemonade/tools/management_tools.py +273 -0
  33. lemonade/tools/mmlu.py +327 -0
  34. lemonade/tools/ort_genai/__init__.py +0 -0
  35. lemonade/tools/ort_genai/oga.py +1129 -0
  36. lemonade/tools/ort_genai/oga_bench.py +142 -0
  37. lemonade/tools/perplexity.py +146 -0
  38. lemonade/tools/prompt.py +228 -0
  39. lemonade/tools/quark/__init__.py +0 -0
  40. lemonade/tools/quark/quark_load.py +172 -0
  41. lemonade/tools/quark/quark_quantize.py +439 -0
  42. lemonade/tools/report/__init__.py +0 -0
  43. lemonade/tools/report/llm_report.py +203 -0
  44. lemonade/tools/report/table.py +739 -0
  45. lemonade/tools/server/__init__.py +0 -0
  46. lemonade/tools/server/serve.py +1354 -0
  47. lemonade/tools/server/tool_calls.py +146 -0
  48. lemonade/tools/tool.py +374 -0
  49. lemonade/version.py +1 -0
  50. lemonade_install/__init__.py +1 -0
  51. lemonade_install/install.py +774 -0
  52. lemonade_sdk-7.0.0.dist-info/METADATA +116 -0
  53. lemonade_sdk-7.0.0.dist-info/RECORD +61 -0
  54. lemonade_sdk-7.0.0.dist-info/WHEEL +5 -0
  55. lemonade_sdk-7.0.0.dist-info/entry_points.txt +4 -0
  56. lemonade_sdk-7.0.0.dist-info/licenses/LICENSE +201 -0
  57. lemonade_sdk-7.0.0.dist-info/licenses/NOTICE.md +21 -0
  58. lemonade_sdk-7.0.0.dist-info/top_level.txt +3 -0
  59. lemonade_server/cli.py +260 -0
  60. lemonade_server/model_manager.py +98 -0
  61. lemonade_server/server_models.json +142 -0
@@ -0,0 +1,1354 @@
1
+ import argparse
2
+ import asyncio
3
+ import statistics
4
+ import time
5
+ from threading import Thread, Event
6
+ import logging
7
+ import traceback
8
+ from typing import Optional, Union
9
+ import json
10
+
11
+ from fastapi import FastAPI, HTTPException, status, Request
12
+ from fastapi.responses import StreamingResponse, HTMLResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel
15
+ import uvicorn
16
+ from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
17
+ from tabulate import tabulate
18
+
19
+ from openai.types.completion import Completion, CompletionChoice
20
+ from openai.types.chat import ChatCompletion, ChatCompletionChunk
21
+ from openai.types.chat import ChatCompletionMessage
22
+ from openai.types.chat.chat_completion_message_tool_call import (
23
+ ChatCompletionMessageToolCall,
24
+ Function,
25
+ )
26
+ from openai.types.chat.chat_completion import Choice
27
+ from openai.types.chat.chat_completion_chunk import ChoiceDelta
28
+ from openai.types.completion_choice import Logprobs
29
+ from openai.types.model import Model
30
+ from openai.types.responses import (
31
+ Response,
32
+ ResponseOutputMessage,
33
+ ResponseOutputText,
34
+ ResponseCreatedEvent,
35
+ ResponseTextDeltaEvent,
36
+ ResponseCompletedEvent,
37
+ )
38
+
39
+ import lemonade.api as lemonade_api
40
+ from lemonade_server.model_manager import ModelManager
41
+ from lemonade.tools.management_tools import ManagementTool
42
+ from lemonade.tools.server.tool_calls import extract_tool_calls
43
+
44
+ # Set to a high number to allow for interesting experiences in real apps
45
+ # Tests should use the max_new_tokens argument to set a lower value
46
+ DEFAULT_MAX_NEW_TOKENS = 1500
47
+
48
+ DEFAULT_PORT = 8000
49
+ DEFAULT_LOG_LEVEL = "info"
50
+
51
+
52
+ class ServerModel(Model):
53
+ """
54
+ An extension of OpenAI's Model class that adds
55
+ checkpoint and recipe attributes.
56
+ """
57
+
58
+ checkpoint: str
59
+ recipe: str
60
+
61
+
62
+ class GeneratorThread(Thread):
63
+ """
64
+ Thread class designed for use with streaming generation within
65
+ an LLM server. It needs access to the streamer in order to order
66
+ to help the completions APIs escape the "for text in streamer" loop.
67
+ It also provides exception handling that works nicely with HTTP
68
+ servers by providing the stack trace and making the exception
69
+ information available to the main thread.
70
+ """
71
+
72
+ def __init__(self, streamer, *args, **kwargs):
73
+ super().__init__(*args, **kwargs)
74
+ self.exception = None
75
+ self.streamer = streamer
76
+
77
+ def run(self):
78
+ try:
79
+ if self._target:
80
+ self._target(*self._args, **self._kwargs)
81
+ except Exception as e: # pylint: disable=broad-except
82
+ self.exception = e
83
+ logging.error(f"Exception raised in generate thread: {e}")
84
+ traceback.print_exc()
85
+ self.streamer.done()
86
+
87
+
88
+ class StopOnEvent(StoppingCriteria):
89
+ """
90
+ Custom stopping criteria that halts text generation when a specified event is set.
91
+
92
+ This allows for external control of generation, such as stopping a generation
93
+ before it reaches the maximum token limit.
94
+ """
95
+
96
+ def __init__(self, stop_event: Event):
97
+ super().__init__()
98
+ self.stop_event = stop_event
99
+
100
+ def __call__(self, input_ids, scores, **kwargs):
101
+ return self.stop_event.is_set()
102
+
103
+
104
+ class PullConfig(BaseModel):
105
+ """
106
+ Configurating for installing a supported LLM.
107
+ """
108
+
109
+ model_name: str
110
+
111
+
112
+ class LoadConfig(BaseModel):
113
+ """
114
+ Configuration for loading a language model.
115
+
116
+ Specifies the model checkpoint, generation parameters,
117
+ and hardware/framework configuration (recipe) for model loading.
118
+ """
119
+
120
+ model_name: Optional[str] = None
121
+ checkpoint: Optional[str] = None
122
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
123
+ recipe: Optional[str] = None
124
+ # Indicates the maximum prompt length allowed for that specific
125
+ # checkpoint + recipe combination
126
+ max_prompt_length: Optional[int] = None
127
+ # Indicates whether the model is a reasoning model, like DeepSeek
128
+ reasoning: Optional[bool] = False
129
+
130
+
131
+ class CompletionRequest(BaseModel):
132
+ """
133
+ Request model for text completion API endpoint.
134
+
135
+ Contains a prompt, a model identifier, and a streaming
136
+ flag to control response delivery.
137
+ """
138
+
139
+ prompt: str
140
+ model: str
141
+ echo: bool = False
142
+ stream: bool = False
143
+ logprobs: int | None = False
144
+ stop: list[str] | str | None = None
145
+ temperature: float | None = None
146
+ max_tokens: int | None = None
147
+
148
+
149
+ class ChatCompletionRequest(BaseModel):
150
+ """
151
+ Request model for chat completion API endpoint.
152
+
153
+ Contains a list of chat messages, a model identifier,
154
+ and a streaming flag to control response delivery.
155
+ """
156
+
157
+ messages: list[dict]
158
+ model: str
159
+ stream: bool = False
160
+ logprobs: int | None = False
161
+ stop: list[str] | str | None = None
162
+ temperature: float | None = None
163
+ tools: list[dict] | None = None
164
+ max_tokens: int | None = None
165
+ max_completion_tokens: int | None = None
166
+
167
+
168
+ class ResponsesRequest(BaseModel):
169
+ """
170
+ Request model for responses API endpoint.
171
+ """
172
+
173
+ input: list[dict] | str
174
+ model: str
175
+ max_output_tokens: int | None = None
176
+ temperature: float | None = None
177
+ stream: bool = False
178
+
179
+
180
+ class Server(ManagementTool):
181
+ """
182
+ Open a web server that apps can use to communicate with the LLM.
183
+
184
+ The server exposes these endpoints:
185
+ - /api/v0/pull: install an LLM by its Lemonade Server Model Name.
186
+ - /api/v0/load: load a model checkpoint.
187
+ - /api/v0/unload: unload a model checkpoint.
188
+ - /api/v0/health: check whether a model is loaded and ready to serve.
189
+ - /api/v0/stats: performance statistics for the generation.
190
+ - /api/v0/halt: stop an in-progress generation from make more tokens.
191
+ - /api/v0/completions: completion responses using HTTP chunked transfer encoding.
192
+ - /api/v0/chat/completions: chat completion responses using HTTP chunked transfer encoding.
193
+ - /api/v0/responses: responses API using HTTP chunked transfer encoding.
194
+ - /api/v0/models: list all available models.
195
+ """
196
+
197
+ unique_name = "serve"
198
+
199
+ def __init__(self):
200
+ super().__init__()
201
+
202
+ # Initialize FastAPI app
203
+ self.app = FastAPI()
204
+
205
+ # Add CORS middleware
206
+ self.app.add_middleware(
207
+ CORSMiddleware,
208
+ allow_origins=["*"], # Allows all origins
209
+ allow_credentials=True,
210
+ allow_methods=["*"], # Allows all methods
211
+ allow_headers=["*"], # Allows all headers
212
+ )
213
+
214
+ # Set up custom routes
215
+ self.app.post("/api/v0/pull")(self.pull)
216
+ self.app.post("/api/v0/load")(self.load_llm)
217
+ self.app.post("/api/v0/unload")(self.unload_llm)
218
+ self.app.get("/api/v0/health")(self.health)
219
+ self.app.get("/api/v0/halt")(self.halt_generation)
220
+ self.app.get("/api/v0/stats")(self.send_stats)
221
+ self.app.post("/api/v0/completions")(self.completions)
222
+ self.app.post("/api/v0/responses")(self.responses)
223
+
224
+ # Set up OpenAI-compatible routes
225
+ self.app.post("/api/v0/chat/completions")(self.chat_completions)
226
+ self.app.post("/api/v0/completions")(self.completions)
227
+ self.app.get("/api/v0/models")(self.models)
228
+
229
+ # Set up instructions
230
+ self.app.get("/")(self.instructions)
231
+
232
+ # Performance stats that are set during /ws and can be
233
+ # fetched in /stats
234
+ self.time_to_first_token = None
235
+ self.tokens_per_second = None
236
+ self.input_tokens = None
237
+ self.output_tokens = None
238
+ self.decode_token_times = None
239
+
240
+ # Input truncation settings
241
+ self.truncate_inputs = False
242
+
243
+ # Store debug logging state
244
+ self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
245
+
246
+ # Flag that tells the LLM to stop generating text and end the response
247
+ self.stop_event = Event()
248
+
249
+ self.llm_loaded: LoadConfig = None
250
+ self.tokenizer = None
251
+
252
+ # Placeholders for model and configs
253
+ self.model = None
254
+
255
+ # Initialize semaphore for tracking active generations
256
+ self.max_concurrent_generations = 1
257
+ self._generate_semaphore = asyncio.Semaphore(self.max_concurrent_generations)
258
+
259
+ # Dictionary of installed LLM, by model name : information about those models
260
+ # Does not include non-installed models
261
+ self.local_models = ModelManager().downloaded_models_enabled
262
+
263
+ # Add lock for load/unload operations
264
+ self._load_lock = asyncio.Lock()
265
+
266
+ @staticmethod
267
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
268
+ parser = __class__.helpful_parser(
269
+ short_description="Launch an industry-standard LLM server",
270
+ add_help=add_help,
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--port",
275
+ required=False,
276
+ type=int,
277
+ default=DEFAULT_PORT,
278
+ help=f"Port number to run the server on (default: {DEFAULT_PORT})",
279
+ )
280
+ parser.add_argument(
281
+ "--log-level",
282
+ required=False,
283
+ type=str,
284
+ default=DEFAULT_LOG_LEVEL,
285
+ choices=["critical", "error", "warning", "info", "debug", "trace"],
286
+ help=f"Logging level (default: {DEFAULT_LOG_LEVEL})",
287
+ )
288
+
289
+ return parser
290
+
291
+ def run(
292
+ self,
293
+ # ManagementTool has a required cache_dir arg, but
294
+ # we always use the default cache directory
295
+ _=None,
296
+ port: int = DEFAULT_PORT,
297
+ log_level: str = DEFAULT_LOG_LEVEL,
298
+ truncate_inputs: bool = False,
299
+ ):
300
+ # Store truncation settings
301
+ self.truncate_inputs = truncate_inputs
302
+
303
+ # Define TRACE level
304
+ logging.TRACE = 9 # Lower than DEBUG which is 10
305
+ logging.addLevelName(logging.TRACE, "TRACE")
306
+
307
+ # Add a convenience function at the module level
308
+ def trace(message, *args, **kwargs):
309
+ logging.log(logging.TRACE, message, *args, **kwargs)
310
+
311
+ logging.trace = trace
312
+
313
+ # Configure logging to match uvicorn's format
314
+ logging_level = getattr(logging, log_level.upper())
315
+ logging.basicConfig(
316
+ level=logging_level,
317
+ format="%(levelprefix)s %(message)s",
318
+ datefmt="%Y-%m-%d %H:%M:%S",
319
+ )
320
+
321
+ # Add uvicorn's log formatter
322
+ logging.root.handlers[0].formatter = uvicorn.logging.DefaultFormatter(
323
+ fmt="%(levelprefix)s %(message)s",
324
+ use_colors=True,
325
+ )
326
+
327
+ # Ensure the log level is properly set
328
+ logging.getLogger().setLevel(logging_level)
329
+
330
+ # Update debug logging state after setting log level
331
+ self.debug_logging_enabled = logging.getLogger().isEnabledFor(logging.DEBUG)
332
+
333
+ if self.debug_logging_enabled:
334
+ # Print the elapsed time for each request
335
+ self.setup_middleware_timer()
336
+
337
+ uvicorn.run(self.app, host="localhost", port=port, log_level=log_level)
338
+
339
+ async def _show_telemetry(self):
340
+ """
341
+ Show telemetry data in debug mode.
342
+ """
343
+ # Exit early if debug logging is disabled or no telemetry data is available
344
+ if not self.debug_logging_enabled or self.tokens_per_second is None:
345
+ return
346
+
347
+ # Prepare telemetry data (transposed format)
348
+ telemetry = [
349
+ ["Input tokens", self.input_tokens],
350
+ ["Output tokens", self.output_tokens],
351
+ ["TTFT (s)", f"{self.time_to_first_token:.2f}"],
352
+ ["TPS", f"{self.tokens_per_second:.2f}"],
353
+ ]
354
+
355
+ table = tabulate(
356
+ telemetry, headers=["Metric", "Value"], tablefmt="fancy_grid"
357
+ ).split("\n")
358
+
359
+ # Show telemetry in debug while complying with uvicorn's log indentation
360
+ logging.debug("\n ".join(table))
361
+
362
+ def instructions(self):
363
+ """
364
+ Show instructions on how to use the server.
365
+ """
366
+ html_content = """
367
+ <!DOCTYPE html>
368
+ <html>
369
+ <head>
370
+ <title>Lemonade Server</title>
371
+ <link rel="icon" href="data:,">
372
+ </head>
373
+ <body>
374
+ <h1>🍋 Welcome to Lemonade Server!</h1>
375
+ <p>
376
+ A standards-compliant server that provides REST APIs for LLM communication.
377
+ To get started, simply point your OpenAI-compatible application at the server's endpoint.
378
+ </p>
379
+ <div class="links">
380
+ <h3>Documentation:</h3>
381
+ <ul>
382
+ <li><a href="https://github.com/lemonade-sdk/lemonade/tree/main/docs/server/apps/README.md">Examples & Usage</a></li>
383
+ <li><a href="https://github.com/lemonade-sdk/lemonade/blob/main/docs/server/server_integration.md">Integration Guide</a></li>
384
+ <li><a href="https://github.com/lemonade-sdk/lemonade/blob/main/docs/server/server_spec.md">Server Specification</a></li>
385
+ </ul>
386
+ </div>
387
+ </body>
388
+ </html>
389
+ """
390
+ return HTMLResponse(content=html_content, status_code=200)
391
+
392
+ def initialize_load_config(
393
+ self, request: Union[ChatCompletionRequest, CompletionRequest]
394
+ ) -> LoadConfig:
395
+ """
396
+ Turn the Request object into a partially-complete LoadConfig.
397
+
398
+ The load_llm() method is responsible for filling in the rest of
399
+ LoadConfig's parameters.
400
+ """
401
+
402
+ # Get model config
403
+ if "/" in request.model:
404
+ # We know the model is a Hugging Face checkpoint if it contains a /
405
+ lc = LoadConfig(checkpoint=request.model)
406
+ else:
407
+ # The model should be a reference to a built-in model
408
+ lc = LoadConfig(model_name=request.model)
409
+
410
+ return lc
411
+
412
+ async def completions(self, completion_request: CompletionRequest):
413
+ """
414
+ Stream completion responses using HTTP chunked transfer encoding.
415
+ """
416
+
417
+ lc = self.initialize_load_config(completion_request)
418
+
419
+ # Load the model if it's different from the currently loaded one
420
+ await self.load_llm(lc, internal_call=True)
421
+
422
+ # Check if the model supports reasoning
423
+ reasoning_first_token = self.llm_loaded.reasoning
424
+
425
+ # If the model supports reasoning, we:
426
+ # 1. add a <think> tag to the model's context
427
+ # 2. ensure that the first token is a <think> token
428
+ text = completion_request.prompt
429
+ if reasoning_first_token:
430
+ text += "<think>"
431
+
432
+ # Prepare generation arguments
433
+ generation_args = {
434
+ "message": text,
435
+ "stop": completion_request.stop,
436
+ "temperature": completion_request.temperature,
437
+ "max_new_tokens": completion_request.max_tokens,
438
+ }
439
+
440
+ if completion_request.stream:
441
+
442
+ if completion_request.logprobs:
443
+ logging.warning("logprobs is not supported for streaming completion")
444
+ if completion_request.echo:
445
+ logging.warning(
446
+ "`Echo` parameter is not supported for streaming completions"
447
+ )
448
+
449
+ # Stream the response
450
+ async def generate():
451
+ # Declare it's the same variable from outside scope
452
+ # This is necessary because the variable is modified
453
+ # in the inner function
454
+ nonlocal reasoning_first_token
455
+
456
+ async for token in self._generate_tokens(**generation_args):
457
+ choice = CompletionChoice(
458
+ text=("<think>" + token if reasoning_first_token else token),
459
+ index=0,
460
+ finish_reason="stop",
461
+ logprobs=None,
462
+ )
463
+
464
+ completion = Completion(
465
+ id="0",
466
+ choices=[choice],
467
+ model=self.llm_loaded.checkpoint,
468
+ object="text_completion",
469
+ created=int(time.time()),
470
+ )
471
+
472
+ # Format as SSE
473
+ reasoning_first_token = False
474
+ yield f"data: {completion.model_dump_json()}\n\n".encode("utf-8")
475
+
476
+ # Send the [DONE] marker
477
+ yield b"data: [DONE]\n\n"
478
+
479
+ return StreamingResponse(
480
+ generate(),
481
+ media_type="text/event-stream",
482
+ headers={
483
+ "Cache-Control": "no-cache",
484
+ "Connection": "keep-alive",
485
+ },
486
+ )
487
+
488
+ # If streaming is not requested, collect all generated tokens into a single response
489
+ else:
490
+ full_response = text if completion_request.echo else ""
491
+ async for token in self._generate_tokens(**generation_args):
492
+ full_response += token
493
+
494
+ # If logprobs are requested, create a logprobs object
495
+ logprobs = None
496
+ if completion_request.logprobs:
497
+
498
+ # Compute the logprobs
499
+ text_offset, token_logprobs, tokens, top_logprobs = (
500
+ self.model.compute_logprobs(
501
+ text=full_response,
502
+ tokenizer=self.tokenizer,
503
+ logprobs=completion_request.logprobs,
504
+ )
505
+ )
506
+ logprobs = Logprobs.model_construct(
507
+ text_offset=text_offset,
508
+ token_logprobs=token_logprobs,
509
+ tokens=tokens,
510
+ top_logprobs=top_logprobs,
511
+ )
512
+
513
+ choice = CompletionChoice(
514
+ text=full_response,
515
+ index=0,
516
+ finish_reason="stop",
517
+ logprobs=logprobs,
518
+ )
519
+
520
+ return Completion(
521
+ id="0",
522
+ choices=[choice],
523
+ model=self.llm_loaded.checkpoint,
524
+ object="text_completion",
525
+ created=int(time.time()),
526
+ )
527
+
528
+ async def chat_completions(self, chat_completion_request: ChatCompletionRequest):
529
+ """
530
+ Stream chat completion responses using HTTP chunked transfer encoding.
531
+ """
532
+
533
+ if chat_completion_request.tools and chat_completion_request.stream:
534
+ logging.warning(
535
+ "tools are only supported on non-streaming chat completions"
536
+ )
537
+ if chat_completion_request.logprobs:
538
+ logging.warning("logprobs is not supported on chat completion")
539
+
540
+ lc = self.initialize_load_config(chat_completion_request)
541
+
542
+ # Load the model if it's different from the currently loaded one
543
+ await self.load_llm(lc, internal_call=True)
544
+
545
+ # Convert chat messages to text using the model's chat template
546
+ text = self.apply_chat_template(
547
+ chat_completion_request.messages,
548
+ tools=(
549
+ chat_completion_request.tools
550
+ if not chat_completion_request.stream
551
+ else None
552
+ ),
553
+ )
554
+
555
+ # If the model supports reasoning, we:
556
+ # 1. add a <think> tag to the model's context
557
+ # 2. ensure that the first token is a <think> token
558
+ reasoning_first_token = self.llm_loaded.reasoning
559
+
560
+ if reasoning_first_token:
561
+ text += "<think>"
562
+ # Set the max_new_tokens parameter
563
+ if (
564
+ chat_completion_request.max_completion_tokens
565
+ and chat_completion_request.max_tokens
566
+ ):
567
+ raise HTTPException(
568
+ status_code=status.HTTP_400_BAD_REQUEST,
569
+ detail=(
570
+ "Both max_tokens and max_completion_tokens were provided. "
571
+ "Please use only one of these parameters.",
572
+ ),
573
+ )
574
+ max_new_tokens = (
575
+ chat_completion_request.max_completion_tokens
576
+ if chat_completion_request.max_completion_tokens
577
+ else chat_completion_request.max_tokens
578
+ )
579
+
580
+ # Prepare generation arguments
581
+ generation_args = {
582
+ "message": text,
583
+ "stop": chat_completion_request.stop,
584
+ "temperature": chat_completion_request.temperature,
585
+ "max_new_tokens": max_new_tokens,
586
+ }
587
+
588
+ if chat_completion_request.stream:
589
+
590
+ # Stream the response
591
+ async def generate():
592
+ # Declare it's the same variable from outside scope
593
+ # This is necessary because the variable is modified
594
+ # in the inner function
595
+ nonlocal reasoning_first_token
596
+
597
+ async for token in self._generate_tokens(**generation_args):
598
+
599
+ # Create a ChatCompletionChunk
600
+ chunk = ChatCompletionChunk.model_construct(
601
+ id="0",
602
+ object="chat.completion.chunk",
603
+ created=int(time.time()),
604
+ model=self.llm_loaded.checkpoint,
605
+ choices=[
606
+ Choice.model_construct(
607
+ index=0,
608
+ delta=ChoiceDelta(
609
+ content=(
610
+ "<think>" + token
611
+ if reasoning_first_token
612
+ else token
613
+ ),
614
+ function_call=None,
615
+ role="assistant",
616
+ tool_calls=None,
617
+ refusal=None,
618
+ ),
619
+ finish_reason=None,
620
+ logprobs=None,
621
+ )
622
+ ],
623
+ )
624
+
625
+ # Format as SSE
626
+ reasoning_first_token = False
627
+ yield f"data: {chunk.model_dump_json()}\n\n".encode("utf-8")
628
+
629
+ # Send the [DONE] marker
630
+ yield b"data: [DONE]\n\n"
631
+
632
+ return StreamingResponse(
633
+ generate(),
634
+ media_type="text/event-stream",
635
+ headers={
636
+ "Cache-Control": "no-cache",
637
+ "Connection": "keep-alive",
638
+ },
639
+ )
640
+
641
+ # If streaming is not requested, collect all generated tokens into a single response
642
+ else:
643
+ full_response = "<think>" if reasoning_first_token else ""
644
+ async for token in self._generate_tokens(**generation_args):
645
+ full_response += token
646
+
647
+ # Extract tool calls from the response
648
+ openai_tool_calls = None
649
+ if chat_completion_request.tools:
650
+ tool_calls, full_response = extract_tool_calls(
651
+ full_response, self.tokenizer.auto_tokenizer.added_tokens_decoder
652
+ )
653
+ if tool_calls:
654
+ openai_tool_calls = []
655
+ for tool_call in tool_calls:
656
+ openai_tool_calls.append(
657
+ ChatCompletionMessageToolCall(
658
+ id="-",
659
+ function=Function(
660
+ arguments=json.dumps(tool_call["arguments"]),
661
+ name=tool_call["name"],
662
+ ),
663
+ type="function",
664
+ )
665
+ )
666
+
667
+ ccm = ChatCompletionMessage(
668
+ content=full_response,
669
+ role="assistant",
670
+ refusal=None,
671
+ audio=None,
672
+ function_call=None,
673
+ tool_calls=openai_tool_calls,
674
+ )
675
+
676
+ choice = Choice(
677
+ finish_reason="stop",
678
+ index=0,
679
+ message=ccm,
680
+ logprobs=None,
681
+ )
682
+
683
+ return ChatCompletion(
684
+ id="0",
685
+ choices=[choice],
686
+ model=self.llm_loaded.checkpoint,
687
+ object="chat.completion",
688
+ created=int(time.time()),
689
+ )
690
+
691
+ def apply_chat_template(
692
+ self, messages: list[dict], tools: list[dict] | None = None
693
+ ):
694
+ """
695
+ Apply the model's chat template to the messages.
696
+ """
697
+ if self.tokenizer.chat_template:
698
+
699
+ return self.tokenizer.apply_chat_template(
700
+ messages,
701
+ tokenize=False,
702
+ add_generation_prompt=True,
703
+ tools=tools,
704
+ )
705
+
706
+ # Fallback to a standardized template if the model doesn't provide one
707
+ logging.warning("No chat template found. Using default template.")
708
+ formatted_messages = []
709
+ for msg in messages:
710
+ role = msg.get("role", "user")
711
+ content = msg.get("content", "")
712
+ role_marker = "<|assistant|>" if role == "assistant" else "<|user|>"
713
+ formatted_messages.append(f"{role_marker}\n{content} <|end|>")
714
+ return "\n".join(formatted_messages) + "\n<|assistant|>"
715
+
716
+ async def responses(self, responses_request: ResponsesRequest):
717
+ """
718
+ Stream responses using HTTP chunked transfer encoding.
719
+ """
720
+
721
+ lc = self.initialize_load_config(responses_request)
722
+
723
+ # Load the model if it's different from the currently loaded one
724
+ await self.load_llm(lc, internal_call=True)
725
+
726
+ # Convert chat messages to text using the model's chat template
727
+ if isinstance(responses_request.input, str):
728
+ text = responses_request.input
729
+ else:
730
+ text = self.apply_chat_template(responses_request.input)
731
+
732
+ # If the model supports reasoning, we:
733
+ # 1. add a <think> tag to the model's context
734
+ # 2. ensure that the first token is a <think> token
735
+ reasoning_first_token = self.llm_loaded.reasoning
736
+
737
+ if reasoning_first_token:
738
+ text += "<think>"
739
+
740
+ # Prepare generation arguments
741
+ generation_args = {
742
+ "message": text,
743
+ "temperature": responses_request.temperature,
744
+ "max_new_tokens": responses_request.max_output_tokens,
745
+ }
746
+
747
+ if responses_request.stream:
748
+
749
+ # Stream the response
750
+ async def generate():
751
+ # Declare it's the same variable from outside scope
752
+ # This is necessary because the variable is modified
753
+ # in the inner function
754
+ nonlocal reasoning_first_token
755
+
756
+ # Send initial creation event
757
+ response = Response(
758
+ id="0",
759
+ model=self.llm_loaded.checkpoint,
760
+ created_at=int(time.time()),
761
+ object="response",
762
+ output=[],
763
+ parallel_tool_calls=True,
764
+ tool_choice="auto",
765
+ tools=[],
766
+ )
767
+ created_event = ResponseCreatedEvent(
768
+ response=response,
769
+ type="response.created",
770
+ )
771
+ yield f"data: {created_event.model_dump_json()}\n\n".encode("utf-8")
772
+
773
+ full_response = "<think>" if reasoning_first_token else ""
774
+
775
+ async for token in self._generate_tokens(**generation_args):
776
+
777
+ # Create an event
778
+ delta_event = ResponseTextDeltaEvent(
779
+ content_index=0,
780
+ delta=("<think>" + token if reasoning_first_token else token),
781
+ item_id="0 ",
782
+ output_index=0,
783
+ type="response.output_text.delta",
784
+ )
785
+ full_response += token
786
+
787
+ # Format as SSE
788
+ reasoning_first_token = False
789
+ yield f"data: {delta_event.model_dump_json()}\n\n".encode("utf-8")
790
+
791
+ # Send the completed event
792
+ response_output_message = ResponseOutputMessage(
793
+ id="0",
794
+ content=[
795
+ ResponseOutputText(
796
+ annotations=[],
797
+ text=full_response,
798
+ type="output_text",
799
+ )
800
+ ],
801
+ role="assistant",
802
+ status="completed",
803
+ type="message",
804
+ )
805
+ response = Response(
806
+ id="0",
807
+ model=self.llm_loaded.checkpoint,
808
+ created_at=int(time.time()),
809
+ object="response",
810
+ output=[response_output_message],
811
+ parallel_tool_calls=True,
812
+ tool_choice="auto",
813
+ tools=[],
814
+ )
815
+ completed_event = ResponseCompletedEvent(
816
+ response=response,
817
+ type="response.completed",
818
+ )
819
+ yield f"data: {completed_event.model_dump_json()}\n\n".encode("utf-8")
820
+
821
+ # Send the [DONE] marker
822
+ yield b"data: [DONE]\n\n"
823
+
824
+ return StreamingResponse(
825
+ generate(),
826
+ media_type="text/event-stream",
827
+ headers={
828
+ "Cache-Control": "no-cache",
829
+ "Connection": "keep-alive",
830
+ },
831
+ )
832
+
833
+ # If streaming is not requested, collect all generated tokens into a single response
834
+ else:
835
+ full_response = "<think>" if reasoning_first_token else ""
836
+ async for token in self._generate_tokens(**generation_args):
837
+ full_response += token
838
+
839
+ # Send the completed event
840
+ response_output_message = ResponseOutputMessage(
841
+ id="0",
842
+ content=[
843
+ ResponseOutputText(
844
+ annotations=[],
845
+ text=full_response,
846
+ type="output_text",
847
+ )
848
+ ],
849
+ role="assistant",
850
+ status="completed",
851
+ type="message",
852
+ )
853
+ return Response(
854
+ id="0",
855
+ model=self.llm_loaded.checkpoint,
856
+ created_at=int(time.time()),
857
+ object="response",
858
+ output=[response_output_message],
859
+ parallel_tool_calls=True,
860
+ tool_choice="auto",
861
+ tools=[],
862
+ )
863
+
864
+ async def _generate_tokens(
865
+ self,
866
+ message: str,
867
+ stop: list[str] | str | None = None,
868
+ max_new_tokens: int | None = None,
869
+ temperature: float | None = None,
870
+ ):
871
+ """
872
+ Core streaming completion logic, separated from response handling.
873
+ Returns an async generator that yields tokens.
874
+ """
875
+ model = self.model
876
+ tokenizer = self.tokenizer
877
+
878
+ # Reset the early-exit flag before we start each generation
879
+ self.stop_event.clear()
880
+
881
+ input_ids = tokenizer(message, return_tensors="pt").input_ids
882
+
883
+ # Process stop sequences
884
+ stop_sequences = []
885
+ if stop is not None:
886
+ if isinstance(stop, str):
887
+ stop_sequences = [stop]
888
+ else:
889
+ stop_sequences = stop[:4] # Limit to 4 sequences as per spec
890
+
891
+ # Set up the generation parameters
892
+ if "oga-" in self.llm_loaded.recipe:
893
+ from lemonade.tools.ort_genai.oga import OrtGenaiStreamer
894
+
895
+ streamer = OrtGenaiStreamer(tokenizer)
896
+ self.input_tokens = len(input_ids)
897
+ else:
898
+ streamer = TextIteratorStreamer(
899
+ tokenizer,
900
+ skip_prompt=True,
901
+ )
902
+ self.input_tokens = len(input_ids[0])
903
+
904
+ if (
905
+ self.llm_loaded.max_prompt_length
906
+ and self.input_tokens > self.llm_loaded.max_prompt_length
907
+ ):
908
+ if self.truncate_inputs:
909
+ # Truncate input ids
910
+ truncate_amount = self.input_tokens - self.llm_loaded.max_prompt_length
911
+ input_ids = input_ids[: self.llm_loaded.max_prompt_length]
912
+
913
+ # Update token count
914
+ self.input_tokens = len(input_ids)
915
+
916
+ # Show warning message
917
+ truncation_message = (
918
+ f"Input exceeded {self.llm_loaded.max_prompt_length} tokens. "
919
+ f"Truncated {truncate_amount} tokens."
920
+ )
921
+ logging.warning(truncation_message)
922
+ else:
923
+ raise RuntimeError(
924
+ f"Prompt tokens ({self.input_tokens}) cannot be greater "
925
+ f"than the model's max prompt length ({self.llm_loaded.max_prompt_length})"
926
+ )
927
+
928
+ # Log the input tokens early to avoid this not showing due to potential crashes
929
+ logging.debug(f"Input Tokens: {self.input_tokens}")
930
+ logging.trace(f"Input Message: {message}")
931
+
932
+ stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])
933
+
934
+ generation_kwargs = {
935
+ "input_ids": input_ids,
936
+ "streamer": streamer,
937
+ "max_new_tokens": (
938
+ max_new_tokens if max_new_tokens else DEFAULT_MAX_NEW_TOKENS
939
+ ),
940
+ "min_new_tokens": 1,
941
+ "pad_token_id": tokenizer.eos_token_id,
942
+ "stopping_criteria": stopping_criteria,
943
+ "temperature": temperature,
944
+ }
945
+
946
+ # Initialize performance variables
947
+ generation_start_time = time.perf_counter()
948
+ first_token = True
949
+ self.decode_token_times = []
950
+ self.output_tokens = 0
951
+
952
+ # Begin generation
953
+ thread = GeneratorThread(
954
+ streamer, target=model.generate, kwargs=generation_kwargs
955
+ )
956
+ thread.start()
957
+
958
+ # Acquire the generation semaphore
959
+ await self._generate_semaphore.acquire()
960
+ active_generations = (
961
+ self.max_concurrent_generations
962
+ - self._generate_semaphore._value # pylint: disable=protected-access
963
+ )
964
+
965
+ logging.debug(f"Active generations: {active_generations}")
966
+
967
+ try:
968
+ # Generate the response using streaming
969
+ new_text = ""
970
+ for new_text in streamer:
971
+ # Yield control back to the event loop
972
+ # This gives the FastAPI server a chance to send the chunks to the client
973
+ await asyncio.sleep(0)
974
+
975
+ # Capture performance stats about this token
976
+ self.output_tokens = self.output_tokens + 1
977
+ if first_token:
978
+ self.time_to_first_token = (
979
+ time.perf_counter() - generation_start_time
980
+ )
981
+ first_token = False
982
+ else:
983
+ self.decode_token_times.append(
984
+ time.perf_counter() - next_token_start_time
985
+ )
986
+ next_token_start_time = time.perf_counter()
987
+
988
+ # Remove the EOS token from the response if needed
989
+ if hasattr(self.tokenizer, "eos_token"):
990
+ new_text = new_text.replace(self.tokenizer.eos_token, "")
991
+
992
+ # Check for stop sequences
993
+ if stop_sequences:
994
+ for stop_seq in stop_sequences:
995
+ if stop_seq in new_text:
996
+ # Make sure we yield the text up to before the stop sequence
997
+ new_text = new_text[: new_text.find(stop_seq)]
998
+ self.stop_event.set()
999
+
1000
+ yield new_text
1001
+
1002
+ # Allow the user to finish the response early
1003
+ if self.stop_event.is_set():
1004
+ logging.info("Stopping generation early.")
1005
+ break
1006
+
1007
+ if len(self.decode_token_times) > 0:
1008
+ self.tokens_per_second = 1 / statistics.mean(self.decode_token_times)
1009
+ else:
1010
+ self.tokens_per_second = 0
1011
+
1012
+ finally:
1013
+ thread.join()
1014
+
1015
+ # Release the semaphore when generation is complete (or if an error occurs)
1016
+ self._generate_semaphore.release()
1017
+ active_generations = (
1018
+ self.max_concurrent_generations
1019
+ - self._generate_semaphore._value # pylint: disable=protected-access
1020
+ )
1021
+
1022
+ # Check if an exception occurred in the generation thread
1023
+ # If it did, raise it as an HTTPException so that the client
1024
+ # knows they wont be getting a completion
1025
+ if thread.exception:
1026
+ raise HTTPException(
1027
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1028
+ detail=f"Completion failure: {thread.exception}",
1029
+ )
1030
+
1031
+ # Display telemetry if in debug mode
1032
+ await self._show_telemetry()
1033
+
1034
+ async def send_stats(self):
1035
+ """
1036
+ Send performance statistics to the client.
1037
+ """
1038
+ return {
1039
+ "time_to_first_token": self.time_to_first_token,
1040
+ "tokens_per_second": self.tokens_per_second,
1041
+ "input_tokens": self.input_tokens,
1042
+ "output_tokens": self.output_tokens,
1043
+ "decode_token_times": self.decode_token_times,
1044
+ }
1045
+
1046
+ async def halt_generation(self):
1047
+ """
1048
+ Allow the client to halt an in-progress generation.
1049
+ """
1050
+
1051
+ self.stop_event.set()
1052
+
1053
+ return {
1054
+ "terminated": True,
1055
+ }
1056
+
1057
+ async def health(self):
1058
+ """
1059
+ Report server health information to the client.
1060
+ """
1061
+ self.stop_event.set()
1062
+
1063
+ return {
1064
+ "status": "ok",
1065
+ "checkpoint_loaded": (
1066
+ self.llm_loaded.checkpoint if self.llm_loaded else None
1067
+ ),
1068
+ "model_loaded": (
1069
+ self.llm_loaded.model_name
1070
+ if (self.llm_loaded and self.llm_loaded.model_name)
1071
+ else None
1072
+ ),
1073
+ }
1074
+
1075
+ def model_load_failure(self, model_reference: str, message: Optional[str] = None):
1076
+ """
1077
+ Clean up after a model load failure, then log it and raise
1078
+ an HTTPException with details.
1079
+ """
1080
+ self.llm_loaded = None
1081
+ self.tokenizer = None
1082
+ self.model = None
1083
+
1084
+ default_message = f"model {model_reference} not found"
1085
+ if message:
1086
+ detail = message
1087
+ else:
1088
+ detail = default_message
1089
+
1090
+ logging.exception(f"Tried to load LLM {model_reference} and failed: {detail}")
1091
+
1092
+ raise HTTPException(
1093
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
1094
+ detail=detail,
1095
+ )
1096
+
1097
+ def recipe_missing_error(self, model_reference: str):
1098
+ self.model_load_failure(
1099
+ model_reference,
1100
+ message=(
1101
+ f"Attempted to load model by checkpoint name {model_reference}, "
1102
+ "however the required 'recipe' parameter was not provided"
1103
+ ),
1104
+ )
1105
+
1106
+ async def pull(self, config: PullConfig):
1107
+ """
1108
+ Install a supported LLM by its Lemonade Model Name.
1109
+ """
1110
+
1111
+ # Install the model
1112
+ ModelManager().download_models([config.model_name])
1113
+
1114
+ # Refresh the list of downloaded models, to ensure it
1115
+ # includes the model we just installed
1116
+ self.local_models = ModelManager().downloaded_models_enabled
1117
+
1118
+ async def load_llm(self, config: LoadConfig, internal_call=False):
1119
+ """
1120
+ Load an LLM into system memory.
1121
+ config: the information required to load the model
1122
+ internal_call: indicates whether the call to this function came from
1123
+ an endpoint (False) or a method of this class (True)
1124
+
1125
+ There are 3 ways this method can be called:
1126
+ 1. An external application asks to load a model by name, using the load endpoint
1127
+ a. This only differs from #2 in that an external application may
1128
+ provide more parameters than in #2, so we need to validate
1129
+ that those parameters are ok.
1130
+ b. Load the model
1131
+
1132
+ 2. An external application asks to load a model by name,
1133
+ using the completions or chat_completions endpoints
1134
+ a. Look up the name in the built-in model dictionary to create
1135
+ a fully-populated LoadConfig.
1136
+ b. Load the model
1137
+
1138
+ 3. An external application asks to load a model by checkpoint and recipe,
1139
+ using the load endpoint
1140
+ a. Populate the checkpoint and recipe into a LoadConfig
1141
+ b. Load the model
1142
+
1143
+ 4. Completions or ChatCompletions asks to "load" a model by checkpoint
1144
+ a. This is only available when #3 has already been executed
1145
+ b. Verify that the checkpoint is already loaded,
1146
+ and raise an exception if it hasn't (don't load anything new)
1147
+ """
1148
+ try:
1149
+ await self._load_lock.acquire()
1150
+
1151
+ # Acquire all generate locks
1152
+ for _ in range(self.max_concurrent_generations):
1153
+ await self._generate_semaphore.acquire()
1154
+
1155
+ # We will populate a LoadConfig that has all of the required fields
1156
+ config_to_use: LoadConfig
1157
+
1158
+ # First, validate that the arguments are valid
1159
+ if config.model_name:
1160
+ # Get the dictionary of supported model from disk
1161
+ supported_models = ModelManager().supported_models
1162
+
1163
+ # Refer to the model by name, since we know the name
1164
+ model_reference = config.model_name
1165
+
1166
+ if config.checkpoint or config.recipe:
1167
+ # Option #1, verify that there are no parameter mismatches
1168
+ built_in_config = supported_models[config.model_name]
1169
+ if config.checkpoint != built_in_config["checkpoint"]:
1170
+ self.model_load_failure(
1171
+ model_reference,
1172
+ message=(
1173
+ f"Load request for model_name={config.model_name} "
1174
+ "included a mismatched "
1175
+ f"checkpoint={config.checkpoint} parameter. Remove the checkpoint "
1176
+ f"parameter, or change it to {built_in_config['checkpoint']}."
1177
+ ),
1178
+ )
1179
+ if config.recipe != built_in_config["recipe"]:
1180
+ self.model_load_failure(
1181
+ model_reference,
1182
+ message=(
1183
+ f"Load request for model_name={config.model_name} "
1184
+ "included a mismatched "
1185
+ f"recipe={config.recipe} parameter. Remove the checkpoint "
1186
+ f"parameter, or change it to {built_in_config['recipe']}."
1187
+ ),
1188
+ )
1189
+
1190
+ # Use the config as-is
1191
+ config_to_use = config
1192
+ else:
1193
+ # Option #2, look up the config from the supported models dictionary
1194
+ config_to_use = LoadConfig(**supported_models[config.model_name])
1195
+
1196
+ elif config.checkpoint:
1197
+ # Refer to the model by checkpoint
1198
+ model_reference = config.checkpoint
1199
+
1200
+ if config.recipe and not internal_call:
1201
+ # Option 3, use the config as-is, but add a custom model name
1202
+ config_to_use = config
1203
+ config_to_use.model_name = "Custom"
1204
+ elif internal_call:
1205
+ # Option 4, make sure the right checkpoint is loaded and then return
1206
+ if (
1207
+ self.llm_loaded
1208
+ and config.checkpoint == self.llm_loaded.checkpoint
1209
+ ):
1210
+ return {
1211
+ "status": "success",
1212
+ "message": f"Model already loaded: {model_reference}",
1213
+ }
1214
+ else:
1215
+ self.model_load_failure(
1216
+ model_reference,
1217
+ message=(
1218
+ "Attempted run completions by using model=<checkpoint name>, "
1219
+ "however, "
1220
+ "this feature only works if the model has already been loaded "
1221
+ "using the load endpoint."
1222
+ ),
1223
+ )
1224
+ else:
1225
+ self.recipe_missing_error(model_reference)
1226
+ else:
1227
+ self.model_load_failure(
1228
+ None,
1229
+ message="Load requests must contain either a model_name or a "
1230
+ "checkpoint parameter",
1231
+ )
1232
+
1233
+ # Caching mechanism: if the checkpoint is already loaded there is nothing else to do
1234
+ if (
1235
+ self.llm_loaded
1236
+ and config_to_use.checkpoint == self.llm_loaded.checkpoint
1237
+ ):
1238
+ return {
1239
+ "status": "success",
1240
+ "message": f"Model already loaded: {model_reference}",
1241
+ }
1242
+
1243
+ # Unload the current model if needed
1244
+ if self.llm_loaded:
1245
+ await self.unload_llm(require_lock=False)
1246
+
1247
+ logging.info(f"Loading llm: {model_reference}")
1248
+ try:
1249
+ self.model, self.tokenizer = lemonade_api.from_pretrained(
1250
+ checkpoint=config_to_use.checkpoint, recipe=config_to_use.recipe
1251
+ )
1252
+ self.llm_loaded = config_to_use
1253
+
1254
+ return {
1255
+ "status": "success",
1256
+ "message": f"Loaded model: {model_reference}",
1257
+ }
1258
+ except Exception: # pylint: disable=broad-exception-caught
1259
+ self.model_load_failure(model_reference)
1260
+
1261
+ finally:
1262
+ self._load_lock.release()
1263
+
1264
+ # Release all generate locks
1265
+ for _ in range(self.max_concurrent_generations):
1266
+ self._generate_semaphore.release()
1267
+
1268
+ # Refresh the list of downloaded models, to ensure it
1269
+ # includes the model we just loaded
1270
+ if config.model_name not in self.local_models:
1271
+ self.local_models = ModelManager().downloaded_models_enabled
1272
+
1273
+ async def unload_llm(self, require_lock: bool = True):
1274
+ try:
1275
+ if require_lock:
1276
+ await self._load_lock.acquire()
1277
+
1278
+ # Acquire all generate locks
1279
+ for _ in range(self.max_concurrent_generations):
1280
+ await self._generate_semaphore.acquire()
1281
+
1282
+ self.llm_loaded = None
1283
+ self.tokenizer = None
1284
+ self.model = None
1285
+ return {"status": "success", "message": "Unloaded model"}
1286
+ except Exception as e: # pylint: disable=broad-exception-caught
1287
+ return {
1288
+ "status": "error",
1289
+ "message": f"Failed to unload model: {str(e)}",
1290
+ }
1291
+ finally:
1292
+ if require_lock:
1293
+ self._load_lock.release()
1294
+
1295
+ # Release all generate locks
1296
+ for _ in range(self.max_concurrent_generations):
1297
+ self._generate_semaphore.release()
1298
+
1299
+ async def models(self):
1300
+ """
1301
+ Return a list of available models in OpenAI-compatible format.
1302
+ """
1303
+ models_list = []
1304
+ for model in self.local_models:
1305
+ m = ServerModel(
1306
+ id=model,
1307
+ owned_by="lemonade",
1308
+ object="model",
1309
+ created=int(time.time()),
1310
+ checkpoint=self.local_models[model]["checkpoint"],
1311
+ recipe=self.local_models[model]["recipe"],
1312
+ )
1313
+ models_list.append(m)
1314
+
1315
+ return {"object": "list", "data": models_list}
1316
+
1317
+ def setup_middleware_timer(self):
1318
+ logging.info("Middleware set up")
1319
+
1320
+ @self.app.middleware("http")
1321
+ async def log_request_time(request: Request, call_next):
1322
+ """
1323
+ Log the request processing time for any request.
1324
+ For streaming responses, wraps the body iterator to measure total time.
1325
+ Only applies the wrapper in debug mode.
1326
+ """
1327
+ start_time = time.perf_counter()
1328
+ response = await call_next(request)
1329
+
1330
+ if (
1331
+ self.debug_logging_enabled
1332
+ and hasattr(response, "body_iterator")
1333
+ and response.body_iterator is not None
1334
+ ):
1335
+ original_iterator = response.body_iterator
1336
+
1337
+ async def wrapped_iterator():
1338
+ async for chunk in original_iterator:
1339
+ yield chunk
1340
+ request_time = time.perf_counter() - start_time
1341
+ logging.debug(
1342
+ f"Total request time (streamed): {request_time:.4f} seconds"
1343
+ )
1344
+
1345
+ response.body_iterator = wrapped_iterator()
1346
+ else:
1347
+ request_time = time.perf_counter() - start_time
1348
+ if self.debug_logging_enabled:
1349
+ logging.debug(f"Total request time: {request_time:.4f} seconds")
1350
+ return response
1351
+
1352
+
1353
+ # This file was originally licensed under Apache 2.0. It has been modified.
1354
+ # Modifications Copyright (c) 2025 AMD