rasa-pro 3.12.0rc2__py3-none-any.whl → 3.12.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (43) hide show
  1. rasa/cli/dialogue_understanding_test.py +5 -8
  2. rasa/cli/llm_fine_tuning.py +47 -12
  3. rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
  4. rasa/core/channels/voice_stream/audiocodes.py +19 -6
  5. rasa/core/channels/voice_stream/call_state.py +3 -9
  6. rasa/core/channels/voice_stream/genesys.py +40 -55
  7. rasa/core/channels/voice_stream/voice_channel.py +61 -39
  8. rasa/core/tracker_store.py +123 -34
  9. rasa/dialogue_understanding/commands/set_slot_command.py +1 -0
  10. rasa/dialogue_understanding/commands/utils.py +1 -4
  11. rasa/dialogue_understanding/generator/command_parser.py +41 -0
  12. rasa/dialogue_understanding/generator/constants.py +7 -2
  13. rasa/dialogue_understanding/generator/llm_based_command_generator.py +9 -2
  14. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +29 -48
  15. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_fallback_other_models_template.jinja2 +57 -0
  16. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +23 -50
  17. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +76 -24
  18. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +32 -18
  19. rasa/dialogue_understanding/processor/command_processor.py +39 -19
  20. rasa/dialogue_understanding/stack/utils.py +11 -6
  21. rasa/engine/language.py +67 -25
  22. rasa/llm_fine_tuning/conversations.py +3 -31
  23. rasa/llm_fine_tuning/llm_data_preparation_module.py +5 -3
  24. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +18 -13
  25. rasa/llm_fine_tuning/paraphrasing_module.py +6 -2
  26. rasa/llm_fine_tuning/train_test_split_module.py +27 -27
  27. rasa/llm_fine_tuning/utils.py +7 -0
  28. rasa/shared/constants.py +4 -0
  29. rasa/shared/core/domain.py +2 -0
  30. rasa/shared/providers/_configs/azure_entra_id_config.py +8 -8
  31. rasa/shared/providers/llm/litellm_router_llm_client.py +1 -0
  32. rasa/shared/providers/router/_base_litellm_router_client.py +38 -7
  33. rasa/shared/utils/llm.py +69 -13
  34. rasa/telemetry.py +13 -3
  35. rasa/tracing/instrumentation/attribute_extractors.py +2 -5
  36. rasa/validator.py +2 -2
  37. rasa/version.py +1 -1
  38. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/METADATA +1 -1
  39. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/RECORD +42 -41
  40. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +0 -68
  41. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/NOTICE +0 -0
  42. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/WHEEL +0 -0
  43. {rasa_pro-3.12.0rc2.dist-info → rasa_pro-3.12.0rc3.dist-info}/entry_points.txt +0 -0
@@ -4,8 +4,10 @@ from typing import Any, Dict, List, Optional
4
4
  import structlog
5
5
  from tqdm import tqdm
6
6
 
7
+ from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
7
8
  from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
8
9
  from rasa.llm_fine_tuning.storage import StorageContext
10
+ from rasa.llm_fine_tuning.utils import commands_as_string
9
11
 
10
12
  LLM_DATA_PREPARATION_MODULE_STORAGE_LOCATION = "3_llm_finetune_data/llm_ft_data.jsonl"
11
13
 
@@ -15,7 +17,7 @@ structlogger = structlog.get_logger()
15
17
  @dataclass
16
18
  class LLMDataExample:
17
19
  prompt: str
18
- output: str
20
+ output: List[PromptCommand]
19
21
  original_test_name: str
20
22
  original_user_utterance: str
21
23
  rephrased_user_utterance: str
@@ -23,7 +25,7 @@ class LLMDataExample:
23
25
  def as_dict(self) -> Dict[str, Any]:
24
26
  return {
25
27
  "prompt": self.prompt,
26
- "output": self.output,
28
+ "output": commands_as_string(self.output),
27
29
  "original_test_name": self.original_test_name,
28
30
  "original_user_utterance": self.original_user_utterance,
29
31
  "rephrased_user_utterance": self.rephrased_user_utterance,
@@ -38,7 +40,7 @@ def _create_data_point(
38
40
  ) -> LLMDataExample:
39
41
  return LLMDataExample(
40
42
  prompt,
41
- step.commands_as_string(),
43
+ step.llm_commands,
42
44
  conversation.get_full_name(),
43
45
  step.original_test_step.text,
44
46
  rephrased_user_message,
@@ -1,18 +1,18 @@
1
- from typing import Any, Dict, List
1
+ from typing import Any, Dict, List, Type
2
2
 
3
3
  import structlog
4
4
 
5
5
  from rasa.dialogue_understanding.commands import Command, SetSlotCommand
6
- from rasa.dialogue_understanding.generator import SingleStepLLMCommandGenerator
6
+ from rasa.dialogue_understanding.generator.llm_based_command_generator import (
7
+ LLMBasedCommandGenerator,
8
+ )
7
9
  from rasa.llm_fine_tuning.conversations import Conversation, ConversationStep
8
10
  from rasa.llm_fine_tuning.paraphrasing.rephrased_user_message import (
9
11
  RephrasedUserMessage,
10
12
  )
11
13
  from rasa.shared.core.flows import FlowsList
12
14
  from rasa.shared.exceptions import ProviderClientAPIException
13
- from rasa.shared.utils.llm import (
14
- llm_factory,
15
- )
15
+ from rasa.shared.utils.llm import llm_factory
16
16
 
17
17
  structlogger = structlog.get_logger()
18
18
 
@@ -26,6 +26,7 @@ class RephraseValidator:
26
26
  self,
27
27
  rephrasings: List[RephrasedUserMessage],
28
28
  conversation: Conversation,
29
+ llm_command_generator: Type[LLMBasedCommandGenerator],
29
30
  ) -> List[RephrasedUserMessage]:
30
31
  """Split rephrased user messages into passing and failing.
31
32
 
@@ -38,6 +39,7 @@ class RephraseValidator:
38
39
  Args:
39
40
  rephrasings: The rephrased user messages.
40
41
  conversation: The conversation.
42
+ llm_command_generator: A LLM based command generator class.
41
43
 
42
44
  Returns:
43
45
  A list of rephrased user messages including the passing and failing
@@ -49,7 +51,9 @@ class RephraseValidator:
49
51
  current_rephrasings = rephrasings[i]
50
52
 
51
53
  for rephrase in current_rephrasings.rephrasings:
52
- if await self._validate_rephrase_is_passing(rephrase, step):
54
+ if await self._validate_rephrase_is_passing(
55
+ rephrase, step, llm_command_generator
56
+ ):
53
57
  current_rephrasings.passed_rephrasings.append(rephrase)
54
58
  else:
55
59
  current_rephrasings.failed_rephrasings.append(rephrase)
@@ -60,25 +64,26 @@ class RephraseValidator:
60
64
  self,
61
65
  rephrase: str,
62
66
  step: ConversationStep,
67
+ llm_command_generator: Type[LLMBasedCommandGenerator],
63
68
  ) -> bool:
64
69
  prompt = self._update_prompt(
65
70
  rephrase, step.original_test_step.text, step.llm_prompt
66
71
  )
67
72
 
68
- action_list = await self._invoke_llm(prompt)
73
+ action_list = await self._invoke_llm(
74
+ prompt, llm_command_generator.get_default_llm_config()
75
+ )
69
76
 
70
77
  commands_from_original_utterance = step.llm_commands
71
- commands_from_rephrased_utterance = (
72
- SingleStepLLMCommandGenerator.parse_commands(action_list, None, self.flows)
78
+ commands_from_rephrased_utterance = llm_command_generator.parse_commands( # type: ignore
79
+ action_list, None, self.flows
73
80
  )
74
81
  return self._check_commands_match(
75
82
  commands_from_original_utterance, commands_from_rephrased_utterance
76
83
  )
77
84
 
78
- async def _invoke_llm(self, prompt: str) -> str:
79
- from rasa.dialogue_understanding.generator.constants import DEFAULT_LLM_CONFIG
80
-
81
- llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
85
+ async def _invoke_llm(self, prompt: str, default_llm_config: Dict[str, Any]) -> str:
86
+ llm = llm_factory(self.llm_config, default_llm_config)
82
87
 
83
88
  try:
84
89
  llm_response = await llm.acompletion(prompt)
@@ -1,8 +1,11 @@
1
- from typing import Any, Dict, List, Tuple
1
+ from typing import Any, Dict, List, Tuple, Type
2
2
 
3
3
  import structlog
4
4
  from tqdm import tqdm
5
5
 
6
+ from rasa.dialogue_understanding.generator.llm_based_command_generator import (
7
+ LLMBasedCommandGenerator,
8
+ )
6
9
  from rasa.llm_fine_tuning.conversations import Conversation
7
10
  from rasa.llm_fine_tuning.paraphrasing.conversation_rephraser import (
8
11
  ConversationRephraser,
@@ -25,6 +28,7 @@ async def create_paraphrased_conversations(
25
28
  rephrase_config: Dict[str, Any],
26
29
  num_rephrases: int,
27
30
  flows: FlowsList,
31
+ llm_command_generator: Type[LLMBasedCommandGenerator],
28
32
  llm_command_generator_config: Dict[str, Any],
29
33
  storage_context: StorageContext,
30
34
  ) -> Tuple[List[Conversation], Dict[str, Any]]:
@@ -71,7 +75,7 @@ async def create_paraphrased_conversations(
71
75
  rephrasings = _filter_rephrasings(rephrasings, conversations[i])
72
76
  # check if the rephrasings are still producing the same commands
73
77
  rephrasings = await validator.validate_rephrasings(
74
- rephrasings, current_conversation
78
+ rephrasings, current_conversation, llm_command_generator
75
79
  )
76
80
  except ProviderClientAPIException as e:
77
81
  structlogger.error(
@@ -1,27 +1,18 @@
1
1
  import random
2
2
  from collections import defaultdict
3
3
  from dataclasses import dataclass
4
- from typing import Any, Dict, List, Protocol, Set, Tuple
4
+ from typing import Any, Dict, List, Protocol, Set, Tuple, Type
5
5
 
6
6
  import structlog
7
7
 
8
+ from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
8
9
  from rasa.e2e_test.e2e_test_case import TestSuite
9
10
  from rasa.llm_fine_tuning.llm_data_preparation_module import LLMDataExample
10
11
  from rasa.llm_fine_tuning.storage import StorageContext
12
+ from rasa.llm_fine_tuning.utils import commands_as_string
11
13
 
12
14
  TRAIN_TEST_MODULE_STORAGE_LOCATION = "4_train_test_split"
13
15
 
14
- SUPPORTED_COMMANDS = [
15
- "SetSlot",
16
- "StartFlow",
17
- "CancelFlow",
18
- "ChitChat",
19
- "SkipQuestion",
20
- "SearchAndReply",
21
- "HumanHandoff",
22
- "Clarify",
23
- ]
24
-
25
16
  INSTRUCTION_DATA_FORMAT = "instruction"
26
17
  CONVERSATIONAL_DATA_FORMAT = "conversational"
27
18
 
@@ -77,17 +68,19 @@ class ConversationalDataFormat(DataExampleFormat):
77
68
  }
78
69
 
79
70
 
80
- def _get_command_types_covered_by_llm_data_point(commands: LLMDataExample) -> Set[str]:
71
+ def _get_command_types_covered_by_llm_data_point(
72
+ data_point: LLMDataExample,
73
+ ) -> Set[Type[PromptCommand]]:
81
74
  """Get the command types covered by the LLM data point.
82
75
 
83
76
  This function returns the set of command types from the output present in a
84
- LLMDataExample object. Eg: The function returns {'SetSlot', 'StartFlow'} when the
85
- LLMDataExample.output is 'SetSlot(slot, abc), SetSlot(slot, cde), StartFlow(xyz)'.
77
+ LLMDataExample object. Eg: The function returns {'SetSlotCommand',
78
+ 'StartFlowCommand'} when the LLMDataExample.output is 'SetSlotCommand(slot, abc),
79
+ SetSlotCommand(slot, cde), StartFlowCommand(xyz)'.
86
80
  """
87
81
  commands_covered = set()
88
- for command in SUPPORTED_COMMANDS:
89
- if command in commands.output:
90
- commands_covered.add(command)
82
+ for command in data_point.output:
83
+ commands_covered.add(command.__class__)
91
84
  return commands_covered
92
85
 
93
86
 
@@ -146,14 +139,18 @@ def _get_minimum_test_case_groups_to_cover_all_commands(
146
139
  {
147
140
  "test_case_name": "t1",
148
141
  "data_examples": [],
149
- "commands": {"SetSlot", "CancelFlow"}
142
+ "commands": {"SetSlotCommand", "CancelFlowCommand"}
150
143
  },
151
- {"test_case_name": "t2", "data_examples": [], "commands": {"CancelFlow"}},
152
- {"test_case_name": "t3", "data_examples": [], "commands": {"StartFlow"}},
144
+ {
145
+ "test_case_name": "t2",
146
+ "data_examples": [],
147
+ "commands": {"CancelFlowCommand"}
148
+ },
149
+ {"test_case_name": "t3", "data_examples": [], "commands": {"StartFlowCommand"}},
153
150
  {
154
151
  "test_case_name": "t4",
155
152
  "data_examples": [],
156
- "commands": {"SetSlot", "StartFlow"}
153
+ "commands": {"SetSlotCommand", "StartFlowCommand"}
157
154
  },
158
155
  ]
159
156
 
@@ -166,7 +163,7 @@ def _get_minimum_test_case_groups_to_cover_all_commands(
166
163
  command for test_group in grouped_data for command in test_group[KEY_COMMANDS]
167
164
  )
168
165
  selected_test_cases = []
169
- covered_commands: Set[str] = set()
166
+ covered_commands: Set[Type[PromptCommand]] = set()
170
167
 
171
168
  while covered_commands != all_commands:
172
169
  # Find the test case group that covers the most number of uncovered commands
@@ -187,7 +184,7 @@ def _get_minimum_test_case_groups_to_cover_all_commands(
187
184
 
188
185
  structlogger.info(
189
186
  "llm_fine_tuning.train_test_split_module.command_coverage_in_train_dataset",
190
- covered_commands=covered_commands,
187
+ covered_commands=[command.__name__ for command in covered_commands],
191
188
  )
192
189
  return selected_test_cases
193
190
 
@@ -205,7 +202,10 @@ def _get_finetuning_data_in_instruction_data_format(
205
202
  data: List[Dict[str, Any]],
206
203
  ) -> List[DataExampleFormat]:
207
204
  return [
208
- InstructionDataFormat(llm_data_example.prompt, llm_data_example.output)
205
+ InstructionDataFormat(
206
+ llm_data_example.prompt,
207
+ commands_as_string(llm_data_example.output),
208
+ )
209
209
  for test_group in data
210
210
  for llm_data_example in test_group[KEY_DATA_EXAMPLES]
211
211
  ]
@@ -232,7 +232,7 @@ def _get_finetuning_data_in_conversational_data_format(
232
232
  [
233
233
  ConversationalMessageDataFormat("user", llm_data_example.prompt),
234
234
  ConversationalMessageDataFormat(
235
- "assistant", llm_data_example.output
235
+ "assistant", commands_as_string(llm_data_example.output)
236
236
  ),
237
237
  ]
238
238
  )
@@ -271,7 +271,7 @@ def _check_and_log_missing_validation_dataset_command_coverage(
271
271
  structlogger.warning(
272
272
  "llm_fine_tuning.train_test_split_module.missing_commands_in_validation_dat"
273
273
  "aset",
274
- missing_commands=missing_commands,
274
+ missing_commands=[command.__name__ for command in missing_commands],
275
275
  )
276
276
 
277
277
 
@@ -0,0 +1,7 @@
1
+ from typing import List
2
+
3
+ from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
4
+
5
+
6
+ def commands_as_string(commands: List[PromptCommand], delimiter: str = "\n") -> str:
7
+ return delimiter.join([command.to_dsl() for command in commands])
rasa/shared/constants.py CHANGED
@@ -194,6 +194,9 @@ PROVIDER_CONFIG_KEY = "provider"
194
194
  REQUEST_TIMEOUT_CONFIG_KEY = "request_timeout" # deprecated
195
195
  TIMEOUT_CONFIG_KEY = "timeout"
196
196
 
197
+ TEMPERATURE_CONFIG_KEY = "temperature"
198
+ MAX_TOKENS_CONFIG_KEY = "max_tokens"
199
+
197
200
  DEPLOYMENT_NAME_CONFIG_KEY = "deployment_name"
198
201
  DEPLOYMENT_CONFIG_KEY = "deployment"
199
202
  EMBEDDINGS_CONFIG_KEY = "embeddings"
@@ -264,6 +267,7 @@ LITELLM_SSL_CERTIFICATE_ENV_VAR = "SSL_CERTIFICATE"
264
267
 
265
268
  OPENAI_PROVIDER = "openai"
266
269
  AZURE_OPENAI_PROVIDER = "azure"
270
+ ANTHROPIC_PROVIDER = "anthropic"
267
271
  SELF_HOSTED_PROVIDER = "self-hosted"
268
272
  HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER = "huggingface_local"
269
273
  RASA_PROVIDER = "rasa"
@@ -52,6 +52,7 @@ from rasa.shared.core.constants import (
52
52
  SlotMappingType,
53
53
  )
54
54
  from rasa.shared.core.events import SlotSet, UserUttered
55
+ from rasa.shared.core.flows.constants import KEY_TRANSLATION
55
56
  from rasa.shared.core.slots import (
56
57
  AnySlot,
57
58
  CategoricalSlot,
@@ -117,6 +118,7 @@ RESPONSE_KEYS_TO_INTERPOLATE = [
117
118
  KEY_RESPONSES_BUTTONS,
118
119
  KEY_RESPONSES_ATTACHMENT,
119
120
  KEY_RESPONSES_QUICK_REPLIES,
121
+ KEY_TRANSLATION,
120
122
  ]
121
123
 
122
124
  ALL_DOMAIN_KEYS = [
@@ -8,7 +8,7 @@ from functools import lru_cache
8
8
  from typing import Any, Callable, Dict, List, Optional, Set, Type
9
9
 
10
10
  import structlog
11
- from azure.core.credentials import TokenProvider
11
+ from azure.core.credentials import TokenCredential
12
12
  from azure.identity import (
13
13
  CertificateCredential,
14
14
  ClientSecretCredential,
@@ -77,7 +77,7 @@ class AzureEntraIDTokenProviderConfig(abc.ABC):
77
77
  """Interface for Azure Entra ID OAuth credential configuration."""
78
78
 
79
79
  @abc.abstractmethod
80
- def create_azure_token_provider(self) -> TokenProvider:
80
+ def create_azure_token_provider(self) -> TokenCredential:
81
81
  """Create an Azure Entra ID token provider."""
82
82
  ...
83
83
 
@@ -159,7 +159,7 @@ class AzureEntraIDClientCredentialsConfig(AzureEntraIDTokenProviderConfig, BaseM
159
159
  ),
160
160
  )
161
161
 
162
- def create_azure_token_provider(self) -> TokenProvider:
162
+ def create_azure_token_provider(self) -> TokenCredential:
163
163
  """Create a ClientSecretCredential for Azure Entra ID."""
164
164
  return create_azure_entra_id_client_credentials(
165
165
  client_id=self.client_id,
@@ -286,7 +286,7 @@ class AzureEntraIDClientCertificateConfig(AzureEntraIDTokenProviderConfig, BaseM
286
286
  ),
287
287
  )
288
288
 
289
- def create_azure_token_provider(self) -> TokenProvider:
289
+ def create_azure_token_provider(self) -> TokenCredential:
290
290
  """Creates a CertificateCredential for Azure Entra ID."""
291
291
  return create_azure_entra_id_certificate_credentials(
292
292
  client_id=self.client_id,
@@ -369,7 +369,7 @@ class AzureEntraIDDefaultCredentialsConfig(AzureEntraIDTokenProviderConfig, Base
369
369
  """
370
370
  return cls(authority_host=config.pop(AZURE_AUTHORITY_FIELD, None))
371
371
 
372
- def create_azure_token_provider(self) -> TokenProvider:
372
+ def create_azure_token_provider(self) -> TokenCredential:
373
373
  """Creates a DefaultAzureCredential."""
374
374
  return create_azure_entra_id_default_credentials(
375
375
  authority_host=self.authority_host
@@ -530,12 +530,12 @@ class AzureEntraIDOAuthConfig(OAuth, BaseModel):
530
530
  azure_oauth_class = AzureEntraIDOAuthConfig._get_azure_oauth_by_type(oauth_type)
531
531
  return azure_oauth_class.from_dict(oauth_config)
532
532
 
533
- def _create_azure_credential(
533
+ def create_azure_credential(
534
534
  self,
535
- ) -> TokenProvider:
535
+ ) -> TokenCredential:
536
536
  """Create an Azure Entra ID client which can be used to get a bearer token."""
537
537
  return self.azure_entra_id_token_provider_config.create_azure_token_provider()
538
538
 
539
539
  def get_bearer_token(self) -> str:
540
540
  """Returns a bearer token."""
541
- return self._create_azure_credential().get_token(*self.scopes).token # type: ignore
541
+ return self.create_azure_credential().get_token(*self.scopes).token
@@ -198,6 +198,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
198
198
  """Returns the completion arguments for invoking a call through
199
199
  LiteLLM's completion functions.
200
200
  """
201
+
201
202
  return {
202
203
  **self._litellm_extra_parameters,
203
204
  LITE_LLM_MODEL_FIELD: self.model_group_id,
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
+ from copy import deepcopy
4
5
  from typing import Any, Dict, List
5
6
 
6
7
  import structlog
@@ -18,6 +19,7 @@ from rasa.shared.constants import (
18
19
  USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
19
20
  )
20
21
  from rasa.shared.exceptions import ProviderClientValidationError
22
+ from rasa.shared.providers._configs.azure_entra_id_config import AzureEntraIDOAuthConfig
21
23
  from rasa.shared.providers._configs.litellm_router_client_config import (
22
24
  LiteLLMRouterClientConfig,
23
25
  )
@@ -61,12 +63,8 @@ class _BaseLiteLLMRouterClient:
61
63
  self._extra_parameters = kwargs or {}
62
64
  self.additional_client_setup()
63
65
  try:
64
- resolved_model_configurations = (
65
- self._resolve_env_vars_in_model_configurations()
66
- )
67
- self._router_client = Router(
68
- model_list=resolved_model_configurations, **router_settings
69
- )
66
+ # We instantiate a router client here to validate the configuration.
67
+ self._router_client = self._create_router_client()
70
68
  except Exception as e:
71
69
  event_info = "Cannot instantiate a router client."
72
70
  structlogger.error(
@@ -145,6 +143,14 @@ class _BaseLiteLLMRouterClient:
145
143
  @property
146
144
  def router_client(self) -> Router:
147
145
  """Returns the instantiated LiteLLM Router client."""
146
+ # In ca se oauth is used, due to a bug in LiteLLM,
147
+ # azure_ad_token_provider is not working as expected.
148
+ # To work around this, we create a new client every
149
+ # time we need to make a call which will
150
+ # ensure that the token is always fresh.
151
+ # GitHub issue for LiteLLm: https://github.com/BerriAI/litellm/issues/4417
152
+ if self._has_oauth():
153
+ return self._create_router_client()
148
154
  return self._router_client
149
155
 
150
156
  @property
@@ -175,11 +181,36 @@ class _BaseLiteLLMRouterClient:
175
181
  **self._litellm_extra_parameters,
176
182
  }
177
183
 
184
+ def _create_router_client(self) -> Router:
185
+ resolved_model_configurations = self._resolve_env_vars_in_model_configurations()
186
+ return Router(model_list=resolved_model_configurations, **self.router_settings)
187
+
188
+ def _has_oauth(self) -> bool:
189
+ for model_configuration in self.model_configurations:
190
+ if model_configuration.get("litellm_params", {}).get("oauth", None):
191
+ return True
192
+ return False
193
+
178
194
  def _resolve_env_vars_in_model_configurations(self) -> List:
179
195
  model_configuration_with_resolved_keys = []
180
196
  for model_configuration in self.model_configurations:
181
197
  resolved_model_configuration = resolve_environment_variables(
182
- model_configuration
198
+ deepcopy(model_configuration)
183
199
  )
200
+
201
+ if not isinstance(resolved_model_configuration, dict):
202
+ continue
203
+
204
+ lite_llm_params = resolved_model_configuration.get("litellm_params", {})
205
+ if lite_llm_params.get("oauth", None):
206
+ oauth_config_dict = lite_llm_params.pop("oauth")
207
+ oauth_config = AzureEntraIDOAuthConfig.from_dict(oauth_config_dict)
208
+ credential = oauth_config.create_azure_credential()
209
+ # token_provider = get_bearer_token_provider(
210
+ # credential, *oauth_config.scopes
211
+ # )
212
+ resolved_model_configuration["litellm_params"]["azure_ad_token"] = (
213
+ credential.get_token(*oauth_config.scopes).token
214
+ )
184
215
  model_configuration_with_resolved_keys.append(resolved_model_configuration)
185
216
  return model_configuration_with_resolved_keys
rasa/shared/utils/llm.py CHANGED
@@ -667,38 +667,94 @@ def get_prompt_template(
667
667
  """
668
668
  try:
669
669
  if jinja_file_path is not None:
670
- return rasa.shared.utils.io.read_file(jinja_file_path)
670
+ prompt_template = rasa.shared.utils.io.read_file(jinja_file_path)
671
+ structlogger.info(
672
+ "utils.llm.get_prompt_template.custom_prompt_template_read_successfull",
673
+ event_info=(
674
+ f"Custom prompt template read successfully from "
675
+ f"`{jinja_file_path}`."
676
+ ),
677
+ prompt_file_path=jinja_file_path,
678
+ )
679
+ return prompt_template
671
680
  except (FileIOException, FileNotFoundException):
672
681
  structlogger.warning(
673
- "Failed to read custom prompt template. Using default template instead.",
674
- jinja_file_path=jinja_file_path,
682
+ "utils.llm.get_prompt_template.failed_to_read_custom_prompt_template",
683
+ event_info=(
684
+ "Failed to read custom prompt template. Using default template instead."
685
+ ),
675
686
  )
676
687
  return default_prompt_template
677
688
 
678
689
 
679
690
  def get_default_prompt_template_based_on_model(
680
- config: Dict[str, Any],
691
+ llm_config: Dict[str, Any],
681
692
  model_prompt_mapping: Dict[str, Any],
693
+ default_prompt_path: str,
682
694
  fallback_prompt_path: str,
683
695
  ) -> Text:
684
696
  """Returns the default prompt template based on the model name.
685
697
 
686
698
  Args:
687
- config: The model config.
699
+ llm_config: The model config.
688
700
  model_prompt_mapping: The mapping of model name to prompt template.
689
- fallback_prompt_path: The fallback prompt path.
701
+ default_prompt_path: The default prompt path of the component.
702
+ fallback_prompt_path: The fallback prompt path for all other models
703
+ that do not have a mapping in the model_prompt_mapping.
690
704
 
691
705
  Returns:
692
706
  The default prompt template.
693
707
  """
694
- _config = deepcopy(config)
695
- if MODELS_CONFIG_KEY in _config:
696
- _config = _config[MODELS_CONFIG_KEY][0]
697
- provider = _config.get(PROVIDER_CONFIG_KEY)
698
- model = _config.get(MODEL_CONFIG_KEY, "")
708
+ _llm_config = deepcopy(llm_config)
709
+ if MODELS_CONFIG_KEY in _llm_config:
710
+ _llm_config = _llm_config[MODELS_CONFIG_KEY][0]
711
+ provider = _llm_config.get(PROVIDER_CONFIG_KEY)
712
+ model = _llm_config.get(MODEL_CONFIG_KEY)
713
+ if not model:
714
+ # If the model is not defined, we default to the default prompt template.
715
+ structlogger.info(
716
+ "utils.llm.get_default_prompt_template_based_on_model.using_default_prompt_template",
717
+ event_info=(
718
+ f"Model not defined in the config. Default prompt template read from"
719
+ f" - `{default_prompt_path}`."
720
+ ),
721
+ default_prompt_path=default_prompt_path,
722
+ )
723
+ return importlib.resources.read_text(
724
+ DEFAULT_PROMPT_PACKAGE_NAME, default_prompt_path
725
+ )
726
+
699
727
  model_name = model if provider and provider in model else f"{provider}/{model}"
700
- prompt_file_path = model_prompt_mapping.get(model_name, fallback_prompt_path)
701
- return importlib.resources.read_text(DEFAULT_PROMPT_PACKAGE_NAME, prompt_file_path)
728
+ if prompt_file_path := model_prompt_mapping.get(model_name):
729
+ # If the model is found in the mapping, we use the model-specific prompt
730
+ # template.
731
+ structlogger.info(
732
+ "utils.llm.get_default_prompt_template_based_on_model.using_model_specific_prompt_template",
733
+ event_info=(
734
+ f"Using model-specific default prompt template. Default prompt "
735
+ f"template read from - `{prompt_file_path}`."
736
+ ),
737
+ default_prompt_path=prompt_file_path,
738
+ model_name=model_name,
739
+ )
740
+ return importlib.resources.read_text(
741
+ DEFAULT_PROMPT_PACKAGE_NAME, prompt_file_path
742
+ )
743
+
744
+ # If the model is not found in the mapping, we default to the fallback prompt
745
+ # template.
746
+ structlogger.info(
747
+ "utils.llm.get_default_prompt_template_based_on_model.using_fallback_prompt_template",
748
+ event_info=(
749
+ f"Model not found in the model prompt mapping. Fallback prompt template "
750
+ f"read from - `{fallback_prompt_path}`."
751
+ ),
752
+ fallback_prompt_path=fallback_prompt_path,
753
+ model_name=model_name,
754
+ )
755
+ return importlib.resources.read_text(
756
+ DEFAULT_PROMPT_PACKAGE_NAME, fallback_prompt_path
757
+ )
702
758
 
703
759
 
704
760
  def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
rasa/telemetry.py CHANGED
@@ -15,7 +15,7 @@ from collections import defaultdict
15
15
  from datetime import datetime
16
16
  from functools import wraps
17
17
  from pathlib import Path
18
- from typing import Any, Callable, Dict, List, Optional, Text, Tuple
18
+ from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Type, cast
19
19
 
20
20
  import importlib_resources
21
21
  import requests
@@ -1126,12 +1126,12 @@ def _get_llm_command_generator_config(config: Dict[str, Any]) -> Optional[Dict]:
1126
1126
  """
1127
1127
  from rasa.dialogue_understanding.generator import (
1128
1128
  CompactLLMCommandGenerator,
1129
+ LLMBasedCommandGenerator,
1129
1130
  LLMCommandGenerator,
1130
1131
  MultiStepLLMCommandGenerator,
1131
1132
  SingleStepLLMCommandGenerator,
1132
1133
  )
1133
1134
  from rasa.dialogue_understanding.generator.constants import (
1134
- DEFAULT_LLM_CONFIG,
1135
1135
  FLOW_RETRIEVAL_KEY,
1136
1136
  LLM_CONFIG_KEY,
1137
1137
  )
@@ -1162,6 +1162,12 @@ def _get_llm_command_generator_config(config: Dict[str, Any]) -> Optional[Dict]:
1162
1162
 
1163
1163
  def extract_llm_command_generator_llm_client_settings(component: Dict) -> Dict:
1164
1164
  """Extracts settings related to LLM command generator."""
1165
+ component_class_lookup = {
1166
+ LLMCommandGenerator.__name__: LLMCommandGenerator,
1167
+ SingleStepLLMCommandGenerator.__name__: SingleStepLLMCommandGenerator,
1168
+ MultiStepLLMCommandGenerator.__name__: MultiStepLLMCommandGenerator,
1169
+ CompactLLMCommandGenerator.__name__: CompactLLMCommandGenerator,
1170
+ }
1165
1171
  llm_config = component.get(LLM_CONFIG_KEY, {})
1166
1172
  # Config at this stage is not yet resolved, so read from `model_group`
1167
1173
  llm_model_group_id = llm_config.get(MODEL_GROUP_CONFIG_KEY)
@@ -1169,7 +1175,11 @@ def _get_llm_command_generator_config(config: Dict[str, Any]) -> Optional[Dict]:
1169
1175
  MODEL_NAME_CONFIG_KEY
1170
1176
  )
1171
1177
  if llm_model_group_id is None and llm_model_name is None:
1172
- llm_model_name = DEFAULT_LLM_CONFIG[MODEL_CONFIG_KEY]
1178
+ component_clz = cast(
1179
+ Type[LLMBasedCommandGenerator],
1180
+ component_class_lookup[component["name"]],
1181
+ )
1182
+ llm_model_name = component_clz.get_default_llm_config()[MODEL_CONFIG_KEY]
1173
1183
 
1174
1184
  custom_prompt_used = (
1175
1185
  PROMPT_CONFIG_KEY in component or PROMPT_TEMPLATE_CONFIG_KEY in component
@@ -58,9 +58,7 @@ from rasa.shared.core.trackers import DialogueStateTracker
58
58
  from rasa.shared.core.training_data.structures import StoryGraph
59
59
  from rasa.shared.importers.importer import TrainingDataImporter
60
60
  from rasa.shared.nlu.constants import INTENT_NAME_KEY, SET_SLOT_COMMAND
61
- from rasa.shared.utils.llm import (
62
- combine_custom_and_default_config,
63
- )
61
+ from rasa.shared.utils.llm import combine_custom_and_default_config
64
62
  from rasa.tracing.constants import (
65
63
  PROMPT_TOKEN_LENGTH_ATTRIBUTE_NAME,
66
64
  REQUEST_BODY_SIZE_IN_BYTES_ATTRIBUTE_NAME,
@@ -375,14 +373,13 @@ def extract_attrs_for_llm_based_command_generator(
375
373
  self: "LLMBasedCommandGenerator",
376
374
  prompt: str,
377
375
  ) -> Dict[str, Any]:
378
- from rasa.dialogue_understanding.generator.constants import DEFAULT_LLM_CONFIG
379
376
  from rasa.dialogue_understanding.generator.flow_retrieval import (
380
377
  DEFAULT_EMBEDDINGS_CONFIG,
381
378
  )
382
379
 
383
380
  attributes = extract_llm_config(
384
381
  self,
385
- default_llm_config=DEFAULT_LLM_CONFIG,
382
+ default_llm_config=self.get_default_llm_config(),
386
383
  default_embeddings_config=DEFAULT_EMBEDDINGS_CONFIG,
387
384
  )
388
385