rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
rasa/core/utils.py CHANGED
@@ -1,23 +1,31 @@
1
- import json
2
1
  import logging
3
2
  import os
4
- from decimal import Decimal
5
3
  from pathlib import Path
6
- from typing import Any, Dict, Optional, Set, Text, Tuple, Union
4
+ from socket import SOCK_DGRAM, SOCK_STREAM
5
+ from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Text, Tuple, Union
7
6
 
8
7
  import numpy as np
8
+ from sanic import Sanic
9
9
 
10
+ import rasa.cli.utils as cli_utils
10
11
  import rasa.shared.utils.io
11
12
  from rasa.constants import DEFAULT_SANIC_WORKERS, ENV_SANIC_WORKERS
12
- from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH, TCP_PROTOCOL
13
-
13
+ from rasa.core.constants import (
14
+ DOMAIN_GROUND_TRUTH_METADATA_KEY,
15
+ UTTER_SOURCE_METADATA_KEY,
16
+ ACTIVE_FLOW_METADATA_KEY,
17
+ STEP_ID_METADATA_KEY,
18
+ )
14
19
  from rasa.core.lock_store import LockStore, RedisLockStore, InMemoryLockStore
20
+ from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH, TCP_PROTOCOL
21
+ from rasa.shared.core.trackers import DialogueStateTracker
15
22
  from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
16
- from sanic import Sanic
17
- from socket import SOCK_DGRAM, SOCK_STREAM
18
- import rasa.cli.utils as cli_utils
19
23
  from rasa.utils.io import write_yaml
20
24
 
25
+ if TYPE_CHECKING:
26
+ from rasa.core.nlg import NaturalLanguageGenerator
27
+ from rasa.shared.core.domain import Domain
28
+
21
29
  logger = logging.getLogger(__name__)
22
30
 
23
31
 
@@ -163,6 +171,8 @@ def is_limit_reached(num_messages: int, limit: Optional[int]) -> bool:
163
171
  class AvailableEndpoints:
164
172
  """Collection of configured endpoints."""
165
173
 
174
+ _instance = None
175
+
166
176
  @classmethod
167
177
  def read_endpoints(cls, endpoint_file: Text) -> "AvailableEndpoints":
168
178
  """Read the different endpoints from a yaml file."""
@@ -209,6 +219,14 @@ class AvailableEndpoints:
209
219
  self.event_broker = event_broker
210
220
  self.vector_store = vector_store
211
221
 
222
+ @classmethod
223
+ def get_instance(cls, endpoint_file: Optional[Text] = None) -> "AvailableEndpoints":
224
+ """Get the singleton instance of AvailableEndpoints."""
225
+ # Ensure that the instance is initialized only once.
226
+ if cls._instance is None:
227
+ cls._instance = cls.read_endpoints(endpoint_file)
228
+ return cls._instance
229
+
212
230
 
213
231
  def read_endpoints_from_path(
214
232
  endpoints_path: Optional[Union[Path, Text]] = None,
@@ -226,55 +244,7 @@ def read_endpoints_from_path(
226
244
  endpoints_config_path = cli_utils.get_validated_path(
227
245
  endpoints_path, "endpoints", DEFAULT_ENDPOINTS_PATH, True
228
246
  )
229
- return AvailableEndpoints.read_endpoints(endpoints_config_path)
230
-
231
-
232
- def replace_floats_with_decimals(obj: Any, round_digits: int = 9) -> Any:
233
- """Convert all instances in `obj` of `float` to `Decimal`.
234
-
235
- Args:
236
- obj: Input object.
237
- round_digits: Rounding precision of `Decimal` values.
238
-
239
- Returns:
240
- Input `obj` with all `float` types replaced by `Decimal`s rounded to
241
- `round_digits` decimal places.
242
- """
243
-
244
- def _float_to_rounded_decimal(s: Text) -> Decimal:
245
- return Decimal(s).quantize(Decimal(10) ** -round_digits)
246
-
247
- return json.loads(json.dumps(obj), parse_float=_float_to_rounded_decimal)
248
-
249
-
250
- class DecimalEncoder(json.JSONEncoder):
251
- """`json.JSONEncoder` that dumps `Decimal`s as `float`s."""
252
-
253
- def default(self, obj: Any) -> Any:
254
- """Get serializable object for `o`.
255
-
256
- Args:
257
- obj: Object to serialize.
258
-
259
- Returns:
260
- `obj` converted to `float` if `o` is a `Decimals`, else the base class
261
- `default()` method.
262
- """
263
- if isinstance(obj, Decimal):
264
- return float(obj)
265
- return super().default(obj)
266
-
267
-
268
- def replace_decimals_with_floats(obj: Any) -> Any:
269
- """Convert all instances in `obj` of `Decimal` to `float`.
270
-
271
- Args:
272
- obj: A `List` or `Dict` object.
273
-
274
- Returns:
275
- Input `obj` with all `Decimal` types replaced by `float`s.
276
- """
277
- return json.loads(json.dumps(obj, cls=DecimalEncoder))
247
+ return AvailableEndpoints.get_instance(endpoints_config_path)
278
248
 
279
249
 
280
250
  def _lock_store_is_multi_worker_compatible(
@@ -337,3 +307,32 @@ def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None])
337
307
  f"configuration has been found."
338
308
  )
339
309
  return _log_and_get_default_number_of_workers()
310
+
311
+
312
+ def add_bot_utterance_metadata(
313
+ message: Dict[str, Any],
314
+ domain_response_name: str,
315
+ nlg: "NaturalLanguageGenerator",
316
+ domain: "Domain",
317
+ tracker: Optional[DialogueStateTracker],
318
+ ) -> Dict[str, Any]:
319
+ """Add metadata to the bot message."""
320
+ message["utter_action"] = domain_response_name
321
+
322
+ utter_source = message.get(UTTER_SOURCE_METADATA_KEY)
323
+ if utter_source is None:
324
+ utter_source = nlg.__class__.__name__
325
+ message[UTTER_SOURCE_METADATA_KEY] = utter_source
326
+
327
+ if tracker:
328
+ message[ACTIVE_FLOW_METADATA_KEY] = tracker.active_flow
329
+ message[STEP_ID_METADATA_KEY] = tracker.current_step_id
330
+
331
+ if utter_source in ["IntentlessPolicy", "ContextualResponseRephraser"]:
332
+ message[DOMAIN_GROUND_TRUTH_METADATA_KEY] = [
333
+ response.get("text")
334
+ for response in domain.responses.get(domain_response_name, [])
335
+ if response.get("text") is not None
336
+ ]
337
+
338
+ return message
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
+
2
3
  import importlib
3
4
  from typing import Any, Dict, List, Optional
4
5
 
5
6
  import structlog
6
7
  from jinja2 import Template
7
8
 
9
+ import rasa.shared.utils.io
8
10
  from rasa.dialogue_understanding.coexistence.constants import (
9
11
  CALM_ENTRY,
10
12
  NLU_ENTRY,
@@ -18,25 +20,32 @@ from rasa.engine.graph import ExecutionContext, GraphComponent
18
20
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
19
21
  from rasa.engine.storage.resource import Resource
20
22
  from rasa.engine.storage.storage import ModelStorage
21
- from rasa.shared.constants import ROUTE_TO_CALM_SLOT
23
+ from rasa.shared.constants import (
24
+ ROUTE_TO_CALM_SLOT,
25
+ PROMPT_CONFIG_KEY,
26
+ PROVIDER_CONFIG_KEY,
27
+ MODEL_CONFIG_KEY,
28
+ OPENAI_PROVIDER,
29
+ TIMEOUT_CONFIG_KEY,
30
+ )
22
31
  from rasa.shared.core.trackers import DialogueStateTracker
23
32
  from rasa.shared.exceptions import InvalidConfigException, FileIOException
24
33
  from rasa.shared.nlu.constants import COMMANDS, TEXT
25
34
  from rasa.shared.nlu.training_data.message import Message
26
35
  from rasa.shared.nlu.training_data.training_data import TrainingData
27
- import rasa.shared.utils.io
28
36
  from rasa.shared.utils.llm import (
29
37
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
30
38
  get_prompt_template,
31
39
  llm_factory,
40
+ try_instantiate_llm_client,
32
41
  )
42
+ from rasa.utils.log_utils import log_llm
33
43
 
34
44
  LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
35
45
  DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
36
46
  "rasa.dialogue_understanding.coexistence", "router_template.jinja2"
37
47
  )
38
48
 
39
-
40
49
  # Token ids for gpt 3.5 and gpt 4 corresponding to space + capitalized Letter
41
50
  A_TO_C_TOKEN_IDS_CHATGPT = [
42
51
  362, # " A"
@@ -45,10 +54,10 @@ A_TO_C_TOKEN_IDS_CHATGPT = [
45
54
  ]
46
55
 
47
56
  DEFAULT_LLM_CONFIG = {
48
- "_type": "openai",
49
- "request_timeout": 7,
57
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
58
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
59
+ TIMEOUT_CONFIG_KEY: 7,
50
60
  "temperature": 0.0,
51
- "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
52
61
  "max_tokens": 1,
53
62
  "logit_bias": {str(token_id): 100 for token_id in A_TO_C_TOKEN_IDS_CHATGPT},
54
63
  }
@@ -67,7 +76,7 @@ class LLMBasedRouter(GraphComponent):
67
76
  def get_default_config() -> Dict[str, Any]:
68
77
  """The component's default config (see parent class for full docstring)."""
69
78
  return {
70
- "prompt": None,
79
+ PROMPT_CONFIG_KEY: None,
71
80
  CALM_ENTRY: {STICKY: None},
72
81
  NLU_ENTRY: {
73
82
  NON_STICKY: "handles chitchat",
@@ -88,7 +97,7 @@ class LLMBasedRouter(GraphComponent):
88
97
  self.prompt_template = (
89
98
  prompt_template
90
99
  or get_prompt_template(
91
- config.get("prompt"),
100
+ config.get(PROMPT_CONFIG_KEY),
92
101
  DEFAULT_COMMAND_PROMPT_TEMPLATE,
93
102
  ).strip()
94
103
  )
@@ -120,6 +129,14 @@ class LLMBasedRouter(GraphComponent):
120
129
 
121
130
  def train(self, training_data: TrainingData) -> Resource:
122
131
  """Train the intent classifier on a data set."""
132
+ # Validate llm configuration
133
+ try_instantiate_llm_client(
134
+ self.config.get(LLM_CONFIG_KEY),
135
+ DEFAULT_LLM_CONFIG,
136
+ "llm_based_router.train",
137
+ "LLMBasedRouter",
138
+ )
139
+
123
140
  self.persist()
124
141
  return self._resource
125
142
 
@@ -144,7 +161,14 @@ class LLMBasedRouter(GraphComponent):
144
161
  "llm_based_router.load.failed", error=e, resource=resource.name
145
162
  )
146
163
 
147
- return cls(config, model_storage, resource, prompt_template=prompt_template)
164
+ router = cls(config, model_storage, resource, prompt_template=prompt_template)
165
+ try_instantiate_llm_client(
166
+ router.config.get(LLM_CONFIG_KEY),
167
+ DEFAULT_LLM_CONFIG,
168
+ "llm_based_router.load",
169
+ LLMBasedRouter.__name__,
170
+ )
171
+ return router
148
172
 
149
173
  @classmethod
150
174
  def create(
@@ -188,12 +212,27 @@ class LLMBasedRouter(GraphComponent):
188
212
  route_session_to_calm = tracker.get_slot(ROUTE_TO_CALM_SLOT)
189
213
  if route_session_to_calm is None:
190
214
  prompt = self.render_template(message)
191
- structlogger.info("llm_based_router.prompt_rendered", prompt=prompt)
215
+ log_llm(
216
+ logger=structlogger,
217
+ log_module="LLMBasedRouter",
218
+ log_event="llm_based_router.prompt_rendered",
219
+ prompt=prompt,
220
+ )
192
221
  # generating answer
193
222
  answer = await self._generate_answer_using_llm(prompt)
194
- structlogger.info("llm_based_router.llm_answer", answer=answer)
223
+ log_llm(
224
+ logger=structlogger,
225
+ log_module="LLMBasedRouter",
226
+ log_event="llm_based_router.llm_answer",
227
+ answer=answer,
228
+ )
195
229
  commands = self.parse_answer(answer)
196
- structlogger.info("llm_based_router.predicated_commands", commands=commands)
230
+ log_llm(
231
+ logger=structlogger,
232
+ log_module="LLMBasedRouter",
233
+ log_event="llm_based_router.final_commands",
234
+ commands=commands,
235
+ )
197
236
  return commands
198
237
  elif route_session_to_calm is True:
199
238
  # don't set any commands so that a `LLMBasedCommandGenerator` is triggered
@@ -252,7 +291,8 @@ class LLMBasedRouter(GraphComponent):
252
291
  llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
253
292
 
254
293
  try:
255
- return await llm.apredict(prompt)
294
+ llm_response = await llm.acompletion(prompt)
295
+ return llm_response.choices[0]
256
296
  except Exception as e:
257
297
  # unfortunately, langchain does not wrap LLM exceptions which means
258
298
  # we have to catch all exceptions here
@@ -9,6 +9,7 @@ from rasa.dialogue_understanding.commands.knowledge_answer_command import (
9
9
  from rasa.dialogue_understanding.commands.chit_chat_answer_command import (
10
10
  ChitChatAnswerCommand,
11
11
  )
12
+ from rasa.dialogue_understanding.commands.restart_command import RestartCommand
12
13
  from rasa.dialogue_understanding.commands.skip_question_command import (
13
14
  SkipQuestionCommand,
14
15
  )
@@ -28,6 +29,9 @@ from rasa.dialogue_understanding.commands.correct_slots_command import (
28
29
  )
29
30
  from rasa.dialogue_understanding.commands.noop_command import NoopCommand
30
31
  from rasa.dialogue_understanding.commands.change_flow_command import ChangeFlowCommand
32
+ from rasa.dialogue_understanding.commands.session_start_command import (
33
+ SessionStartCommand,
34
+ )
31
35
 
32
36
  __all__ = [
33
37
  "Command",
@@ -46,4 +50,6 @@ __all__ = [
46
50
  "ErrorCommand",
47
51
  "NoopCommand",
48
52
  "ChangeFlowCommand",
53
+ "SessionStartCommand",
54
+ "RestartCommand",
49
55
  ]
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List
5
+
6
+ from rasa.dialogue_understanding.commands import Command
7
+ from rasa.dialogue_understanding.patterns.restart import RestartPatternFlowStackFrame
8
+ from rasa.shared.core.events import Event
9
+ from rasa.shared.core.flows import FlowsList
10
+ from rasa.shared.core.trackers import DialogueStateTracker
11
+
12
+
13
+ @dataclass
14
+ class RestartCommand(Command):
15
+ """A command to restart a session."""
16
+
17
+ @classmethod
18
+ def command(cls) -> str:
19
+ """Returns the command type."""
20
+ return "restart"
21
+
22
+ @classmethod
23
+ def from_dict(cls, data: Dict[str, Any]) -> RestartCommand:
24
+ """Converts the dictionary to a command.
25
+
26
+ Returns:
27
+ The converted dictionary.
28
+ """
29
+ return RestartCommand()
30
+
31
+ def run_command_on_tracker(
32
+ self,
33
+ tracker: DialogueStateTracker,
34
+ all_flows: FlowsList,
35
+ original_tracker: DialogueStateTracker,
36
+ ) -> List[Event]:
37
+ """Runs the command on the tracker.
38
+
39
+ Args:
40
+ tracker: The tracker to run the command on.
41
+ all_flows: All flows in the assistant.
42
+ original_tracker: The tracker before any command was executed.
43
+
44
+ Returns:
45
+ The events to apply to the tracker.
46
+ """
47
+ stack = tracker.stack
48
+ stack.push(RestartPatternFlowStackFrame())
49
+ return tracker.create_stack_updated_events(stack)
50
+
51
+ def __hash__(self) -> int:
52
+ return hash(self.command())
53
+
54
+ def __eq__(self, other: object) -> bool:
55
+ if not isinstance(other, RestartCommand):
56
+ return False
57
+
58
+ return True
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List
5
+ from rasa.dialogue_understanding.commands import Command
6
+ from rasa.dialogue_understanding.patterns.session_start import (
7
+ SessionStartPatternFlowStackFrame,
8
+ )
9
+ from rasa.shared.core.events import Event
10
+ from rasa.shared.core.flows import FlowsList
11
+ from rasa.shared.core.trackers import DialogueStateTracker
12
+
13
+
14
+ @dataclass
15
+ class SessionStartCommand(Command):
16
+ """A command to indicate the start of a session."""
17
+
18
+ @classmethod
19
+ def command(cls) -> str:
20
+ """Returns the command type."""
21
+ return "session start"
22
+
23
+ @classmethod
24
+ def from_dict(cls, data: Dict[str, Any]) -> SessionStartCommand:
25
+ """Converts the dictionary to a command.
26
+
27
+ Returns:
28
+ The converted dictionary.
29
+ """
30
+ return SessionStartCommand()
31
+
32
+ def run_command_on_tracker(
33
+ self,
34
+ tracker: DialogueStateTracker,
35
+ all_flows: FlowsList,
36
+ original_tracker: DialogueStateTracker,
37
+ ) -> List[Event]:
38
+ """Runs the command on the tracker.
39
+
40
+ Args:
41
+ tracker: The tracker to run the command on.
42
+ all_flows: All flows in the assistant.
43
+ original_tracker: The tracker before any command was executed.
44
+
45
+ Returns:
46
+ The events to apply to the tracker.
47
+ """
48
+ stack = tracker.stack
49
+ stack.push(SessionStartPatternFlowStackFrame())
50
+ return tracker.create_stack_updated_events(stack)
51
+
52
+ def __hash__(self) -> int:
53
+ return hash(self.command())
54
+
55
+ def __eq__(self, other: object) -> bool:
56
+ if not isinstance(other, SessionStartCommand):
57
+ return False
58
+
59
+ return True
@@ -0,0 +1,40 @@
1
+ from typing import Dict, Type
2
+
3
+ from rasa.dialogue_understanding.commands import (
4
+ CancelFlowCommand,
5
+ CannotHandleCommand,
6
+ ChitChatAnswerCommand,
7
+ Command,
8
+ HumanHandoffCommand,
9
+ KnowledgeAnswerCommand,
10
+ SessionStartCommand,
11
+ SkipQuestionCommand,
12
+ RestartCommand,
13
+ )
14
+ from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
15
+ from rasa.dialogue_understanding.patterns.cannot_handle import (
16
+ CannotHandlePatternFlowStackFrame,
17
+ )
18
+ from rasa.dialogue_understanding.patterns.chitchat import ChitchatPatternFlowStackFrame
19
+ from rasa.dialogue_understanding.patterns.human_handoff import (
20
+ HumanHandoffPatternFlowStackFrame,
21
+ )
22
+ from rasa.dialogue_understanding.patterns.restart import RestartPatternFlowStackFrame
23
+ from rasa.dialogue_understanding.patterns.search import SearchPatternFlowStackFrame
24
+ from rasa.dialogue_understanding.patterns.session_start import (
25
+ SessionStartPatternFlowStackFrame,
26
+ )
27
+ from rasa.dialogue_understanding.patterns.skip_question import (
28
+ SkipQuestionPatternFlowStackFrame,
29
+ )
30
+
31
+ triggerable_pattern_to_command_class: Dict[str, Type[Command]] = {
32
+ SessionStartPatternFlowStackFrame.flow_id: SessionStartCommand,
33
+ CancelPatternFlowStackFrame.flow_id: CancelFlowCommand,
34
+ ChitchatPatternFlowStackFrame.flow_id: ChitChatAnswerCommand,
35
+ HumanHandoffPatternFlowStackFrame.flow_id: HumanHandoffCommand,
36
+ SearchPatternFlowStackFrame.flow_id: KnowledgeAnswerCommand,
37
+ SkipQuestionPatternFlowStackFrame.flow_id: SkipQuestionCommand,
38
+ CannotHandlePatternFlowStackFrame.flow_id: CannotHandleCommand,
39
+ RestartPatternFlowStackFrame.flow_id: RestartCommand,
40
+ }
@@ -1,14 +1,20 @@
1
+ from rasa.shared.constants import (
2
+ PROVIDER_CONFIG_KEY,
3
+ OPENAI_PROVIDER,
4
+ MODEL_CONFIG_KEY,
5
+ TIMEOUT_CONFIG_KEY,
6
+ )
1
7
  from rasa.shared.utils.llm import (
2
8
  DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
3
9
  DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
4
10
  )
5
11
 
6
12
  DEFAULT_LLM_CONFIG = {
7
- "_type": "openai",
8
- "request_timeout": 7,
13
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
14
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
9
15
  "temperature": 0.0,
10
- "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
11
16
  "max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
17
+ TIMEOUT_CONFIG_KEY: 7,
12
18
  }
13
19
 
14
20
  LLM_CONFIG_KEY = "llm"
@@ -16,3 +22,4 @@ USER_INPUT_CONFIG_KEY = "user_input"
16
22
 
17
23
  FLOW_RETRIEVAL_KEY = "flow_retrieval"
18
24
  FLOW_RETRIEVAL_ACTIVE_KEY = "active"
25
+ FLOW_RETRIEVAL_EMBEDDINGS_CONFIG_KEY = "embeddings"
@@ -25,16 +25,24 @@ import structlog
25
25
  from jinja2 import Template
26
26
  from langchain.docstore.document import Document
27
27
  from langchain.schema.embeddings import Embeddings
28
- from langchain.vectorstores.faiss import FAISS
29
- from langchain.vectorstores.utils import DistanceStrategy
28
+ from langchain_community.vectorstores.faiss import FAISS
29
+ from langchain_community.vectorstores.utils import DistanceStrategy
30
30
  from rasa.engine.storage.resource import Resource
31
31
  from rasa.engine.storage.storage import ModelStorage
32
+ from rasa.shared.constants import (
33
+ EMBEDDINGS_CONFIG_KEY,
34
+ PROVIDER_CONFIG_KEY,
35
+ OPENAI_PROVIDER,
36
+ )
32
37
  from rasa.shared.core.domain import Domain
33
38
  from rasa.shared.core.flows import FlowsList
34
39
  from rasa.shared.core.trackers import DialogueStateTracker
35
40
  from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
36
41
  from rasa.shared.nlu.training_data.message import Message
37
42
  from rasa.shared.exceptions import ProviderClientAPIException
43
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
44
+ _LangchainEmbeddingClientAdapter,
45
+ )
38
46
  from rasa.shared.utils.llm import (
39
47
  tracker_as_readable_transcript,
40
48
  embedder_factory,
@@ -42,15 +50,15 @@ from rasa.shared.utils.llm import (
42
50
  USER,
43
51
  get_prompt_template,
44
52
  allowed_values_for_slot,
53
+ try_instantiate_embedder,
45
54
  )
46
55
 
47
56
  DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
48
57
  "rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
49
58
  )
50
59
 
51
- EMBEDDINGS_CONFIG_KEY = "embeddings"
52
60
  DEFAULT_EMBEDDINGS_CONFIG = {
53
- "_type": "openai",
61
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
54
62
  "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
55
63
  }
56
64
 
@@ -135,6 +143,12 @@ class FlowRetrieval:
135
143
  """Load flow retrieval with previously populated FAISS vector store."""
136
144
  # initialize base flow retrieval
137
145
  flow_retrieval = FlowRetrieval(config, model_storage, resource)
146
+ try_instantiate_embedder(
147
+ flow_retrieval.config.get(EMBEDDINGS_CONFIG_KEY),
148
+ DEFAULT_EMBEDDINGS_CONFIG,
149
+ "flow_retrieval.load",
150
+ FlowRetrieval.__name__,
151
+ )
138
152
  # load vector store
139
153
  vector_store = cls._load_vector_store(
140
154
  flow_retrieval.config, model_storage, resource
@@ -154,6 +168,7 @@ class FlowRetrieval:
154
168
  folder_path=model_path,
155
169
  embeddings=embeddings,
156
170
  distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
171
+ allow_dangerous_deserialization=True,
157
172
  )
158
173
  except Exception as e:
159
174
  structlogger.warning(
@@ -170,9 +185,10 @@ class FlowRetrieval:
170
185
  Returns:
171
186
  The embedder.
172
187
  """
173
- return embedder_factory(
188
+ client = embedder_factory(
174
189
  config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
175
190
  )
191
+ return _LangchainEmbeddingClientAdapter(client)
176
192
 
177
193
  def persist(self) -> None:
178
194
  self._persist_vector_store()
@@ -32,8 +32,9 @@ from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
32
32
  from rasa.shared.nlu.training_data.message import Message
33
33
  from rasa.shared.nlu.training_data.training_data import TrainingData
34
34
  from rasa.shared.utils.llm import (
35
- llm_factory,
36
35
  allowed_values_for_slot,
36
+ llm_factory,
37
+ try_instantiate_llm_client,
37
38
  )
38
39
  from rasa.utils.log_utils import log_llm
39
40
 
@@ -167,6 +168,14 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
167
168
  """Train the llm based command generator. Stores all flows into a vector
168
169
  store.
169
170
  """
171
+ # Validate llm configuration
172
+ try_instantiate_llm_client(
173
+ self.config.get(LLM_CONFIG_KEY),
174
+ DEFAULT_LLM_CONFIG,
175
+ "llm_based_command_generator.train",
176
+ "LLMBasedCommandGenerator",
177
+ )
178
+
170
179
  # flow retrieval is populated with only user-defined flows
171
180
  try:
172
181
  if self.flow_retrieval is not None and not flows.is_empty():
@@ -174,7 +183,7 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
174
183
  except Exception as e:
175
184
  structlogger.error(
176
185
  "llm_based_command_generator.train.failed",
177
- event_info=("Flow retrieval store isinaccessible."),
186
+ event_info="Flow retrieval store is inaccessible.",
178
187
  error=e,
179
188
  )
180
189
  raise
@@ -286,7 +295,8 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
286
295
  """
287
296
  llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
288
297
  try:
289
- return await llm.apredict(prompt)
298
+ llm_response = await llm.acompletion(prompt)
299
+ return llm_response.choices[0]
290
300
  except Exception as e:
291
301
  # unfortunately, langchain does not wrap LLM exceptions which means
292
302
  # we have to catch all exceptions here