llama-stack 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. llama_stack/apis/benchmarks/benchmarks.py +8 -0
  2. llama_stack/apis/scoring_functions/scoring_functions.py +8 -0
  3. llama_stack/cli/stack/_build.py +7 -0
  4. llama_stack/cli/verify_download.py +7 -10
  5. llama_stack/core/datatypes.py +10 -3
  6. llama_stack/core/library_client.py +0 -2
  7. llama_stack/core/routers/__init__.py +4 -1
  8. llama_stack/core/routers/inference.py +12 -7
  9. llama_stack/core/routing_tables/benchmarks.py +4 -0
  10. llama_stack/core/routing_tables/common.py +4 -0
  11. llama_stack/core/routing_tables/scoring_functions.py +4 -0
  12. llama_stack/distributions/ci-tests/build.yaml +1 -0
  13. llama_stack/distributions/ci-tests/run.yaml +7 -0
  14. llama_stack/distributions/starter/build.yaml +1 -0
  15. llama_stack/distributions/starter/run.yaml +7 -0
  16. llama_stack/distributions/starter/starter.py +18 -0
  17. llama_stack/distributions/starter-gpu/build.yaml +1 -0
  18. llama_stack/distributions/starter-gpu/run.yaml +7 -0
  19. llama_stack/distributions/watsonx/run.yaml +9 -0
  20. llama_stack/distributions/watsonx/watsonx.py +10 -2
  21. llama_stack/providers/inline/eval/meta_reference/eval.py +7 -0
  22. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +3 -0
  23. llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +6 -6
  24. llama_stack/providers/inline/tool_runtime/rag/memory.py +101 -46
  25. llama_stack/providers/registry/batches.py +1 -1
  26. llama_stack/providers/registry/inference.py +22 -11
  27. llama_stack/providers/registry/scoring.py +1 -1
  28. llama_stack/providers/remote/eval/nvidia/eval.py +11 -2
  29. llama_stack/providers/remote/inference/azure/__init__.py +15 -0
  30. llama_stack/providers/remote/inference/azure/azure.py +64 -0
  31. llama_stack/providers/remote/inference/azure/config.py +63 -0
  32. llama_stack/providers/remote/inference/azure/models.py +28 -0
  33. llama_stack/providers/remote/inference/bedrock/bedrock.py +49 -2
  34. llama_stack/providers/remote/inference/tgi/tgi.py +43 -15
  35. llama_stack/providers/remote/inference/together/models.py +70 -44
  36. llama_stack/providers/remote/inference/together/together.py +79 -130
  37. llama_stack/providers/remote/inference/vertexai/vertexai.py +29 -4
  38. llama_stack/providers/remote/inference/vllm/vllm.py +11 -186
  39. llama_stack/providers/remote/inference/watsonx/config.py +2 -2
  40. llama_stack/providers/remote/inference/watsonx/watsonx.py +18 -2
  41. llama_stack/providers/utils/inference/inference_store.py +129 -19
  42. llama_stack/providers/utils/inference/openai_mixin.py +53 -8
  43. llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +14 -0
  44. llama_stack/providers/utils/telemetry/tracing.py +24 -10
  45. llama_stack/providers/utils/vector_io/vector_utils.py +2 -4
  46. llama_stack/testing/inference_recorder.py +43 -32
  47. {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/METADATA +5 -5
  48. {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/RECORD +52 -48
  49. {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/WHEEL +0 -0
  50. {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/entry_points.txt +0 -0
  51. {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/licenses/LICENSE +0 -0
  52. {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/top_level.txt +0 -0
@@ -93,3 +93,11 @@ class Benchmarks(Protocol):
93
93
  :param metadata: The metadata to use for the benchmark.
94
94
  """
95
95
  ...
96
+
97
+ @webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE")
98
+ async def unregister_benchmark(self, benchmark_id: str) -> None:
99
+ """Unregister a benchmark.
100
+
101
+ :param benchmark_id: The ID of the benchmark to unregister.
102
+ """
103
+ ...
@@ -197,3 +197,11 @@ class ScoringFunctions(Protocol):
197
197
  :param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
198
198
  """
199
199
  ...
200
+
201
+ @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE")
202
+ async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
203
+ """Unregister a scoring function.
204
+
205
+ :param scoring_fn_id: The ID of the scoring function to unregister.
206
+ """
207
+ ...
@@ -45,6 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
45
45
  from llama_stack.core.utils.exec import formulate_run_args, run_command
46
46
  from llama_stack.core.utils.image_types import LlamaStackImageType
47
47
  from llama_stack.providers.datatypes import Api
48
+ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
48
49
 
49
50
  DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
50
51
 
@@ -294,6 +295,12 @@ def _generate_run_config(
294
295
  if build_config.external_providers_dir
295
296
  else EXTERNAL_PROVIDERS_DIR,
296
297
  )
298
+ if not run_config.inference_store:
299
+ run_config.inference_store = SqliteSqlStoreConfig(
300
+ **SqliteSqlStoreConfig.sample_run_config(
301
+ __distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
302
+ )
303
+ )
297
304
  # build providers dict
298
305
  provider_registry = get_provider_registry(build_config)
299
306
  for api in apis:
@@ -48,15 +48,12 @@ def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
48
48
  parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
49
49
 
50
50
 
51
- def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
52
- # NOTE: MD5 is used here only for download integrity verification,
53
- # not for security purposes
54
- # TODO: switch to SHA256
55
- md5_hash = hashlib.md5(usedforsecurity=False)
51
+ def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
52
+ sha256_hash = hashlib.sha256()
56
53
  with open(filepath, "rb") as f:
57
54
  for chunk in iter(lambda: f.read(chunk_size), b""):
58
- md5_hash.update(chunk)
59
- return md5_hash.hexdigest()
55
+ sha256_hash.update(chunk)
56
+ return sha256_hash.hexdigest()
60
57
 
61
58
 
62
59
  def load_checksums(checklist_path: Path) -> dict[str, str]:
@@ -64,10 +61,10 @@ def load_checksums(checklist_path: Path) -> dict[str, str]:
64
61
  with open(checklist_path) as f:
65
62
  for line in f:
66
63
  if line.strip():
67
- md5sum, filepath = line.strip().split(" ", 1)
64
+ sha256sum, filepath = line.strip().split(" ", 1)
68
65
  # Remove leading './' if present
69
66
  filepath = filepath.lstrip("./")
70
- checksums[filepath] = md5sum
67
+ checksums[filepath] = sha256sum
71
68
  return checksums
72
69
 
73
70
 
@@ -88,7 +85,7 @@ def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -
88
85
  matches = False
89
86
 
90
87
  if exists:
91
- actual_hash = calculate_md5(full_path)
88
+ actual_hash = calculate_sha256(full_path)
92
89
  matches = actual_hash == expected_hash
93
90
 
94
91
  results.append(
@@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
431
431
  )
432
432
 
433
433
 
434
+ class InferenceStoreConfig(BaseModel):
435
+ sql_store_config: SqlStoreConfig
436
+ max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
437
+ num_writers: int = Field(default=4, description="Number of concurrent background writers")
438
+
439
+
434
440
  class StackRunConfig(BaseModel):
435
441
  version: int = LLAMA_STACK_RUN_CONFIG_VERSION
436
442
 
@@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
464
470
  a default SQLite store will be used.""",
465
471
  )
466
472
 
467
- inference_store: SqlStoreConfig | None = Field(
473
+ inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
468
474
  default=None,
469
475
  description="""
470
- Configuration for the persistence store used by the inference API. If not specified,
471
- a default SQLite store will be used.""",
476
+ Configuration for the persistence store used by the inference API. Can be either a
477
+ InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
478
+ If not specified, a default SQLite store will be used.""",
472
479
  )
473
480
 
474
481
  # registry of "resources" in the distribution
@@ -10,7 +10,6 @@ import json
10
10
  import logging # allow-direct-logging
11
11
  import os
12
12
  import sys
13
- from concurrent.futures import ThreadPoolExecutor
14
13
  from enum import Enum
15
14
  from io import BytesIO
16
15
  from pathlib import Path
@@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
148
147
  self.async_client = AsyncLlamaStackAsLibraryClient(
149
148
  config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
150
149
  )
151
- self.pool_executor = ThreadPoolExecutor(max_workers=4)
152
150
  self.provider_data = provider_data
153
151
 
154
152
  self.loop = asyncio.new_event_loop()
@@ -78,7 +78,10 @@ async def get_auto_router_impl(
78
78
 
79
79
  # TODO: move pass configs to routers instead
80
80
  if api == Api.inference and run_config.inference_store:
81
- inference_store = InferenceStore(run_config.inference_store, policy)
81
+ inference_store = InferenceStore(
82
+ config=run_config.inference_store,
83
+ policy=policy,
84
+ )
82
85
  await inference_store.initialize()
83
86
  api_to_dep_impl["store"] = inference_store
84
87
 
@@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
63
63
  from llama_stack.models.llama.llama3.tokenizer import Tokenizer
64
64
  from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
65
65
  from llama_stack.providers.utils.inference.inference_store import InferenceStore
66
- from llama_stack.providers.utils.telemetry.tracing import get_current_span
66
+ from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
67
67
 
68
68
  logger = get_logger(name=__name__, category="core::routers")
69
69
 
@@ -90,6 +90,11 @@ class InferenceRouter(Inference):
90
90
 
91
91
  async def shutdown(self) -> None:
92
92
  logger.debug("InferenceRouter.shutdown")
93
+ if self.store:
94
+ try:
95
+ await self.store.shutdown()
96
+ except Exception as e:
97
+ logger.warning(f"Error during InferenceStore shutdown: {e}")
93
98
 
94
99
  async def register_model(
95
100
  self,
@@ -160,7 +165,7 @@ class InferenceRouter(Inference):
160
165
  metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
161
166
  if self.telemetry:
162
167
  for metric in metrics:
163
- await self.telemetry.log_event(metric)
168
+ enqueue_event(metric)
164
169
  return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
165
170
 
166
171
  async def _count_tokens(
@@ -431,7 +436,7 @@ class InferenceRouter(Inference):
431
436
  model=model_obj,
432
437
  )
433
438
  for metric in metrics:
434
- await self.telemetry.log_event(metric)
439
+ enqueue_event(metric)
435
440
 
436
441
  # these metrics will show up in the client response.
437
442
  response.metrics = (
@@ -537,7 +542,7 @@ class InferenceRouter(Inference):
537
542
  model=model_obj,
538
543
  )
539
544
  for metric in metrics:
540
- await self.telemetry.log_event(metric)
545
+ enqueue_event(metric)
541
546
  # these metrics will show up in the client response.
542
547
  response.metrics = (
543
548
  metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
@@ -664,7 +669,7 @@ class InferenceRouter(Inference):
664
669
  "completion_tokens",
665
670
  "total_tokens",
666
671
  ]: # Only log completion and total tokens
667
- await self.telemetry.log_event(metric)
672
+ enqueue_event(metric)
668
673
 
669
674
  # Return metrics in response
670
675
  async_metrics = [
@@ -710,7 +715,7 @@ class InferenceRouter(Inference):
710
715
  )
711
716
  for metric in completion_metrics:
712
717
  if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
713
- await self.telemetry.log_event(metric)
718
+ enqueue_event(metric)
714
719
 
715
720
  # Return metrics in response
716
721
  return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
@@ -806,7 +811,7 @@ class InferenceRouter(Inference):
806
811
  model=model,
807
812
  )
808
813
  for metric in metrics:
809
- await self.telemetry.log_event(metric)
814
+ enqueue_event(metric)
810
815
 
811
816
  yield chunk
812
817
  finally:
@@ -56,3 +56,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
56
56
  provider_resource_id=provider_benchmark_id,
57
57
  )
58
58
  await self.register_object(benchmark)
59
+
60
+ async def unregister_benchmark(self, benchmark_id: str) -> None:
61
+ existing_benchmark = await self.get_benchmark(benchmark_id)
62
+ await self.unregister_object(existing_benchmark)
@@ -64,6 +64,10 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
64
64
  return await p.unregister_shield(obj.identifier)
65
65
  elif api == Api.datasetio:
66
66
  return await p.unregister_dataset(obj.identifier)
67
+ elif api == Api.eval:
68
+ return await p.unregister_benchmark(obj.identifier)
69
+ elif api == Api.scoring:
70
+ return await p.unregister_scoring_function(obj.identifier)
67
71
  elif api == Api.tool_runtime:
68
72
  return await p.unregister_toolgroup(obj.identifier)
69
73
  else:
@@ -60,3 +60,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
60
60
  )
61
61
  scoring_fn.provider_id = provider_id
62
62
  await self.register_object(scoring_fn)
63
+
64
+ async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
65
+ existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
66
+ await self.unregister_object(existing_scoring_fn)
@@ -17,6 +17,7 @@ distribution_spec:
17
17
  - provider_type: remote::vertexai
18
18
  - provider_type: remote::groq
19
19
  - provider_type: remote::sambanova
20
+ - provider_type: remote::azure
20
21
  - provider_type: inline::sentence-transformers
21
22
  vector_io:
22
23
  - provider_type: inline::faiss
@@ -81,6 +81,13 @@ providers:
81
81
  config:
82
82
  url: https://api.sambanova.ai/v1
83
83
  api_key: ${env.SAMBANOVA_API_KEY:=}
84
+ - provider_id: ${env.AZURE_API_KEY:+azure}
85
+ provider_type: remote::azure
86
+ config:
87
+ api_key: ${env.AZURE_API_KEY:=}
88
+ api_base: ${env.AZURE_API_BASE:=}
89
+ api_version: ${env.AZURE_API_VERSION:=}
90
+ api_type: ${env.AZURE_API_TYPE:=}
84
91
  - provider_id: sentence-transformers
85
92
  provider_type: inline::sentence-transformers
86
93
  vector_io:
@@ -18,6 +18,7 @@ distribution_spec:
18
18
  - provider_type: remote::vertexai
19
19
  - provider_type: remote::groq
20
20
  - provider_type: remote::sambanova
21
+ - provider_type: remote::azure
21
22
  - provider_type: inline::sentence-transformers
22
23
  vector_io:
23
24
  - provider_type: inline::faiss
@@ -81,6 +81,13 @@ providers:
81
81
  config:
82
82
  url: https://api.sambanova.ai/v1
83
83
  api_key: ${env.SAMBANOVA_API_KEY:=}
84
+ - provider_id: ${env.AZURE_API_KEY:+azure}
85
+ provider_type: remote::azure
86
+ config:
87
+ api_key: ${env.AZURE_API_KEY:=}
88
+ api_base: ${env.AZURE_API_BASE:=}
89
+ api_version: ${env.AZURE_API_VERSION:=}
90
+ api_type: ${env.AZURE_API_TYPE:=}
84
91
  - provider_id: sentence-transformers
85
92
  provider_type: inline::sentence-transformers
86
93
  vector_io:
@@ -59,6 +59,7 @@ ENABLED_INFERENCE_PROVIDERS = [
59
59
  "cerebras",
60
60
  "nvidia",
61
61
  "bedrock",
62
+ "azure",
62
63
  ]
63
64
 
64
65
  INFERENCE_PROVIDER_IDS = {
@@ -68,6 +69,7 @@ INFERENCE_PROVIDER_IDS = {
68
69
  "cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
69
70
  "nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
70
71
  "vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}",
72
+ "azure": "${env.AZURE_API_KEY:+azure}",
71
73
  }
72
74
 
73
75
 
@@ -277,5 +279,21 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
277
279
  "http://localhost:11434",
278
280
  "Ollama URL",
279
281
  ),
282
+ "AZURE_API_KEY": (
283
+ "",
284
+ "Azure API Key",
285
+ ),
286
+ "AZURE_API_BASE": (
287
+ "",
288
+ "Azure API Base",
289
+ ),
290
+ "AZURE_API_VERSION": (
291
+ "",
292
+ "Azure API Version",
293
+ ),
294
+ "AZURE_API_TYPE": (
295
+ "azure",
296
+ "Azure API Type",
297
+ ),
280
298
  },
281
299
  )
@@ -18,6 +18,7 @@ distribution_spec:
18
18
  - provider_type: remote::vertexai
19
19
  - provider_type: remote::groq
20
20
  - provider_type: remote::sambanova
21
+ - provider_type: remote::azure
21
22
  - provider_type: inline::sentence-transformers
22
23
  vector_io:
23
24
  - provider_type: inline::faiss
@@ -81,6 +81,13 @@ providers:
81
81
  config:
82
82
  url: https://api.sambanova.ai/v1
83
83
  api_key: ${env.SAMBANOVA_API_KEY:=}
84
+ - provider_id: ${env.AZURE_API_KEY:+azure}
85
+ provider_type: remote::azure
86
+ config:
87
+ api_key: ${env.AZURE_API_KEY:=}
88
+ api_base: ${env.AZURE_API_BASE:=}
89
+ api_version: ${env.AZURE_API_VERSION:=}
90
+ api_type: ${env.AZURE_API_TYPE:=}
84
91
  - provider_id: sentence-transformers
85
92
  provider_type: inline::sentence-transformers
86
93
  vector_io:
@@ -10,6 +10,7 @@ apis:
10
10
  - telemetry
11
11
  - tool_runtime
12
12
  - vector_io
13
+ - files
13
14
  providers:
14
15
  inference:
15
16
  - provider_id: watsonx
@@ -94,6 +95,14 @@ providers:
94
95
  provider_type: inline::rag-runtime
95
96
  - provider_id: model-context-protocol
96
97
  provider_type: remote::model-context-protocol
98
+ files:
99
+ - provider_id: meta-reference-files
100
+ provider_type: inline::localfs
101
+ config:
102
+ storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/watsonx/files}
103
+ metadata_store:
104
+ type: sqlite
105
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/files_metadata.db
97
106
  metadata_store:
98
107
  type: sqlite
99
108
  db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db
@@ -9,6 +9,7 @@ from pathlib import Path
9
9
  from llama_stack.apis.models import ModelType
10
10
  from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
11
11
  from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
12
+ from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
12
13
  from llama_stack.providers.inline.inference.sentence_transformers import (
13
14
  SentenceTransformersInferenceConfig,
14
15
  )
@@ -16,7 +17,7 @@ from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
16
17
  from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
17
18
 
18
19
 
19
- def get_distribution_template() -> DistributionTemplate:
20
+ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
20
21
  providers = {
21
22
  "inference": [
22
23
  BuildProvider(provider_type="remote::watsonx"),
@@ -42,6 +43,7 @@ def get_distribution_template() -> DistributionTemplate:
42
43
  BuildProvider(provider_type="inline::rag-runtime"),
43
44
  BuildProvider(provider_type="remote::model-context-protocol"),
44
45
  ],
46
+ "files": [BuildProvider(provider_type="inline::localfs")],
45
47
  }
46
48
 
47
49
  inference_provider = Provider(
@@ -79,9 +81,14 @@ def get_distribution_template() -> DistributionTemplate:
79
81
  },
80
82
  )
81
83
 
84
+ files_provider = Provider(
85
+ provider_id="meta-reference-files",
86
+ provider_type="inline::localfs",
87
+ config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
88
+ )
82
89
  default_models, _ = get_model_registry(available_models)
83
90
  return DistributionTemplate(
84
- name="watsonx",
91
+ name=name,
85
92
  distro_type="remote_hosted",
86
93
  description="Use watsonx for running LLM inference",
87
94
  container_image=None,
@@ -92,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
92
99
  "run.yaml": RunConfigSettings(
93
100
  provider_overrides={
94
101
  "inference": [inference_provider, embedding_provider],
102
+ "files": [files_provider],
95
103
  },
96
104
  default_models=default_models + [embedding_model],
97
105
  default_tool_groups=default_tool_groups,
@@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
75
75
  )
76
76
  self.benchmarks[task_def.identifier] = task_def
77
77
 
78
+ async def unregister_benchmark(self, benchmark_id: str) -> None:
79
+ if benchmark_id in self.benchmarks:
80
+ del self.benchmarks[benchmark_id]
81
+
82
+ key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
83
+ await self.kvstore.delete(key)
84
+
78
85
  async def run_eval(
79
86
  self,
80
87
  benchmark_id: str,
@@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
63
63
  async def register_scoring_function(self, function_def: ScoringFn) -> None:
64
64
  self.llm_as_judge_fn.register_scoring_fn_def(function_def)
65
65
 
66
+ async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
67
+ self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
68
+
66
69
  async def score_batch(
67
70
  self,
68
71
  dataset_id: str,
@@ -8,7 +8,7 @@
8
8
  from jinja2 import Template
9
9
 
10
10
  from llama_stack.apis.common.content_types import InterleavedContent
11
- from llama_stack.apis.inference import UserMessage
11
+ from llama_stack.apis.inference import OpenAIUserMessageParam
12
12
  from llama_stack.apis.tools.rag_tool import (
13
13
  DefaultRAGQueryGeneratorConfig,
14
14
  LLMRAGQueryGeneratorConfig,
@@ -61,16 +61,16 @@ async def llm_rag_query_generator(
61
61
  messages = [interleaved_content_as_str(content)]
62
62
 
63
63
  template = Template(config.template)
64
- content = template.render({"messages": messages})
64
+ rendered_content: str = template.render({"messages": messages})
65
65
 
66
66
  model = config.model
67
- message = UserMessage(content=content)
68
- response = await inference_api.chat_completion(
69
- model_id=model,
67
+ message = OpenAIUserMessageParam(content=rendered_content)
68
+ response = await inference_api.openai_chat_completion(
69
+ model=model,
70
70
  messages=[message],
71
71
  stream=False,
72
72
  )
73
73
 
74
- query = response.completion_message.content
74
+ query = response.choices[0].message.content
75
75
 
76
76
  return query