lemonade-sdk 8.0.2__py3-none-any.whl → 8.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lemonade-sdk might be problematic. Click here for more details.

lemonade/cli.py CHANGED
@@ -90,9 +90,9 @@ https://github.com/lemonade-sdk/lemonade/blob/main/docs/README.md""",
90
90
  )
91
91
 
92
92
  profiler_instances = [
93
- profiler(global_args[profiler.unique_name])
93
+ profiler(global_args[profiler.unique_name.replace("-", "_")])
94
94
  for profiler in profilers
95
- if global_args.get(profiler.unique_name, None) is not None
95
+ if global_args.get(profiler.unique_name.replace("-", "_"), None) is not None
96
96
  ]
97
97
 
98
98
  if len(evaluation_tools) > 0:
@@ -48,7 +48,10 @@ class Profiler(abc.ABC):
48
48
  This method is called so that the profiler can create its output files.
49
49
  The state is passed so that build info can be gathered and stats can be written.
50
50
  The timestamp can be used for filename in current working directory.
51
- The start times contain a list of tools and start times.
51
+ The start times parameter is a dict with the keys being the tools names and
52
+ the values being the time the tool started. There is an initial "warmup" key
53
+ that has a start time before the first tool and a "cool down" key that contains the
54
+ time when the last tool ended.
52
55
  """
53
56
 
54
57
 
@@ -24,7 +24,7 @@ class AccuracyHumaneval(Tool):
24
24
  - pass@10: Percentage of problems solved within 10 generation attempts
25
25
  - pass@100: Percentage of problems solved within 100 generation attempts
26
26
 
27
- See docs/lemonade/humaneval_accuracy.md for more details
27
+ See docs/dev_cli/humaneval_accuracy.md for more details
28
28
  """
29
29
 
30
30
  unique_name = "accuracy-humaneval"
lemonade/tools/mmlu.py CHANGED
@@ -27,7 +27,7 @@ def min_handle_none(*args: int):
27
27
 
28
28
  class AccuracyMMLU(Tool):
29
29
  """
30
- See docs/lemonade/mmlu_accuracy.md for more details
30
+ See docs/dev_cli/mmlu_accuracy.md for more details
31
31
  """
32
32
 
33
33
  unique_name = "accuracy-mmlu"
@@ -1,12 +1,6 @@
1
1
  # onnxruntime_genai is not lint-friendly yet and PyLint can't
2
2
  # find any of the class methods
3
3
  # pylint: disable=no-member
4
- #
5
- # Model builder constraints:
6
- # 11/10/24 Need transformers <4.45.0 OR onnxruntime-genai 0.5.0 (which must be built from source)
7
- # (transformers v4.45 changes the format of the tokenizer.json file which will be supported in
8
- # onnxruntime-genai 0.5)
9
- #
10
4
 
11
5
  import argparse
12
6
  import os
@@ -51,8 +45,8 @@ def import_error_heler(e: Exception):
51
45
  """
52
46
  raise ImportError(
53
47
  f"{e}\n Please install lemonade-sdk with "
54
- "one of the llm-oga extras, for example:\n"
55
- "pip install lemonade-sdk[llm-oga-cpu]\n"
48
+ "one of the oga extras, for example:\n"
49
+ "pip install lemonade-sdk[dev,oga-cpu]\n"
56
50
  "See https://lemonade_server.ai/install_options.html for details"
57
51
  )
58
52
 
@@ -64,7 +58,7 @@ class OgaLoad(FirstTool):
64
58
  Input: path to a checkpoint.
65
59
  Supported choices for cpu and igpu from HF model repository:
66
60
  LLM models on Huggingface supported by model_builder. See documentation
67
- (https://github.com/lemonade-sdk/lemonade/blob/main/docs/ort_genai_igpu.md)
61
+ (https://github.com/lemonade-sdk/lemonade/blob/main/docs/dev_cli/ort_genai_igpu.md)
68
62
  for supported models.
69
63
  Supported choices for npu from HF model repository:
70
64
  Models on Hugging Face that follow the "amd/**-onnx-ryzen-strix" pattern
@@ -17,7 +17,7 @@ class AccuracyPerplexity(Tool):
17
17
 
18
18
  Output state produced: None
19
19
 
20
- See docs/lemonade/perplexity.md for more details.
20
+ See docs/dev_cli/perplexity.md for more details.
21
21
  """
22
22
 
23
23
  unique_name = "accuracy-perplexity"
@@ -63,7 +63,7 @@ class AccuracyPerplexity(Tool):
63
63
  # try-except will allow a few more LLMs to work
64
64
  max_length = 2048
65
65
  # Set stride to half of the maximum input length for overlapping window processing
66
- # Refer to docs/perplexity.md for more information on sliding window
66
+ # Refer to docs/dev_cli/perplexity.md for more information on sliding window
67
67
  stride = max_length // 2
68
68
  # Determine the total sequence length of the tokenized input
69
69
  seq_len = encodings.input_ids.size(1)
lemonade/tools/prompt.py CHANGED
@@ -176,12 +176,21 @@ class LLMPrompt(Tool):
176
176
 
177
177
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
178
178
  if isinstance(input_ids, (list, str)):
179
- # OGA models return a list of tokens
179
+ # OGA models return a list of tokens (older versions)
180
180
  # Our llama.cpp adapter returns a string
181
181
  len_tokens_in = len(input_ids)
182
- else:
182
+ elif hasattr(input_ids, "shape"):
183
183
  # HF models return a 2-D tensor
184
- len_tokens_in = input_ids.shape[1]
184
+ # OGA models with newer versions may return numpy arrays
185
+ if len(input_ids.shape) == 1:
186
+ # 1-D array from newer OGA versions
187
+ len_tokens_in = len(input_ids)
188
+ else:
189
+ # 2-D tensor from HF models
190
+ len_tokens_in = input_ids.shape[1]
191
+ else:
192
+ # Fallback: try to get length directly
193
+ len_tokens_in = len(input_ids)
185
194
 
186
195
  len_tokens_out = []
187
196
  response_texts = []
@@ -202,9 +211,15 @@ class LLMPrompt(Tool):
202
211
  random_seed += 1
203
212
 
204
213
  # Flatten the input and response
205
- input_ids_array = (
206
- input_ids if isinstance(input_ids, (list, str)) else input_ids[0]
207
- )
214
+ if isinstance(input_ids, (list, str)):
215
+ input_ids_array = input_ids
216
+ elif hasattr(input_ids, "shape") and len(input_ids.shape) == 1:
217
+ # 1-D array from newer OGA versions - already flat
218
+ input_ids_array = input_ids
219
+ else:
220
+ # 2-D tensor from HF models - take first row
221
+ input_ids_array = input_ids[0]
222
+
208
223
  response_array = response if isinstance(response, str) else response[0]
209
224
 
210
225
  # Separate the prompt from the response
@@ -18,7 +18,7 @@ class QuarkLoad(Tool):
18
18
  Output:
19
19
  - state of the loaded model
20
20
 
21
- See docs/quark.md for more details.
21
+ See docs/dev_cli/quark.md for more details.
22
22
  """
23
23
 
24
24
  unique_name = "quark-load"
@@ -25,7 +25,7 @@ class QuarkQuantize(Tool):
25
25
  Output:
26
26
  - Modifies `state` with quantized and optionally exported model.
27
27
 
28
- See docs/quark.md for more details.
28
+ See docs/dev_cli/quark.md for more details.
29
29
  """
30
30
 
31
31
  unique_name = "quark-quantize"
@@ -94,7 +94,7 @@ class QuarkQuantize(Tool):
94
94
  help="Number of samples for calibration.",
95
95
  )
96
96
 
97
- # See docs/quark.md for more details.
97
+ # See docs/dev_cli/quark.md for more details.
98
98
  parser.add_argument(
99
99
  "--quant-scheme",
100
100
  type=str,
@@ -74,6 +74,7 @@ class SimpleStat(TableColumn):
74
74
  align="center",
75
75
  omit_if_lean=False,
76
76
  wrap=None,
77
+ stat_fn=None,
77
78
  ):
78
79
  self.column_header = column_header
79
80
  self.stat = stat
@@ -81,6 +82,7 @@ class SimpleStat(TableColumn):
81
82
  self.align = align
82
83
  self.omit_if_lean = omit_if_lean
83
84
  self.wrap = wrap or self.default_wrap
85
+ self.stat_fn = stat_fn
84
86
 
85
87
  def get_str(self, build_stats, lean=False):
86
88
  if lean and self.omit_if_lean:
@@ -88,6 +90,8 @@ class SimpleStat(TableColumn):
88
90
  data = build_stats.get(self.stat, None)
89
91
  if data is None:
90
92
  return ""
93
+ if self.stat_fn:
94
+ data = self.stat_fn(data)
91
95
  cell_str = "\n".join(
92
96
  [_wrap(f"{x:{self.format_str}}", self.wrap) for x in _to_list(data)]
93
97
  )
@@ -233,6 +237,47 @@ class AdditionalStat(TableColumn):
233
237
  return "\n".join(cell_entry)
234
238
 
235
239
 
240
+ class DictListStat(TableColumn):
241
+ """
242
+ A statistic that is a list of dicts and values from a given list of keys will be
243
+ pulled out of each dict and placed in the cell
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ column_header,
249
+ statistic_name,
250
+ key_format_list,
251
+ align="center",
252
+ omit_if_lean=False,
253
+ wrap=None,
254
+ ):
255
+ self.column_header = column_header
256
+ self.statistic_name = statistic_name
257
+ self.key_format_list = key_format_list
258
+ self.align = align
259
+ self.omit_if_lean = omit_if_lean
260
+ self.wrap = wrap or self.default_wrap
261
+
262
+ def get_str(self, build_stats, lean=False):
263
+ if lean and self.omit_if_lean:
264
+ return None
265
+ stat = build_stats.get(self.statistic_name, None)
266
+ if not stat:
267
+ return ""
268
+ cell_entry = []
269
+ for stat_dict in stat:
270
+ line = [
271
+ format_str.format(stat_dict[key])
272
+ for key, format_str in self.key_format_list
273
+ ]
274
+ cell_entry.append(" ".join(line))
275
+ return "\n".join(cell_entry)
276
+
277
+ def get_keys(self):
278
+ return [self.statistic_name]
279
+
280
+
236
281
  ################################################################################
237
282
  # ABSTRACT BASE CLASS FOR DEFINING A TABLE
238
283
  ################################################################################
@@ -350,6 +395,28 @@ class Table(ABC):
350
395
  headers.append(column.column_header)
351
396
  col_align += (column.align,)
352
397
 
398
+ # Stat column headers
399
+ stat_columns = self.table_descriptor.get("stat_columns", [])
400
+ stat_columns_include = []
401
+ for column in stat_columns:
402
+ # Check to see that at least one build has data for the column
403
+ keep_column = False
404
+ if not (self.lean and column.omit_if_lean):
405
+ keys = column.get_keys()
406
+ for build_stats in self.all_stats:
407
+ found = [(key in build_stats) for key in keys]
408
+ if any(found):
409
+ keep_column = True
410
+ headers.append(column.column_header)
411
+ col_align += (column.align,)
412
+ break
413
+ stat_columns_include.append(keep_column)
414
+ stat_columns = [
415
+ column
416
+ for column, include in zip(stat_columns, stat_columns_include)
417
+ if include
418
+ ]
419
+
353
420
  # Final headers
354
421
  last_columns = self.table_descriptor.get("last_columns", [])
355
422
  for column in last_columns:
@@ -386,6 +453,12 @@ class Table(ABC):
386
453
  if entry_str is not None:
387
454
  row.append(entry_str)
388
455
 
456
+ # Per stat columns
457
+ for entry in stat_columns:
458
+ entry_str = entry.get_str(build_stats, self.lean)
459
+ if entry_str is not None:
460
+ row.append(entry_str)
461
+
389
462
  # Final columns
390
463
  for entry in last_columns:
391
464
  entry_str = entry.get_str(build_stats, self.lean)
@@ -514,6 +587,12 @@ class LemonadePerfTable(Table):
514
587
  Keys.STD_DEV_TOKENS_PER_SECOND,
515
588
  ".2f",
516
589
  ),
590
+ SimpleStat(
591
+ _wrap("Total Generated Tokens", 9),
592
+ Keys.RESPONSE_TOKENS,
593
+ "d",
594
+ stat_fn=sum,
595
+ ),
517
596
  SimpleStat(
518
597
  _wrap("Memory Used (GB)", 8), Keys.MAX_MEMORY_USED_GBYTE, ".3f"
519
598
  ),
@@ -537,6 +616,7 @@ class LemonadePerfTable(Table):
537
616
  )
538
617
  ],
539
618
  },
619
+ "stat_columns": [],
540
620
  "last_columns": [
541
621
  SimpleStat(
542
622
  "System Info",
@@ -16,11 +16,29 @@ from fastapi.responses import StreamingResponse
16
16
 
17
17
  from openai import OpenAI
18
18
 
19
- from lemonade_server.pydantic_models import ChatCompletionRequest, PullConfig
19
+ from lemonade_server.pydantic_models import (
20
+ ChatCompletionRequest,
21
+ PullConfig,
22
+ EmbeddingsRequest,
23
+ RerankingRequest,
24
+ )
20
25
  from lemonade_server.model_manager import ModelManager
21
26
  from lemonade.tools.server.utils.port import find_free_port
22
27
 
23
- LLAMA_VERSION = "b5699"
28
+ LLAMA_VERSION = "b5787"
29
+
30
+
31
+ def llamacpp_address(port: int) -> str:
32
+ """
33
+ Generate the base URL for the llamacpp server.
34
+
35
+ Args:
36
+ port: The port number the llamacpp server is running on
37
+
38
+ Returns:
39
+ The base URL for the llamacpp server
40
+ """
41
+ return f"http://127.0.0.1:{port}/v1"
24
42
 
25
43
 
26
44
  def get_llama_server_paths():
@@ -210,15 +228,20 @@ def _log_subprocess_output(
210
228
  """
211
229
 
212
230
  if process.stdout:
213
- for line in iter(process.stdout.readline, ""):
214
- if line:
215
- line_stripped = line.strip()
216
- logging.debug("%s: %s", prefix, line_stripped)
231
+ try:
232
+ for line in iter(process.stdout.readline, ""):
233
+ if line:
234
+ line_stripped = line.strip()
235
+ logging.debug("%s: %s", prefix, line_stripped)
217
236
 
218
- telemetry.parse_telemetry_line(line_stripped)
237
+ telemetry.parse_telemetry_line(line_stripped)
219
238
 
220
- if process.poll() is not None:
221
- break
239
+ if process.poll() is not None:
240
+ break
241
+ except UnicodeDecodeError as e:
242
+ logging.debug("Unicode decode error reading subprocess output: %s", str(e))
243
+ except Exception as e: # pylint: disable=broad-exception-caught
244
+ logging.error("Unexpected error reading subprocess output: %s", str(e))
222
245
 
223
246
 
224
247
  def _wait_for_load(llama_server_process: subprocess.Popen, port: int):
@@ -239,10 +262,24 @@ def _wait_for_load(llama_server_process: subprocess.Popen, port: int):
239
262
 
240
263
 
241
264
  def _launch_llama_subprocess(
242
- snapshot_files: dict, use_gpu: bool, telemetry: LlamaTelemetry
265
+ snapshot_files: dict,
266
+ use_gpu: bool,
267
+ telemetry: LlamaTelemetry,
268
+ supports_embeddings: bool = False,
269
+ supports_reranking: bool = False,
243
270
  ) -> subprocess.Popen:
244
271
  """
245
- Launch llama server subprocess with GPU or CPU configuration
272
+ Launch llama server subprocess with appropriate configuration.
273
+
274
+ Args:
275
+ snapshot_files: Dictionary of model files to load
276
+ use_gpu: Whether to use GPU acceleration
277
+ telemetry: Telemetry object for tracking performance metrics
278
+ supports_embeddings: Whether the model supports embeddings
279
+ supports_reranking: Whether the model supports reranking
280
+
281
+ Returns:
282
+ Subprocess handle for the llama server
246
283
  """
247
284
 
248
285
  # Get the current executable path (handles both Windows and Ubuntu structures)
@@ -266,6 +303,14 @@ def _launch_llama_subprocess(
266
303
  # reasoning_content field
267
304
  base_command.extend(["--reasoning-format", "none"])
268
305
 
306
+ # Add embeddings support if the model supports it
307
+ if supports_embeddings:
308
+ base_command.append("--embeddings")
309
+
310
+ # Add reranking support if the model supports it
311
+ if supports_reranking:
312
+ base_command.append("--reranking")
313
+
269
314
  # Configure GPU layers: 99 for GPU, 0 for CPU-only
270
315
  ngl_value = "99" if use_gpu else "0"
271
316
  command = base_command + ["-ngl", ngl_value]
@@ -287,6 +332,8 @@ def _launch_llama_subprocess(
287
332
  stdout=subprocess.PIPE,
288
333
  stderr=subprocess.STDOUT,
289
334
  text=True,
335
+ encoding="utf-8",
336
+ errors="replace",
290
337
  bufsize=1,
291
338
  env=env,
292
339
  )
@@ -303,7 +350,6 @@ def _launch_llama_subprocess(
303
350
 
304
351
 
305
352
  def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
306
-
307
353
  # Validate platform support before proceeding
308
354
  validate_platform_support()
309
355
 
@@ -360,15 +406,26 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
360
406
  logging.info("Cleaned up zip file")
361
407
 
362
408
  # Download the gguf to the hugging face cache
363
- snapshot_files = ModelManager().download_gguf(model_config)
409
+ model_manager = ModelManager()
410
+ snapshot_files = model_manager.download_gguf(model_config)
364
411
  logging.debug(f"GGUF file paths: {snapshot_files}")
365
412
 
413
+ # Check if model supports embeddings
414
+ supported_models = model_manager.supported_models
415
+ model_info = supported_models.get(model_config.model_name, {})
416
+ supports_embeddings = "embeddings" in model_info.get("labels", [])
417
+ supports_reranking = "reranking" in model_info.get("labels", [])
418
+
366
419
  # Start the llama-serve.exe process
367
420
  logging.debug(f"Using llama_server for GGUF model: {llama_server_exe_path}")
368
421
 
369
422
  # Attempt loading on GPU first
370
423
  llama_server_process = _launch_llama_subprocess(
371
- snapshot_files, use_gpu=True, telemetry=telemetry
424
+ snapshot_files,
425
+ use_gpu=True,
426
+ telemetry=telemetry,
427
+ supports_embeddings=supports_embeddings,
428
+ supports_reranking=supports_reranking,
372
429
  )
373
430
 
374
431
  # Check the /health endpoint until GPU server is ready
@@ -383,8 +440,16 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
383
440
  f"Loading {model_config.model_name} on GPU didn't work, re-attempting on CPU"
384
441
  )
385
442
 
443
+ if os.environ.get("LEMONADE_LLAMACPP_NO_FALLBACK"):
444
+ # Used for testing, when the test should fail if GPU didn't work
445
+ raise Exception("llamacpp GPU loading failed")
446
+
386
447
  llama_server_process = _launch_llama_subprocess(
387
- snapshot_files, use_gpu=False, telemetry=telemetry
448
+ snapshot_files,
449
+ use_gpu=False,
450
+ telemetry=telemetry,
451
+ supports_embeddings=supports_embeddings,
452
+ supports_reranking=supports_reranking,
388
453
  )
389
454
 
390
455
  # Check the /health endpoint until CPU server is ready
@@ -405,7 +470,7 @@ def server_load(model_config: PullConfig, telemetry: LlamaTelemetry):
405
470
  def chat_completion(
406
471
  chat_completion_request: ChatCompletionRequest, telemetry: LlamaTelemetry
407
472
  ):
408
- base_url = f"http://127.0.0.1:{telemetry.port}/v1"
473
+ base_url = llamacpp_address(telemetry.port)
409
474
  client = OpenAI(
410
475
  base_url=base_url,
411
476
  api_key="lemonade",
@@ -456,3 +521,70 @@ def chat_completion(
456
521
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
457
522
  detail=f"Chat completion error: {str(e)}",
458
523
  )
524
+
525
+
526
+ def embeddings(embeddings_request: EmbeddingsRequest, telemetry: LlamaTelemetry):
527
+ """
528
+ Generate embeddings using the llamacpp server.
529
+
530
+ Args:
531
+ embeddings_request: The embeddings request containing input text/tokens
532
+ telemetry: Telemetry object containing the server port
533
+
534
+ Returns:
535
+ Embeddings response from the llamacpp server
536
+ """
537
+ base_url = llamacpp_address(telemetry.port)
538
+ client = OpenAI(
539
+ base_url=base_url,
540
+ api_key="lemonade",
541
+ )
542
+
543
+ # Convert Pydantic model to dict and remove unset/null values
544
+ request_dict = embeddings_request.model_dump(exclude_unset=True, exclude_none=True)
545
+
546
+ try:
547
+ # Call the embeddings endpoint
548
+ response = client.embeddings.create(**request_dict)
549
+ return response
550
+
551
+ except Exception as e: # pylint: disable=broad-exception-caught
552
+ raise HTTPException(
553
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
554
+ detail=f"Embeddings error: {str(e)}",
555
+ )
556
+
557
+
558
+ def reranking(reranking_request: RerankingRequest, telemetry: LlamaTelemetry):
559
+ """
560
+ Rerank documents based on their relevance to a query using the llamacpp server.
561
+
562
+ Args:
563
+ reranking_request: The reranking request containing query and documents
564
+ telemetry: Telemetry object containing the server port
565
+
566
+ Returns:
567
+ Reranking response from the llamacpp server containing ranked documents and scores
568
+ """
569
+ base_url = llamacpp_address(telemetry.port)
570
+
571
+ try:
572
+ # Convert Pydantic model to dict and exclude unset/null values
573
+ request_dict = reranking_request.model_dump(
574
+ exclude_unset=True, exclude_none=True
575
+ )
576
+
577
+ # Call the reranking endpoint directly since it's not supported by the OpenAI API
578
+ response = requests.post(
579
+ f"{base_url}/rerank",
580
+ json=request_dict,
581
+ )
582
+ response.raise_for_status()
583
+ return response.json()
584
+
585
+ except Exception as e:
586
+ logging.error("Error during reranking: %s", str(e))
587
+ raise HTTPException(
588
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
589
+ detail=f"Reranking error: {str(e)}",
590
+ ) from e
@@ -54,6 +54,8 @@ from lemonade_server.pydantic_models import (
54
54
  LoadConfig,
55
55
  CompletionRequest,
56
56
  ChatCompletionRequest,
57
+ EmbeddingsRequest,
58
+ RerankingRequest,
57
59
  ResponsesRequest,
58
60
  PullConfig,
59
61
  DeleteConfig,
@@ -231,8 +233,13 @@ class Server(ManagementTool):
231
233
 
232
234
  # OpenAI-compatible routes
233
235
  self.app.post(f"{prefix}/chat/completions")(self.chat_completions)
236
+ self.app.post(f"{prefix}/embeddings")(self.embeddings)
234
237
  self.app.get(f"{prefix}/models")(self.models)
235
238
 
239
+ # JinaAI routes (jina.ai/reranker/)
240
+ self.app.post(f"{prefix}/reranking")(self.reranking)
241
+ self.app.post(f"{prefix}/rerank")(self.reranking)
242
+
236
243
  @staticmethod
237
244
  def parser(add_help: bool = True) -> argparse.ArgumentParser:
238
245
  parser = __class__.helpful_parser(
@@ -796,6 +803,72 @@ class Server(ManagementTool):
796
803
  created=int(time.time()),
797
804
  )
798
805
 
806
+ async def embeddings(self, embeddings_request: EmbeddingsRequest):
807
+ """
808
+ Generate embeddings for the provided input.
809
+ """
810
+ # Initialize load config from embeddings request
811
+ lc = LoadConfig(model_name=embeddings_request.model)
812
+
813
+ # Load the model if it's different from the currently loaded one
814
+ await self.load_llm(lc)
815
+
816
+ if self.llm_loaded.recipe == "llamacpp":
817
+ try:
818
+ return llamacpp.embeddings(embeddings_request, self.llama_telemetry)
819
+ except Exception as e: # pylint: disable=broad-exception-caught
820
+ # Check if model has embeddings label
821
+ model_info = ModelManager().supported_models.get(
822
+ self.llm_loaded.model_name, {}
823
+ )
824
+ if "embeddings" not in model_info.get("labels", []):
825
+ raise HTTPException(
826
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
827
+ detail="You tried to generate embeddings for a model that is "
828
+ "not labeled as an embeddings model. Please use another model "
829
+ "or re-register the current model with the 'embeddings' label.",
830
+ ) from e
831
+ else:
832
+ raise e
833
+ else:
834
+ raise HTTPException(
835
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
836
+ detail=f"Embeddings not supported for recipe: {self.llm_loaded.recipe}",
837
+ )
838
+
839
+ async def reranking(self, reranking_request: RerankingRequest):
840
+ """
841
+ Rerank documents based on their relevance to a query using the llamacpp server.
842
+ """
843
+ # Initialize load config from reranking request
844
+ lc = LoadConfig(model_name=reranking_request.model)
845
+
846
+ # Load the model if it's different from the currently loaded one
847
+ await self.load_llm(lc)
848
+
849
+ if self.llm_loaded.recipe == "llamacpp":
850
+ try:
851
+ return llamacpp.reranking(reranking_request, self.llama_telemetry)
852
+ except Exception as e: # pylint: disable=broad-exception-caught
853
+ # Check if model has reranking label
854
+ model_info = ModelManager().supported_models.get(
855
+ self.llm_loaded.model_name, {}
856
+ )
857
+ if "reranking" not in model_info.get("labels", []):
858
+ raise HTTPException(
859
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
860
+ detail="You tried to use reranking for a model that is "
861
+ "not labeled as a reranking model. Please use another model "
862
+ "or re-register the current model with the 'reranking' label.",
863
+ ) from e
864
+ else:
865
+ raise e
866
+ else:
867
+ raise HTTPException(
868
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
869
+ detail=f"Reranking not supported for recipe: {self.llm_loaded.recipe}",
870
+ )
871
+
799
872
  def apply_chat_template(
800
873
  self, messages: list[dict], tools: list[dict] | None = None
801
874
  ):