rasa-pro 3.9.18__py3-none-any.whl → 3.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (189) hide show
  1. README.md +26 -57
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +26 -22
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +2 -0
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +108 -433
  12. rasa/cli/interactive.py +1 -0
  13. rasa/cli/llm_fine_tuning.py +395 -0
  14. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  15. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  16. rasa/cli/run.py +14 -13
  17. rasa/cli/scaffold.py +10 -8
  18. rasa/cli/train.py +8 -7
  19. rasa/cli/utils.py +15 -0
  20. rasa/constants.py +7 -1
  21. rasa/core/actions/action.py +98 -49
  22. rasa/core/actions/action_run_slot_rejections.py +4 -1
  23. rasa/core/actions/custom_action_executor.py +9 -6
  24. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  25. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  26. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  27. rasa/core/actions/http_custom_action_executor.py +6 -5
  28. rasa/core/agent.py +21 -17
  29. rasa/core/channels/__init__.py +2 -0
  30. rasa/core/channels/audiocodes.py +1 -16
  31. rasa/core/channels/inspector/dist/index.html +0 -2
  32. rasa/core/channels/inspector/index.html +0 -2
  33. rasa/core/channels/voice_aware/__init__.py +0 -0
  34. rasa/core/channels/voice_aware/jambonz.py +103 -0
  35. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  36. rasa/core/channels/voice_aware/utils.py +20 -0
  37. rasa/core/channels/voice_native/__init__.py +0 -0
  38. rasa/core/constants.py +6 -1
  39. rasa/core/featurizers/single_state_featurizer.py +1 -22
  40. rasa/core/featurizers/tracker_featurizers.py +18 -115
  41. rasa/core/information_retrieval/faiss.py +7 -4
  42. rasa/core/information_retrieval/information_retrieval.py +8 -0
  43. rasa/core/information_retrieval/milvus.py +9 -2
  44. rasa/core/information_retrieval/qdrant.py +1 -1
  45. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  46. rasa/core/nlg/summarize.py +4 -3
  47. rasa/core/policies/enterprise_search_policy.py +100 -44
  48. rasa/core/policies/flows/flow_executor.py +130 -94
  49. rasa/core/policies/intentless_policy.py +52 -28
  50. rasa/core/policies/ted_policy.py +33 -58
  51. rasa/core/policies/unexpected_intent_policy.py +7 -15
  52. rasa/core/processor.py +20 -53
  53. rasa/core/run.py +5 -4
  54. rasa/core/tracker_store.py +8 -4
  55. rasa/core/utils.py +45 -56
  56. rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
  57. rasa/dialogue_understanding/commands/__init__.py +4 -0
  58. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  59. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  60. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  61. rasa/dialogue_understanding/commands/utils.py +38 -0
  62. rasa/dialogue_understanding/generator/constants.py +10 -3
  63. rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
  64. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
  65. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
  66. rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
  67. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
  68. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  69. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  70. rasa/dialogue_understanding/processor/command_processor.py +13 -14
  71. rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
  72. rasa/e2e_test/assertions.py +1181 -0
  73. rasa/e2e_test/assertions_schema.yml +106 -0
  74. rasa/e2e_test/constants.py +20 -0
  75. rasa/e2e_test/e2e_config.py +220 -0
  76. rasa/e2e_test/e2e_config_schema.yml +26 -0
  77. rasa/e2e_test/e2e_test_case.py +131 -8
  78. rasa/e2e_test/e2e_test_converter.py +363 -0
  79. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  80. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  81. rasa/e2e_test/e2e_test_result.py +26 -6
  82. rasa/e2e_test/e2e_test_runner.py +491 -72
  83. rasa/e2e_test/e2e_test_schema.yml +96 -0
  84. rasa/e2e_test/pykwalify_extensions.py +39 -0
  85. rasa/e2e_test/stub_custom_action.py +70 -0
  86. rasa/e2e_test/utils/__init__.py +0 -0
  87. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  88. rasa/e2e_test/utils/io.py +596 -0
  89. rasa/e2e_test/utils/validation.py +80 -0
  90. rasa/engine/recipes/default_components.py +0 -2
  91. rasa/engine/storage/local_model_storage.py +0 -1
  92. rasa/env.py +9 -0
  93. rasa/llm_fine_tuning/__init__.py +0 -0
  94. rasa/llm_fine_tuning/annotation_module.py +241 -0
  95. rasa/llm_fine_tuning/conversations.py +144 -0
  96. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  97. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  98. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  99. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  100. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  101. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  102. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  103. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  104. rasa/llm_fine_tuning/storage.py +174 -0
  105. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  106. rasa/model_training.py +48 -16
  107. rasa/nlu/classifiers/diet_classifier.py +25 -38
  108. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
  109. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  110. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  111. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
  112. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  113. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  114. rasa/nlu/persistor.py +129 -32
  115. rasa/server.py +45 -10
  116. rasa/shared/constants.py +63 -15
  117. rasa/shared/core/domain.py +15 -12
  118. rasa/shared/core/events.py +28 -2
  119. rasa/shared/core/flows/flow.py +208 -13
  120. rasa/shared/core/flows/flow_path.py +84 -0
  121. rasa/shared/core/flows/flows_list.py +28 -10
  122. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  123. rasa/shared/core/flows/validation.py +112 -25
  124. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  125. rasa/shared/core/trackers.py +6 -0
  126. rasa/shared/core/training_data/visualization.html +2 -2
  127. rasa/shared/exceptions.py +4 -0
  128. rasa/shared/importers/importer.py +60 -11
  129. rasa/shared/importers/remote_importer.py +196 -0
  130. rasa/shared/nlu/constants.py +2 -0
  131. rasa/shared/nlu/training_data/features.py +2 -120
  132. rasa/shared/providers/_configs/__init__.py +0 -0
  133. rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
  134. rasa/shared/providers/_configs/client_config.py +57 -0
  135. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  136. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  137. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  138. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
  139. rasa/shared/providers/_configs/utils.py +101 -0
  140. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  141. rasa/shared/providers/embedding/__init__.py +0 -0
  142. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
  143. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  144. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  145. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  146. rasa/shared/providers/embedding/embedding_client.py +90 -0
  147. rasa/shared/providers/embedding/embedding_response.py +41 -0
  148. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  149. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  150. rasa/shared/providers/llm/__init__.py +0 -0
  151. rasa/shared/providers/llm/_base_litellm_client.py +227 -0
  152. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  153. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  154. rasa/shared/providers/llm/llm_client.py +76 -0
  155. rasa/shared/providers/llm/llm_response.py +50 -0
  156. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  157. rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
  158. rasa/shared/providers/mappings.py +75 -0
  159. rasa/shared/utils/cli.py +30 -0
  160. rasa/shared/utils/io.py +65 -3
  161. rasa/shared/utils/llm.py +223 -200
  162. rasa/shared/utils/yaml.py +122 -7
  163. rasa/studio/download.py +19 -13
  164. rasa/studio/train.py +2 -3
  165. rasa/studio/upload.py +2 -3
  166. rasa/telemetry.py +113 -58
  167. rasa/tracing/config.py +2 -3
  168. rasa/tracing/instrumentation/attribute_extractors.py +29 -17
  169. rasa/tracing/instrumentation/instrumentation.py +4 -47
  170. rasa/utils/common.py +18 -19
  171. rasa/utils/endpoints.py +7 -4
  172. rasa/utils/io.py +66 -0
  173. rasa/utils/json_utils.py +60 -0
  174. rasa/utils/licensing.py +9 -1
  175. rasa/utils/ml_utils.py +4 -2
  176. rasa/utils/tensorflow/model_data.py +193 -2
  177. rasa/validator.py +195 -1
  178. rasa/version.py +1 -1
  179. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +47 -72
  180. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +185 -121
  181. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  182. rasa/shared/providers/openai/clients.py +0 -43
  183. rasa/shared/providers/openai/session_handler.py +0 -110
  184. rasa/utils/tensorflow/feature_array.py +0 -366
  185. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  186. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  187. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
  188. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
  189. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
rasa/core/utils.py CHANGED
@@ -1,23 +1,31 @@
1
- import json
2
1
  import logging
3
2
  import os
4
- from decimal import Decimal
5
3
  from pathlib import Path
6
- from typing import Any, Dict, Optional, Set, Text, Tuple, Union
4
+ from socket import SOCK_DGRAM, SOCK_STREAM
5
+ from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Text, Tuple, Union
7
6
 
8
7
  import numpy as np
8
+ from sanic import Sanic
9
9
 
10
+ import rasa.cli.utils as cli_utils
10
11
  import rasa.shared.utils.io
11
12
  from rasa.constants import DEFAULT_SANIC_WORKERS, ENV_SANIC_WORKERS
12
- from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH, TCP_PROTOCOL
13
-
13
+ from rasa.core.constants import (
14
+ DOMAIN_GROUND_TRUTH_METADATA_KEY,
15
+ UTTER_SOURCE_METADATA_KEY,
16
+ ACTIVE_FLOW_METADATA_KEY,
17
+ STEP_ID_METADATA_KEY,
18
+ )
14
19
  from rasa.core.lock_store import LockStore, RedisLockStore, InMemoryLockStore
20
+ from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH, TCP_PROTOCOL
21
+ from rasa.shared.core.trackers import DialogueStateTracker
15
22
  from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
16
- from sanic import Sanic
17
- from socket import SOCK_DGRAM, SOCK_STREAM
18
- import rasa.cli.utils as cli_utils
19
23
  from rasa.utils.io import write_yaml
20
24
 
25
+ if TYPE_CHECKING:
26
+ from rasa.core.nlg import NaturalLanguageGenerator
27
+ from rasa.shared.core.domain import Domain
28
+
21
29
  logger = logging.getLogger(__name__)
22
30
 
23
31
 
@@ -229,54 +237,6 @@ def read_endpoints_from_path(
229
237
  return AvailableEndpoints.read_endpoints(endpoints_config_path)
230
238
 
231
239
 
232
- def replace_floats_with_decimals(obj: Any, round_digits: int = 9) -> Any:
233
- """Convert all instances in `obj` of `float` to `Decimal`.
234
-
235
- Args:
236
- obj: Input object.
237
- round_digits: Rounding precision of `Decimal` values.
238
-
239
- Returns:
240
- Input `obj` with all `float` types replaced by `Decimal`s rounded to
241
- `round_digits` decimal places.
242
- """
243
-
244
- def _float_to_rounded_decimal(s: Text) -> Decimal:
245
- return Decimal(s).quantize(Decimal(10) ** -round_digits)
246
-
247
- return json.loads(json.dumps(obj), parse_float=_float_to_rounded_decimal)
248
-
249
-
250
- class DecimalEncoder(json.JSONEncoder):
251
- """`json.JSONEncoder` that dumps `Decimal`s as `float`s."""
252
-
253
- def default(self, obj: Any) -> Any:
254
- """Get serializable object for `o`.
255
-
256
- Args:
257
- obj: Object to serialize.
258
-
259
- Returns:
260
- `obj` converted to `float` if `o` is a `Decimals`, else the base class
261
- `default()` method.
262
- """
263
- if isinstance(obj, Decimal):
264
- return float(obj)
265
- return super().default(obj)
266
-
267
-
268
- def replace_decimals_with_floats(obj: Any) -> Any:
269
- """Convert all instances in `obj` of `Decimal` to `float`.
270
-
271
- Args:
272
- obj: A `List` or `Dict` object.
273
-
274
- Returns:
275
- Input `obj` with all `Decimal` types replaced by `float`s.
276
- """
277
- return json.loads(json.dumps(obj, cls=DecimalEncoder))
278
-
279
-
280
240
  def _lock_store_is_multi_worker_compatible(
281
241
  lock_store: Union[EndpointConfig, LockStore, None],
282
242
  ) -> bool:
@@ -337,3 +297,32 @@ def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None])
337
297
  f"configuration has been found."
338
298
  )
339
299
  return _log_and_get_default_number_of_workers()
300
+
301
+
302
+ def add_bot_utterance_metadata(
303
+ message: Dict[str, Any],
304
+ domain_response_name: str,
305
+ nlg: "NaturalLanguageGenerator",
306
+ domain: "Domain",
307
+ tracker: Optional[DialogueStateTracker],
308
+ ) -> Dict[str, Any]:
309
+ """Add metadata to the bot message."""
310
+ message["utter_action"] = domain_response_name
311
+
312
+ utter_source = message.get(UTTER_SOURCE_METADATA_KEY)
313
+ if utter_source is None:
314
+ utter_source = nlg.__class__.__name__
315
+ message[UTTER_SOURCE_METADATA_KEY] = utter_source
316
+
317
+ if tracker:
318
+ message[ACTIVE_FLOW_METADATA_KEY] = tracker.active_flow
319
+ message[STEP_ID_METADATA_KEY] = tracker.current_step_id
320
+
321
+ if utter_source in ["IntentlessPolicy", "ContextualResponseRephraser"]:
322
+ message[DOMAIN_GROUND_TRUTH_METADATA_KEY] = [
323
+ response.get("text")
324
+ for response in domain.responses.get(domain_response_name, [])
325
+ if response.get("text") is not None
326
+ ]
327
+
328
+ return message
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
+
2
3
  import importlib
3
4
  from typing import Any, Dict, List, Optional
4
5
 
5
6
  import structlog
6
7
  from jinja2 import Template
7
8
 
9
+ import rasa.shared.utils.io
8
10
  from rasa.dialogue_understanding.coexistence.constants import (
9
11
  CALM_ENTRY,
10
12
  NLU_ENTRY,
@@ -18,25 +20,32 @@ from rasa.engine.graph import ExecutionContext, GraphComponent
18
20
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
19
21
  from rasa.engine.storage.resource import Resource
20
22
  from rasa.engine.storage.storage import ModelStorage
21
- from rasa.shared.constants import ROUTE_TO_CALM_SLOT
23
+ from rasa.shared.constants import (
24
+ ROUTE_TO_CALM_SLOT,
25
+ PROMPT_CONFIG_KEY,
26
+ PROVIDER_CONFIG_KEY,
27
+ MODEL_CONFIG_KEY,
28
+ OPENAI_PROVIDER,
29
+ TIMEOUT_CONFIG_KEY,
30
+ )
22
31
  from rasa.shared.core.trackers import DialogueStateTracker
23
32
  from rasa.shared.exceptions import InvalidConfigException, FileIOException
24
33
  from rasa.shared.nlu.constants import COMMANDS, TEXT
25
34
  from rasa.shared.nlu.training_data.message import Message
26
35
  from rasa.shared.nlu.training_data.training_data import TrainingData
27
- import rasa.shared.utils.io
28
36
  from rasa.shared.utils.llm import (
29
37
  DEFAULT_OPENAI_CHAT_MODEL_NAME,
30
38
  get_prompt_template,
31
39
  llm_factory,
40
+ try_instantiate_llm_client,
32
41
  )
42
+ from rasa.utils.log_utils import log_llm
33
43
 
34
44
  LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
35
45
  DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
36
46
  "rasa.dialogue_understanding.coexistence", "router_template.jinja2"
37
47
  )
38
48
 
39
-
40
49
  # Token ids for gpt 3.5 and gpt 4 corresponding to space + capitalized Letter
41
50
  A_TO_C_TOKEN_IDS_CHATGPT = [
42
51
  362, # " A"
@@ -45,10 +54,10 @@ A_TO_C_TOKEN_IDS_CHATGPT = [
45
54
  ]
46
55
 
47
56
  DEFAULT_LLM_CONFIG = {
48
- "_type": "openai",
49
- "request_timeout": 7,
57
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
58
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
59
+ TIMEOUT_CONFIG_KEY: 7,
50
60
  "temperature": 0.0,
51
- "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
52
61
  "max_tokens": 1,
53
62
  "logit_bias": {str(token_id): 100 for token_id in A_TO_C_TOKEN_IDS_CHATGPT},
54
63
  }
@@ -67,7 +76,7 @@ class LLMBasedRouter(GraphComponent):
67
76
  def get_default_config() -> Dict[str, Any]:
68
77
  """The component's default config (see parent class for full docstring)."""
69
78
  return {
70
- "prompt": None,
79
+ PROMPT_CONFIG_KEY: None,
71
80
  CALM_ENTRY: {STICKY: None},
72
81
  NLU_ENTRY: {
73
82
  NON_STICKY: "handles chitchat",
@@ -88,7 +97,7 @@ class LLMBasedRouter(GraphComponent):
88
97
  self.prompt_template = (
89
98
  prompt_template
90
99
  or get_prompt_template(
91
- config.get("prompt"),
100
+ config.get(PROMPT_CONFIG_KEY),
92
101
  DEFAULT_COMMAND_PROMPT_TEMPLATE,
93
102
  ).strip()
94
103
  )
@@ -120,6 +129,14 @@ class LLMBasedRouter(GraphComponent):
120
129
 
121
130
  def train(self, training_data: TrainingData) -> Resource:
122
131
  """Train the intent classifier on a data set."""
132
+ # Validate llm configuration
133
+ try_instantiate_llm_client(
134
+ self.config.get(LLM_CONFIG_KEY),
135
+ DEFAULT_LLM_CONFIG,
136
+ "llm_based_router.train",
137
+ "LLMBasedRouter",
138
+ )
139
+
123
140
  self.persist()
124
141
  return self._resource
125
142
 
@@ -188,12 +205,27 @@ class LLMBasedRouter(GraphComponent):
188
205
  route_session_to_calm = tracker.get_slot(ROUTE_TO_CALM_SLOT)
189
206
  if route_session_to_calm is None:
190
207
  prompt = self.render_template(message)
191
- structlogger.info("llm_based_router.prompt_rendered", prompt=prompt)
208
+ log_llm(
209
+ logger=structlogger,
210
+ log_module="LLMBasedRouter",
211
+ log_event="llm_based_router.prompt_rendered",
212
+ prompt=prompt,
213
+ )
192
214
  # generating answer
193
215
  answer = await self._generate_answer_using_llm(prompt)
194
- structlogger.info("llm_based_router.llm_answer", answer=answer)
216
+ log_llm(
217
+ logger=structlogger,
218
+ log_module="LLMBasedRouter",
219
+ log_event="llm_based_router.llm_answer",
220
+ answer=answer,
221
+ )
195
222
  commands = self.parse_answer(answer)
196
- structlogger.info("llm_based_router.predicated_commands", commands=commands)
223
+ log_llm(
224
+ logger=structlogger,
225
+ log_module="LLMBasedRouter",
226
+ log_event="llm_based_router.final_commands",
227
+ commands=commands,
228
+ )
197
229
  return commands
198
230
  elif route_session_to_calm is True:
199
231
  # don't set any commands so that a `LLMBasedCommandGenerator` is triggered
@@ -252,7 +284,8 @@ class LLMBasedRouter(GraphComponent):
252
284
  llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
253
285
 
254
286
  try:
255
- return await llm.apredict(prompt)
287
+ llm_response = await llm.acompletion(prompt)
288
+ return llm_response.choices[0]
256
289
  except Exception as e:
257
290
  # unfortunately, langchain does not wrap LLM exceptions which means
258
291
  # we have to catch all exceptions here
@@ -28,6 +28,9 @@ from rasa.dialogue_understanding.commands.correct_slots_command import (
28
28
  )
29
29
  from rasa.dialogue_understanding.commands.noop_command import NoopCommand
30
30
  from rasa.dialogue_understanding.commands.change_flow_command import ChangeFlowCommand
31
+ from rasa.dialogue_understanding.commands.session_start_command import (
32
+ SessionStartCommand,
33
+ )
31
34
 
32
35
  __all__ = [
33
36
  "Command",
@@ -46,4 +49,5 @@ __all__ = [
46
49
  "ErrorCommand",
47
50
  "NoopCommand",
48
51
  "ChangeFlowCommand",
52
+ "SessionStartCommand",
49
53
  ]
@@ -36,9 +36,3 @@ class ChangeFlowCommand(Command):
36
36
  # the change flow command is not actually pushing anything to the tracker,
37
37
  # but it is predicted by the MultiStepLLMCommandGenerator and used internally
38
38
  return []
39
-
40
- def __eq__(self, other: Any) -> bool:
41
- return isinstance(other, ChangeFlowCommand)
42
-
43
- def __hash__(self) -> int:
44
- return hash(self.command())
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List
5
+ from rasa.dialogue_understanding.commands import Command
6
+ from rasa.dialogue_understanding.patterns.session_start import (
7
+ SessionStartPatternFlowStackFrame,
8
+ )
9
+ from rasa.shared.core.events import Event
10
+ from rasa.shared.core.flows import FlowsList
11
+ from rasa.shared.core.trackers import DialogueStateTracker
12
+
13
+
14
+ @dataclass
15
+ class SessionStartCommand(Command):
16
+ """A command to indicate the start of a session."""
17
+
18
+ @classmethod
19
+ def command(cls) -> str:
20
+ """Returns the command type."""
21
+ return "session start"
22
+
23
+ @classmethod
24
+ def from_dict(cls, data: Dict[str, Any]) -> SessionStartCommand:
25
+ """Converts the dictionary to a command.
26
+
27
+ Returns:
28
+ The converted dictionary.
29
+ """
30
+ return SessionStartCommand()
31
+
32
+ def run_command_on_tracker(
33
+ self,
34
+ tracker: DialogueStateTracker,
35
+ all_flows: FlowsList,
36
+ original_tracker: DialogueStateTracker,
37
+ ) -> List[Event]:
38
+ """Runs the command on the tracker.
39
+
40
+ Args:
41
+ tracker: The tracker to run the command on.
42
+ all_flows: All flows in the assistant.
43
+ original_tracker: The tracker before any command was executed.
44
+
45
+ Returns:
46
+ The events to apply to the tracker.
47
+ """
48
+ stack = tracker.stack
49
+ stack.push(SessionStartPatternFlowStackFrame())
50
+ return tracker.create_stack_updated_events(stack)
51
+
52
+ def __hash__(self) -> int:
53
+ return hash(self.command())
54
+
55
+ def __eq__(self, other: object) -> bool:
56
+ if not isinstance(other, SessionStartCommand):
57
+ return False
58
+
59
+ return True
@@ -127,11 +127,7 @@ class SetSlotCommand(Command):
127
127
  if (
128
128
  self.name not in slots_of_active_flow
129
129
  and self.name != ROUTE_TO_CALM_SLOT
130
- and self.extractor
131
- in {
132
- SetSlotExtractor.LLM.value,
133
- SetSlotExtractor.COMMAND_PAYLOAD_READER.value,
134
- }
130
+ and self.extractor == SetSlotExtractor.LLM.value
135
131
  ):
136
132
  # Get the other predicted flows from the most recent message on the tracker.
137
133
  predicted_flows = get_flows_predicted_to_start_from_tracker(tracker)
@@ -0,0 +1,38 @@
1
+ from typing import Dict, Type
2
+
3
+ from rasa.dialogue_understanding.commands import (
4
+ CancelFlowCommand,
5
+ CannotHandleCommand,
6
+ ChitChatAnswerCommand,
7
+ Command,
8
+ HumanHandoffCommand,
9
+ KnowledgeAnswerCommand,
10
+ SessionStartCommand,
11
+ SkipQuestionCommand,
12
+ )
13
+ from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
14
+ from rasa.dialogue_understanding.patterns.cannot_handle import (
15
+ CannotHandlePatternFlowStackFrame,
16
+ )
17
+ from rasa.dialogue_understanding.patterns.chitchat import ChitchatPatternFlowStackFrame
18
+ from rasa.dialogue_understanding.patterns.human_handoff import (
19
+ HumanHandoffPatternFlowStackFrame,
20
+ )
21
+ from rasa.dialogue_understanding.patterns.search import SearchPatternFlowStackFrame
22
+ from rasa.dialogue_understanding.patterns.session_start import (
23
+ SessionStartPatternFlowStackFrame,
24
+ )
25
+ from rasa.dialogue_understanding.patterns.skip_question import (
26
+ SkipQuestionPatternFlowStackFrame,
27
+ )
28
+
29
+
30
+ triggerable_pattern_to_command_class: Dict[str, Type[Command]] = {
31
+ SessionStartPatternFlowStackFrame.flow_id: SessionStartCommand,
32
+ CancelPatternFlowStackFrame.flow_id: CancelFlowCommand,
33
+ ChitchatPatternFlowStackFrame.flow_id: ChitChatAnswerCommand,
34
+ HumanHandoffPatternFlowStackFrame.flow_id: HumanHandoffCommand,
35
+ SearchPatternFlowStackFrame.flow_id: KnowledgeAnswerCommand,
36
+ SkipQuestionPatternFlowStackFrame.flow_id: SkipQuestionCommand,
37
+ CannotHandlePatternFlowStackFrame.flow_id: CannotHandleCommand,
38
+ }
@@ -1,14 +1,20 @@
1
+ from rasa.shared.constants import (
2
+ PROVIDER_CONFIG_KEY,
3
+ OPENAI_PROVIDER,
4
+ MODEL_CONFIG_KEY,
5
+ TIMEOUT_CONFIG_KEY,
6
+ )
1
7
  from rasa.shared.utils.llm import (
2
8
  DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
3
9
  DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
4
10
  )
5
11
 
6
12
  DEFAULT_LLM_CONFIG = {
7
- "_type": "openai",
8
- "request_timeout": 7,
13
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
14
+ MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
9
15
  "temperature": 0.0,
10
- "model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
11
16
  "max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
17
+ TIMEOUT_CONFIG_KEY: 7,
12
18
  }
13
19
 
14
20
  LLM_CONFIG_KEY = "llm"
@@ -16,3 +22,4 @@ USER_INPUT_CONFIG_KEY = "user_input"
16
22
 
17
23
  FLOW_RETRIEVAL_KEY = "flow_retrieval"
18
24
  FLOW_RETRIEVAL_ACTIVE_KEY = "active"
25
+ FLOW_RETRIEVAL_EMBEDDINGS_CONFIG_KEY = "embeddings"
@@ -25,16 +25,24 @@ import structlog
25
25
  from jinja2 import Template
26
26
  from langchain.docstore.document import Document
27
27
  from langchain.schema.embeddings import Embeddings
28
- from langchain.vectorstores.faiss import FAISS
29
- from langchain.vectorstores.utils import DistanceStrategy
28
+ from langchain_community.vectorstores.faiss import FAISS
29
+ from langchain_community.vectorstores.utils import DistanceStrategy
30
30
  from rasa.engine.storage.resource import Resource
31
31
  from rasa.engine.storage.storage import ModelStorage
32
+ from rasa.shared.constants import (
33
+ EMBEDDINGS_CONFIG_KEY,
34
+ PROVIDER_CONFIG_KEY,
35
+ OPENAI_PROVIDER,
36
+ )
32
37
  from rasa.shared.core.domain import Domain
33
38
  from rasa.shared.core.flows import FlowsList
34
39
  from rasa.shared.core.trackers import DialogueStateTracker
35
40
  from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
36
41
  from rasa.shared.nlu.training_data.message import Message
37
42
  from rasa.shared.exceptions import ProviderClientAPIException
43
+ from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
44
+ _LangchainEmbeddingClientAdapter,
45
+ )
38
46
  from rasa.shared.utils.llm import (
39
47
  tracker_as_readable_transcript,
40
48
  embedder_factory,
@@ -48,9 +56,8 @@ DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
48
56
  "rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
49
57
  )
50
58
 
51
- EMBEDDINGS_CONFIG_KEY = "embeddings"
52
59
  DEFAULT_EMBEDDINGS_CONFIG = {
53
- "_type": "openai",
60
+ PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
54
61
  "model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
55
62
  }
56
63
 
@@ -154,6 +161,7 @@ class FlowRetrieval:
154
161
  folder_path=model_path,
155
162
  embeddings=embeddings,
156
163
  distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
164
+ allow_dangerous_deserialization=True,
157
165
  )
158
166
  except Exception as e:
159
167
  structlogger.warning(
@@ -170,9 +178,10 @@ class FlowRetrieval:
170
178
  Returns:
171
179
  The embedder.
172
180
  """
173
- return embedder_factory(
181
+ client = embedder_factory(
174
182
  config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
175
183
  )
184
+ return _LangchainEmbeddingClientAdapter(client)
176
185
 
177
186
  def persist(self) -> None:
178
187
  self._persist_vector_store()
@@ -32,8 +32,9 @@ from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
32
32
  from rasa.shared.nlu.training_data.message import Message
33
33
  from rasa.shared.nlu.training_data.training_data import TrainingData
34
34
  from rasa.shared.utils.llm import (
35
- llm_factory,
36
35
  allowed_values_for_slot,
36
+ llm_factory,
37
+ try_instantiate_llm_client,
37
38
  )
38
39
  from rasa.utils.log_utils import log_llm
39
40
 
@@ -167,6 +168,14 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
167
168
  """Train the llm based command generator. Stores all flows into a vector
168
169
  store.
169
170
  """
171
+ # Validate llm configuration
172
+ try_instantiate_llm_client(
173
+ self.config.get(LLM_CONFIG_KEY),
174
+ DEFAULT_LLM_CONFIG,
175
+ "llm_based_command_generator.train",
176
+ "LLMBasedCommandGenerator",
177
+ )
178
+
170
179
  # flow retrieval is populated with only user-defined flows
171
180
  try:
172
181
  if self.flow_retrieval is not None and not flows.is_empty():
@@ -286,7 +295,8 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
286
295
  """
287
296
  llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
288
297
  try:
289
- return await llm.apredict(prompt)
298
+ llm_response = await llm.acompletion(prompt)
299
+ return llm_response.choices[0]
290
300
  except Exception as e:
291
301
  # unfortunately, langchain does not wrap LLM exceptions which means
292
302
  # we have to catch all exceptions here