rasa-pro 3.13.0.dev2__py3-none-any.whl → 3.13.0.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (173) hide show
  1. rasa/__main__.py +3 -1
  2. rasa/cli/inspect.py +8 -4
  3. rasa/cli/project_templates/default/config.yml +5 -32
  4. rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_cancels_during_a_correction.yml +1 -1
  5. rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +1 -1
  6. rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_handle.yml +1 -1
  7. rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_name.yml +1 -1
  8. rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +1 -1
  9. rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_lists_contacts.yml +1 -1
  10. rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact.yml +1 -1
  11. rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact_from_list.yml +1 -1
  12. rasa/cli/project_templates/default/endpoints.yml +18 -2
  13. rasa/cli/run.py +10 -6
  14. rasa/cli/scaffold.py +3 -4
  15. rasa/cli/studio/download.py +1 -1
  16. rasa/cli/studio/upload.py +0 -6
  17. rasa/cli/utils.py +7 -0
  18. rasa/core/channels/channel.py +93 -0
  19. rasa/core/channels/inspector/dist/assets/{arc-c7691751.js → arc-9f75cc3b.js} +1 -1
  20. rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-ab99dff7.js → blockDiagram-38ab4fdb-7f34db23.js} +1 -1
  21. rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-08c35a6b.js → c4Diagram-3d4e48cf-948bab2c.js} +1 -1
  22. rasa/core/channels/inspector/dist/assets/channel-dfa68278.js +1 -0
  23. rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-9e9c71c9.js → classDiagram-70f12bd4-53b0dd0e.js} +1 -1
  24. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-15e7e2bf.js → classDiagram-v2-f2320105-fdf789e7.js} +1 -1
  25. rasa/core/channels/inspector/dist/assets/clone-edb7f119.js +1 -0
  26. rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-9c105cb1.js → createText-2e5e7dd3-87c4ece5.js} +1 -1
  27. rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-77e89e48.js → edges-e0da2a9e-5a8b0749.js} +1 -1
  28. rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-7a011646.js → erDiagram-9861fffd-66da90e2.js} +1 -1
  29. rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-b6f105ac.js → flowDb-956e92f1-10044f05.js} +1 -1
  30. rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-ce4f18c2.js → flowDiagram-66a62f08-f338f66a.js} +1 -1
  31. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-65e7c670.js +1 -0
  32. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-cb5f6da4.js → flowchart-elk-definition-4a651766-b13140aa.js} +1 -1
  33. rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-e4d19e28.js → ganttDiagram-c361ad54-f2b4a55a.js} +1 -1
  34. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-727b1c33.js → gitGraphDiagram-72cf32ee-dedc298d.js} +1 -1
  35. rasa/core/channels/inspector/dist/assets/{graph-6e2ab9a7.js → graph-4ede11ff.js} +1 -1
  36. rasa/core/channels/inspector/dist/assets/{index-3862675e-84ec700f.js → index-3862675e-65549d37.js} +1 -1
  37. rasa/core/channels/inspector/dist/assets/{index-098a1a24.js → index-3a23e736.js} +142 -129
  38. rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-78dda442.js → infoDiagram-f8f76790-65439671.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-f1cc6dd1.js → journeyDiagram-49397b02-56d03d98.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{layout-d98dcd0c.js → layout-dd48f7f4.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{line-838e3d82.js → line-1569ad2c.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{linear-eae72406.js → linear-48bf4935.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-c96fd84b.js → mindmap-definition-fc14e90a-688504c1.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-c936d4e2.js → pieDiagram-8a3498a8-78b6d7e6.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-b338eb8f.js → quadrantDiagram-120e2f19-048b84b3.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-c6b6c0d5.js → requirementDiagram-deff3bca-dd67f107.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-b9372e19.js → sankeyDiagram-04a897e0-8128436e.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-479e0a3f.js → sequenceDiagram-704730f1-1a0d1461.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-fd26eebc.js → stateDiagram-587899a1-46d388ed.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-3233e0ae.js → stateDiagram-v2-d93cdb3a-ea42951a.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-1fdd392b.js → styles-6aaf32cf-7427ed0c.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{styles-9a916d00-6d7bfa1b.js → styles-9a916d00-ff5e5a16.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{styles-c10674c1-f86aab11.js → styles-c10674c1-7b3680cf.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-e3e49d7a.js → svgDrawCommon-08f97a94-f860f2ad.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-6fe08b4d.js → timeline-definition-85554ec2-2eebf0c8.js} +1 -1
  56. rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-c2e06fd6.js → xychartDiagram-e933f94c-5d7f4e96.js} +1 -1
  57. rasa/core/channels/inspector/dist/index.html +1 -1
  58. rasa/core/channels/inspector/src/App.tsx +3 -2
  59. rasa/core/channels/inspector/src/components/Chat.tsx +23 -2
  60. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +2 -5
  61. rasa/core/channels/inspector/src/helpers/conversation.ts +16 -0
  62. rasa/core/channels/inspector/src/types.ts +1 -1
  63. rasa/core/channels/voice_ready/audiocodes.py +41 -15
  64. rasa/core/channels/voice_ready/jambonz.py +25 -5
  65. rasa/core/channels/voice_ready/jambonz_protocol.py +4 -0
  66. rasa/core/channels/voice_ready/twilio_voice.py +48 -1
  67. rasa/core/channels/voice_stream/tts/azure.py +11 -2
  68. rasa/core/channels/voice_stream/twilio_media_streams.py +101 -26
  69. rasa/core/channels/voice_stream/voice_channel.py +28 -2
  70. rasa/core/concurrent_lock_store.py +24 -10
  71. rasa/core/information_retrieval/faiss.py +7 -68
  72. rasa/core/information_retrieval/information_retrieval.py +2 -40
  73. rasa/core/information_retrieval/milvus.py +2 -7
  74. rasa/core/information_retrieval/qdrant.py +2 -7
  75. rasa/core/lock_store.py +151 -60
  76. rasa/core/nlg/contextual_response_rephraser.py +3 -0
  77. rasa/core/policies/enterprise_search_policy.py +310 -61
  78. rasa/core/policies/intentless_policy.py +3 -0
  79. rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -0
  80. rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
  81. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  82. rasa/dialogue_understanding/generator/flow_retrieval.py +1 -4
  83. rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -2
  84. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +13 -0
  85. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  86. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
  87. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +2 -24
  88. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +22 -17
  89. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +27 -12
  90. rasa/dialogue_understanding_test/du_test_case.py +16 -8
  91. rasa/dialogue_understanding_test/io.py +8 -13
  92. rasa/e2e_test/utils/validation.py +3 -3
  93. rasa/engine/recipes/default_components.py +0 -2
  94. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +3 -0
  95. rasa/plugin.py +0 -3
  96. rasa/shared/constants.py +1 -0
  97. rasa/shared/core/domain.py +165 -11
  98. rasa/shared/core/flows/flow.py +155 -131
  99. rasa/shared/core/flows/flow_step.py +19 -3
  100. rasa/shared/core/flows/flow_step_links.py +15 -0
  101. rasa/shared/core/flows/flow_step_sequence.py +6 -0
  102. rasa/shared/core/flows/nlu_trigger.py +13 -0
  103. rasa/shared/core/flows/steps/action.py +7 -4
  104. rasa/shared/core/flows/steps/call.py +11 -4
  105. rasa/shared/core/flows/steps/collect.py +27 -6
  106. rasa/shared/core/flows/steps/internal.py +6 -1
  107. rasa/shared/core/flows/steps/link.py +7 -4
  108. rasa/shared/core/flows/steps/no_operation.py +7 -4
  109. rasa/shared/core/flows/steps/set_slots.py +8 -4
  110. rasa/shared/core/flows/yaml_flows_io.py +106 -5
  111. rasa/shared/importers/importer.py +8 -0
  112. rasa/shared/providers/_utils.py +83 -0
  113. rasa/shared/providers/llm/_base_litellm_client.py +6 -3
  114. rasa/shared/providers/llm/azure_openai_llm_client.py +6 -68
  115. rasa/shared/providers/router/_base_litellm_router_client.py +53 -1
  116. rasa/shared/utils/common.py +42 -0
  117. rasa/shared/utils/constants.py +3 -0
  118. rasa/shared/utils/llm.py +70 -24
  119. rasa/studio/download/domains.py +49 -0
  120. rasa/studio/download/download.py +439 -0
  121. rasa/studio/download/flows.py +359 -0
  122. rasa/studio/results_logger.py +6 -1
  123. rasa/studio/upload.py +69 -5
  124. rasa/tracing/instrumentation/attribute_extractors.py +7 -10
  125. rasa/tracing/instrumentation/instrumentation.py +12 -12
  126. rasa/utils/common.py +36 -0
  127. rasa/utils/endpoints.py +22 -1
  128. rasa/utils/licensing.py +1 -1
  129. rasa/validator.py +1 -2
  130. rasa/version.py +1 -1
  131. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/METADATA +7 -7
  132. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/RECORD +149 -166
  133. rasa/cli/project_templates/calm/config.yml +0 -10
  134. rasa/cli/project_templates/calm/credentials.yml +0 -33
  135. rasa/cli/project_templates/calm/endpoints.yml +0 -58
  136. rasa/cli/project_templates/default/actions/actions.py +0 -27
  137. rasa/cli/project_templates/default/data/nlu.yml +0 -91
  138. rasa/cli/project_templates/default/data/rules.yml +0 -13
  139. rasa/cli/project_templates/default/data/stories.yml +0 -30
  140. rasa/cli/project_templates/default/domain.yml +0 -34
  141. rasa/cli/project_templates/default/tests/test_stories.yml +0 -91
  142. rasa/core/channels/inspector/dist/assets/channel-11268142.js +0 -1
  143. rasa/core/channels/inspector/dist/assets/clone-ff7f2ce7.js +0 -1
  144. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-cba7ae20.js +0 -1
  145. rasa/document_retrieval/__init__.py +0 -0
  146. rasa/document_retrieval/constants.py +0 -32
  147. rasa/document_retrieval/document_post_processor.py +0 -351
  148. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  149. rasa/document_retrieval/document_retriever.py +0 -333
  150. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  151. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +0 -39
  152. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +0 -34
  153. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +0 -226
  154. rasa/document_retrieval/query_rewriter.py +0 -234
  155. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +0 -8
  156. rasa/studio/download.py +0 -489
  157. /rasa/cli/project_templates/{calm → default}/actions/action_template.py +0 -0
  158. /rasa/cli/project_templates/{calm → default}/actions/add_contact.py +0 -0
  159. /rasa/cli/project_templates/{calm → default}/actions/db.py +0 -0
  160. /rasa/cli/project_templates/{calm → default}/actions/list_contacts.py +0 -0
  161. /rasa/cli/project_templates/{calm → default}/actions/remove_contact.py +0 -0
  162. /rasa/cli/project_templates/{calm → default}/data/flows/add_contact.yml +0 -0
  163. /rasa/cli/project_templates/{calm → default}/data/flows/list_contacts.yml +0 -0
  164. /rasa/cli/project_templates/{calm → default}/data/flows/remove_contact.yml +0 -0
  165. /rasa/cli/project_templates/{calm → default}/db/contacts.json +0 -0
  166. /rasa/cli/project_templates/{calm → default}/domain/add_contact.yml +0 -0
  167. /rasa/cli/project_templates/{calm → default}/domain/list_contacts.yml +0 -0
  168. /rasa/cli/project_templates/{calm → default}/domain/remove_contact.yml +0 -0
  169. /rasa/cli/project_templates/{calm → default}/domain/shared.yml +0 -0
  170. /rasa/{cli/project_templates/calm/actions → studio/download}/__init__.py +0 -0
  171. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/NOTICE +0 -0
  172. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/WHEEL +0 -0
  173. {rasa_pro-3.13.0.dev2.dist-info → rasa_pro-3.13.0.dev5.dist-info}/entry_points.txt +0 -0
@@ -1,351 +0,0 @@
1
- import asyncio
2
- import importlib.resources
3
- from enum import Enum
4
- from functools import lru_cache
5
- from typing import Any, Dict, Optional, Text
6
-
7
- import structlog
8
- from jinja2 import Template
9
-
10
- import rasa.shared.utils.io
11
- from rasa.core.information_retrieval import SearchResult, SearchResultList
12
- from rasa.dialogue_understanding.utils import add_prompt_to_message_parse_data
13
- from rasa.engine.storage.resource import Resource
14
- from rasa.engine.storage.storage import ModelStorage
15
- from rasa.shared.constants import (
16
- LLM_CONFIG_KEY,
17
- MODEL_CONFIG_KEY,
18
- OPENAI_PROVIDER,
19
- PROMPT_TEMPLATE_CONFIG_KEY,
20
- PROVIDER_CONFIG_KEY,
21
- TEXT,
22
- TIMEOUT_CONFIG_KEY,
23
- )
24
- from rasa.shared.core.trackers import DialogueStateTracker
25
- from rasa.shared.exceptions import FileIOException, ProviderClientAPIException
26
- from rasa.shared.nlu.training_data.message import Message
27
- from rasa.shared.providers.llm.llm_client import LLMClient
28
- from rasa.shared.providers.llm.llm_response import LLMResponse
29
- from rasa.shared.utils.health_check.health_check import perform_llm_health_check
30
- from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
31
- from rasa.shared.utils.llm import (
32
- DEFAULT_OPENAI_GENERATE_MODEL_NAME,
33
- DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
34
- get_prompt_template,
35
- llm_factory,
36
- resolve_model_client_config,
37
- tracker_as_readable_transcript,
38
- )
39
-
40
- TYPE_CONFIG_KEY = "type"
41
- EMBEDDING_MODEL_KEY = "embedding_model_name"
42
-
43
- DEFAULT_EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
44
- DOCUMENT_POST_PROCESSOR_PROMPT_FILE_NAME = (
45
- "document_post_processor_prompt_template.jina2"
46
- )
47
- DEFAULT_LLM_CONFIG = {
48
- PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
49
- MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
50
- "temperature": 0.3,
51
- "max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
52
- TIMEOUT_CONFIG_KEY: 5,
53
- }
54
- DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE = importlib.resources.read_text(
55
- "rasa.document_retrieval",
56
- "document_post_processor_prompt_template.jinja2",
57
- )
58
-
59
- structlogger = structlog.get_logger()
60
-
61
-
62
- class PostProcessingType(Enum):
63
- PLAIN = "PLAIN"
64
- AGGREGATED_SUMMARY = "AGGREGATED_SUMMARY"
65
- INDIVIDUAL_SUMMARIES = "INDIVIDUAL_SUMMARIES"
66
- BINARY_LLM = "BINARY_LLM"
67
- BINARY_EMBEDDING_MODEL = "BINARY_EMBEDDING_MODEL"
68
- FINAL_ANSWER = "FINAL_ANSWER"
69
-
70
- def __str__(self) -> str:
71
- return self.value
72
-
73
-
74
- class DocumentPostProcessor(LLMHealthCheckMixin):
75
- @classmethod
76
- def get_default_config(cls) -> Dict[str, Any]:
77
- """The default config for the document post processor."""
78
- return {
79
- TYPE_CONFIG_KEY: PostProcessingType.PLAIN,
80
- LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
81
- PROMPT_TEMPLATE_CONFIG_KEY: DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE,
82
- }
83
-
84
- def __init__(
85
- self,
86
- config: Dict[str, Any],
87
- model_storage: ModelStorage,
88
- resource: Resource,
89
- prompt_template: Optional[str] = None,
90
- ):
91
- self.config = {**self.get_default_config(), **config}
92
- self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
93
- self.config.get(LLM_CONFIG_KEY), DocumentPostProcessor.__name__
94
- )
95
- self.prompt_template = prompt_template or get_prompt_template(
96
- config.get(PROMPT_TEMPLATE_CONFIG_KEY),
97
- DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE,
98
- )
99
-
100
- self._model_storage = model_storage
101
- self._resource = resource
102
-
103
- @classmethod
104
- def load(
105
- cls,
106
- config: Dict[Text, Any],
107
- model_storage: ModelStorage,
108
- resource: Resource,
109
- **kwargs: Any,
110
- ) -> "DocumentPostProcessor":
111
- """Load document post processor."""
112
- llm_config = resolve_model_client_config(config.get(LLM_CONFIG_KEY, {}))
113
- perform_llm_health_check(
114
- llm_config,
115
- DEFAULT_LLM_CONFIG,
116
- "document_post_processor.load",
117
- DocumentPostProcessor.__name__,
118
- )
119
-
120
- # load prompt template
121
- prompt_template = None
122
- try:
123
- with model_storage.read_from(resource) as path:
124
- prompt_template = rasa.shared.utils.io.read_file(
125
- path / DOCUMENT_POST_PROCESSOR_PROMPT_FILE_NAME
126
- )
127
- except (FileNotFoundError, FileIOException) as e:
128
- structlogger.warning(
129
- "document_post_processor.load_prompt_template.failed",
130
- error=e,
131
- resource=resource.name,
132
- )
133
-
134
- return DocumentPostProcessor(config, model_storage, resource, prompt_template)
135
-
136
- def persist(self) -> None:
137
- with self._model_storage.write_to(self._resource) as path:
138
- rasa.shared.utils.io.write_text_file(
139
- self.prompt_template, path / DOCUMENT_POST_PROCESSOR_PROMPT_FILE_NAME
140
- )
141
-
142
- async def process_documents(
143
- self,
144
- message: Message,
145
- search_query: str,
146
- documents: SearchResultList,
147
- tracker: DialogueStateTracker,
148
- ) -> SearchResultList:
149
- processing_type = self.config.get(TYPE_CONFIG_KEY)
150
-
151
- llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
152
-
153
- if processing_type == PostProcessingType.AGGREGATED_SUMMARY.value:
154
- return await self._create_aggregated_summary(documents, llm)
155
-
156
- elif processing_type == PostProcessingType.INDIVIDUAL_SUMMARIES.value:
157
- return await self._create_individual_summaries(documents, llm)
158
-
159
- elif processing_type == PostProcessingType.BINARY_LLM.value:
160
- return await self._check_documents_relevance_to_user_query(
161
- message, search_query, documents, llm, tracker
162
- )
163
-
164
- elif processing_type == PostProcessingType.BINARY_EMBEDDING_MODEL.value:
165
- return (
166
- await self._check_documents_relevance_to_user_query_using_modern_bert(
167
- search_query,
168
- documents,
169
- )
170
- )
171
-
172
- elif processing_type == PostProcessingType.PLAIN.value:
173
- return documents
174
-
175
- elif processing_type == PostProcessingType.FINAL_ANSWER.value:
176
- return await self._generate_final_answer(message, documents, llm, tracker)
177
-
178
- else:
179
- raise ValueError(f"Invalid postprocessing type: {processing_type}")
180
-
181
- @lru_cache
182
- def compile_template(self, template: str) -> Template:
183
- """Compile the prompt template.
184
-
185
- Compiling the template is an expensive operation,
186
- so we cache the result.
187
- """
188
- return Template(template)
189
-
190
- def render_prompt(self, data: Dict) -> str:
191
- # TODO: This should probably be fixed, as the default prompt template is empty
192
- # If there are default templates for summarization they should be created,
193
- # and ideally be initialized based on the processing type.
194
- prompt_template = get_prompt_template(
195
- self.config.get(PROMPT_TEMPLATE_CONFIG_KEY),
196
- DEFAULT_DOCUMENT_POST_PROCESSOR_PROMPT_TEMPLATE,
197
- )
198
- return self.compile_template(prompt_template).render(**data)
199
-
200
- async def _invoke_llm(self, prompt: str, llm: LLMClient) -> Optional[LLMResponse]:
201
- try:
202
- return await llm.acompletion(prompt)
203
- except Exception as e:
204
- # unfortunately, langchain does not wrap LLM exceptions which means
205
- # we have to catch all exceptions here
206
- structlogger.error("document_post_processor.llm.error", error=e)
207
- raise ProviderClientAPIException(
208
- message="LLM call exception", original_exception=e
209
- )
210
-
211
- async def _create_aggregated_summary(
212
- self, documents: SearchResultList, llm: LLMClient
213
- ) -> SearchResultList:
214
- prompt = self.render_prompt(
215
- {"retrieval_results": [doc.text for doc in documents.results]}
216
- )
217
-
218
- llm_response = await self._invoke_llm(prompt, llm)
219
- aggregated_summary = LLMResponse.ensure_llm_response(llm_response)
220
-
221
- aggregated_result = SearchResult(
222
- text=aggregated_summary.choices[0], metadata={}
223
- )
224
-
225
- return SearchResultList(results=[aggregated_result], metadata={})
226
-
227
- async def _create_individual_summaries(
228
- self, documents: SearchResultList, llm: LLMClient
229
- ) -> SearchResultList:
230
- tasks = []
231
-
232
- for doc in documents.results:
233
- prompt_template = self.render_prompt({"retrieval_results": doc.text})
234
- prompt = prompt_template.format(doc.text, llm)
235
- tasks.append(asyncio.create_task(self._invoke_llm(prompt, llm)))
236
-
237
- llm_responses = await asyncio.gather(*tasks)
238
- summarized_contents = [
239
- LLMResponse.ensure_llm_response(summary) for summary in llm_responses
240
- ]
241
-
242
- results = [
243
- SearchResult(text=summary.choices[0], metadata={})
244
- for summary in summarized_contents
245
- ]
246
- return SearchResultList(results=results, metadata={})
247
-
248
- async def _check_documents_relevance_to_user_query(
249
- self,
250
- message: Message,
251
- search_query: str,
252
- documents: SearchResultList,
253
- llm: LLMClient,
254
- tracker: DialogueStateTracker,
255
- ) -> SearchResultList:
256
- # If no documents were retrieved from the vector store, the
257
- # documents seem to be irrelevant. Respond with "NO".
258
- if not documents.results:
259
- return SearchResultList(
260
- results=[
261
- SearchResult(
262
- text="NO",
263
- metadata={},
264
- )
265
- ],
266
- metadata={},
267
- )
268
-
269
- prompt_data = {
270
- "search_query": search_query,
271
- "relevant_documents": documents,
272
- "conversation": tracker_as_readable_transcript(tracker, max_turns=10),
273
- }
274
-
275
- prompt = self.render_prompt(prompt_data)
276
-
277
- llm_response = await self._invoke_llm(prompt, llm)
278
- documents_relevance = LLMResponse.ensure_llm_response(llm_response)
279
-
280
- aggregated_result = SearchResult(
281
- text=documents_relevance.choices[0],
282
- metadata={},
283
- )
284
-
285
- add_prompt_to_message_parse_data(
286
- message=message,
287
- component_name=self.__class__.__name__,
288
- prompt_name="document_post_processor",
289
- user_prompt=prompt,
290
- llm_response=llm_response,
291
- )
292
- structlogger.debug(
293
- "document_post_processor._check_documents_relevance_to_user_query",
294
- prompt=prompt,
295
- documents=[d.text for d in documents.results],
296
- llm_response=llm_response,
297
- )
298
-
299
- return SearchResultList(results=[aggregated_result], metadata={})
300
-
301
- async def _check_documents_relevance_to_user_query_using_modern_bert(
302
- self,
303
- search_query: str,
304
- documents: SearchResultList,
305
- threshold: float = 0.5,
306
- ) -> SearchResultList:
307
- import torch
308
- from sentence_transformers import SentenceTransformer
309
-
310
- self.model = SentenceTransformer(
311
- self.config.get(EMBEDDING_MODEL_KEY, DEFAULT_EMBEDDING_MODEL_NAME),
312
- trust_remote_code=True,
313
- )
314
-
315
- query_embeddings = self.model.encode(["search_query: " + search_query])
316
- doc_embeddings = self.model.encode(
317
- ["search_document: " + doc.text for doc in documents.results]
318
- )
319
-
320
- similarities = self.model.similarity(query_embeddings, doc_embeddings)
321
-
322
- is_any_doc_relevant = torch.any(similarities > threshold).item()
323
-
324
- return SearchResultList(
325
- results=[
326
- SearchResult(text="YES" if is_any_doc_relevant else "NO", metadata={})
327
- ],
328
- metadata={},
329
- )
330
-
331
- async def _generate_final_answer(
332
- self,
333
- message: Message,
334
- documents: SearchResultList,
335
- llm: LLMClient,
336
- tracker: DialogueStateTracker,
337
- ) -> SearchResultList:
338
- input = {
339
- "current_conversation": tracker_as_readable_transcript(tracker),
340
- "relevant_documents": documents,
341
- "user_message": message.get(TEXT),
342
- }
343
- prompt = self.render_prompt(input)
344
- response = await self._invoke_llm(prompt, llm)
345
- response_text = response.choices[0] if response else ""
346
- search_result = SearchResult(text=response_text, metadata={})
347
- results = SearchResultList(
348
- results=[search_result],
349
- metadata={},
350
- )
351
- return results
@@ -1,333 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import datetime
4
- import time
5
- import uuid
6
- from enum import Enum
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text
8
-
9
- import structlog
10
-
11
- from rasa.core.information_retrieval.faiss import FAISS_Store
12
- from rasa.dialogue_understanding.utils import add_prompt_to_message_parse_data
13
- from rasa.document_retrieval.constants import (
14
- CONNECTOR_CONFIG_KEY,
15
- DEFAULT_K,
16
- DEFAULT_THRESHOLD,
17
- K_CONFIG_KEY,
18
- POST_PROCESSED_DOCUMENTS_KEY,
19
- POST_PROCESSING_CONFIG_KEY,
20
- QUERY_REWRITING_CONFIG_KEY,
21
- RETRIEVED_DOCUMENTS_KEY,
22
- SEARCH_QUERY_KEY,
23
- SOURCE_PROPERTY,
24
- THRESHOLD_CONFIG_KEY,
25
- USE_LLM_PROPERTY,
26
- VECTOR_STORE_CONFIG_KEY,
27
- VECTOR_STORE_TYPE_CONFIG_KEY,
28
- )
29
- from rasa.document_retrieval.document_post_processor import DocumentPostProcessor
30
- from rasa.document_retrieval.knowledge_base_connectors.api_connector import APIConnector
31
- from rasa.document_retrieval.knowledge_base_connectors.knowledge_base_connector import (
32
- KnowledgeBaseConnector,
33
- )
34
- from rasa.document_retrieval.knowledge_base_connectors.vector_store_connector import (
35
- DEFAULT_EMBEDDINGS_CONFIG,
36
- VectorStoreConnector,
37
- VectorStoreType,
38
- )
39
- from rasa.document_retrieval.query_rewriter import QueryRewriter
40
- from rasa.engine.graph import ExecutionContext, GraphComponent
41
- from rasa.engine.recipes.default_recipe import DefaultV1Recipe
42
- from rasa.engine.storage.resource import Resource
43
- from rasa.engine.storage.storage import ModelStorage
44
- from rasa.shared.constants import (
45
- EMBEDDINGS_CONFIG_KEY,
46
- )
47
- from rasa.shared.core.trackers import DialogueStateTracker
48
- from rasa.shared.nlu.training_data.message import Message
49
- from rasa.shared.nlu.training_data.training_data import TrainingData
50
- from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
51
- _LangchainEmbeddingClientAdapter,
52
- )
53
- from rasa.shared.providers.llm.llm_response import LLMResponse
54
- from rasa.shared.utils.llm import (
55
- embedder_factory,
56
- resolve_model_client_config,
57
- )
58
-
59
- if TYPE_CHECKING:
60
- from langchain.schema.embeddings import Embeddings
61
-
62
- structlogger = structlog.get_logger()
63
-
64
-
65
- class ConnectorType(Enum):
66
- API = "API"
67
- VECTOR_STORE = "VECTOR_STORE"
68
-
69
- def __str__(self) -> str:
70
- return self.value
71
-
72
-
73
- @DefaultV1Recipe.register(
74
- [
75
- DefaultV1Recipe.ComponentType.COEXISTENCE_ROUTER,
76
- ],
77
- is_trainable=True,
78
- )
79
- class DocumentRetriever(GraphComponent):
80
- @staticmethod
81
- def get_default_config() -> Dict[str, Any]:
82
- """The component's default config (see parent class for full docstring)."""
83
- return {
84
- THRESHOLD_CONFIG_KEY: DEFAULT_THRESHOLD,
85
- K_CONFIG_KEY: DEFAULT_K,
86
- CONNECTOR_CONFIG_KEY: ConnectorType.VECTOR_STORE.value,
87
- EMBEDDINGS_CONFIG_KEY: DEFAULT_EMBEDDINGS_CONFIG,
88
- VECTOR_STORE_CONFIG_KEY: {
89
- VECTOR_STORE_TYPE_CONFIG_KEY: VectorStoreType.FAISS.value,
90
- },
91
- }
92
-
93
- def __init__(
94
- self,
95
- config: Dict[str, Any],
96
- model_storage: ModelStorage,
97
- resource: Resource,
98
- query_rewriter: Optional[QueryRewriter] = None,
99
- document_post_processor: Optional[DocumentPostProcessor] = None,
100
- knowledge_base_connector: Optional[KnowledgeBaseConnector] = None,
101
- ) -> None:
102
- self.config = {**self.get_default_config(), **config}
103
- self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
104
- self.config.get(EMBEDDINGS_CONFIG_KEY), DocumentRetriever.__name__
105
- )
106
-
107
- self._model_storage = model_storage
108
- self._resource = resource
109
-
110
- # Disable query rewriting and post processing if they are not set
111
- query_rewriting_config = config.get(
112
- QUERY_REWRITING_CONFIG_KEY, {"type": "PLAIN"}
113
- )
114
- post_processing_config = config.get(
115
- POST_PROCESSING_CONFIG_KEY, {"type": "PLAIN"}
116
- )
117
-
118
- self.query_rewriter = query_rewriter or QueryRewriter(
119
- query_rewriting_config, model_storage, resource
120
- )
121
- self.document_post_processor = document_post_processor or DocumentPostProcessor(
122
- post_processing_config, model_storage, resource
123
- )
124
- self.knowledge_base_connector = (
125
- knowledge_base_connector or self.initialize_knowledge_base_connector()
126
- )
127
-
128
- self.use_llm = self.config.get(USE_LLM_PROPERTY, False)
129
-
130
- def persist(self) -> None:
131
- """Persist this component to disk for future loading."""
132
- self.query_rewriter.persist()
133
- self.document_post_processor.persist()
134
-
135
- @classmethod
136
- def _create_plain_embedder(cls, config: Dict[Text, Any]) -> "Embeddings":
137
- """Creates an embedder based on the given configuration.
138
-
139
- Returns:
140
- The embedder.
141
- """
142
- # Copy the config so original config is not modified
143
- config = config.copy()
144
- # Resolve config and instantiate the embedding client
145
- config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
146
- config.get(EMBEDDINGS_CONFIG_KEY), DocumentRetriever.__name__
147
- )
148
- client = embedder_factory(
149
- config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
150
- )
151
- # Wrap the embedding client in the adapter
152
- return _LangchainEmbeddingClientAdapter(client)
153
-
154
- def train(self, training_data: TrainingData) -> Resource:
155
- """Train the document retriever on a data set."""
156
- store_type = self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
157
- VECTOR_STORE_TYPE_CONFIG_KEY
158
- )
159
- if store_type == VectorStoreType.FAISS.value:
160
- structlogger.info("document_retriever.train.faiss")
161
- embeddings = self._create_plain_embedder(self.config)
162
- with self._model_storage.write_to(self._resource) as path:
163
- self.vector_store = FAISS_Store(
164
- docs_folder=self.config.get(VECTOR_STORE_CONFIG_KEY, {}).get(
165
- SOURCE_PROPERTY
166
- ),
167
- embeddings=embeddings,
168
- index_path=path,
169
- create_index=True,
170
- use_llm=self.use_llm,
171
- )
172
- self.persist()
173
- return self._resource
174
-
175
- @classmethod
176
- def load(
177
- cls,
178
- config: Dict[str, Any],
179
- model_storage: ModelStorage,
180
- resource: Resource,
181
- execution_context: ExecutionContext,
182
- **kwargs: Any,
183
- ) -> "DocumentRetriever":
184
- """Loads trained component (see parent class for full docstring)."""
185
- # Load query rewriter and document post processor
186
-
187
- # Disable query rewriting and post processing if they are not set
188
- query_rewriting_config = config.get(
189
- QUERY_REWRITING_CONFIG_KEY, {"type": "PLAIN"}
190
- )
191
- post_processing_config = config.get(
192
- POST_PROCESSING_CONFIG_KEY, {"type": "PLAIN"}
193
- )
194
-
195
- query_rewriter = QueryRewriter.load(
196
- query_rewriting_config, model_storage, resource
197
- )
198
- document_post_processor = DocumentPostProcessor.load(
199
- post_processing_config, model_storage, resource
200
- )
201
-
202
- connector_type = config.get(CONNECTOR_CONFIG_KEY)
203
- knowledge_base_connector: KnowledgeBaseConnector
204
-
205
- if connector_type == ConnectorType.VECTOR_STORE.value:
206
- knowledge_base_connector = VectorStoreConnector.load(
207
- config, model_storage, resource
208
- )
209
- elif connector_type == ConnectorType.API.value:
210
- knowledge_base_connector = APIConnector.load(
211
- config, model_storage, resource
212
- )
213
- else:
214
- raise ValueError(f"Invalid knowledge base connector: {connector_type}")
215
-
216
- return cls(
217
- config,
218
- model_storage,
219
- resource,
220
- query_rewriter,
221
- document_post_processor,
222
- knowledge_base_connector,
223
- )
224
-
225
- @classmethod
226
- def create(
227
- cls,
228
- config: Dict[str, Any],
229
- model_storage: ModelStorage,
230
- resource: Resource,
231
- execution_context: ExecutionContext,
232
- ) -> DocumentRetriever:
233
- """Creates component (see parent class for full docstring)."""
234
- return cls(config, model_storage, resource)
235
-
236
- def initialize_knowledge_base_connector(self) -> KnowledgeBaseConnector:
237
- connector_type = self.config.get(CONNECTOR_CONFIG_KEY)
238
-
239
- if connector_type == ConnectorType.VECTOR_STORE.value:
240
- return VectorStoreConnector(
241
- self.config,
242
- self._model_storage,
243
- self._resource,
244
- )
245
- elif connector_type == ConnectorType.API.value:
246
- return APIConnector(self.config)
247
- else:
248
- raise ValueError(f"Invalid knowledge base connector: {type}")
249
-
250
- async def process(
251
- self,
252
- messages: List[Message],
253
- tracker: Optional[DialogueStateTracker] = None,
254
- ) -> List[Message]:
255
- """Process a list of messages."""
256
- self.knowledge_base_connector.connect_or_raise()
257
-
258
- for message in messages:
259
- start = time.time()
260
-
261
- # Prepare search query
262
- search_query = await self.query_rewriter.prepare_search_query(
263
- message, tracker
264
- )
265
- message.set(
266
- SEARCH_QUERY_KEY,
267
- search_query,
268
- add_to_output=True,
269
- )
270
-
271
- # Retrieve documents
272
- search_result = await self.knowledge_base_connector.retrieve_documents(
273
- search_query,
274
- self.config[K_CONFIG_KEY] or DEFAULT_K,
275
- self.config[THRESHOLD_CONFIG_KEY] or DEFAULT_THRESHOLD,
276
- tracker,
277
- )
278
-
279
- if search_result is None:
280
- message.set(
281
- RETRIEVED_DOCUMENTS_KEY,
282
- [],
283
- add_to_output=True,
284
- )
285
- message.set(
286
- POST_PROCESSED_DOCUMENTS_KEY,
287
- [],
288
- add_to_output=True,
289
- )
290
- continue
291
-
292
- message.set(
293
- RETRIEVED_DOCUMENTS_KEY,
294
- search_result.to_dict(),
295
- add_to_output=True,
296
- )
297
-
298
- # Post process documents
299
- final_search_result = await self.document_post_processor.process_documents(
300
- message, search_query, search_result, tracker
301
- )
302
- message.set(
303
- POST_PROCESSED_DOCUMENTS_KEY,
304
- final_search_result.to_dict(),
305
- add_to_output=True,
306
- )
307
-
308
- structlogger.debug(
309
- "document_retriever.process",
310
- search_query=search_query,
311
- search_result=search_result.to_dict(),
312
- final_search_result=final_search_result.to_dict(),
313
- )
314
-
315
- end = time.time()
316
- add_prompt_to_message_parse_data(
317
- message,
318
- DocumentRetriever.__name__,
319
- "document_retriever_process",
320
- user_prompt="Dummy prompt for document retriever process.",
321
- llm_response=LLMResponse(
322
- id=str(uuid.uuid4()),
323
- choices=[
324
- f"search_query: {search_query}\n"
325
- f"retrieved_documents: {search_result.to_dict()}\n"
326
- f"post_processed_documents: {final_search_result.to_dict()}",
327
- ],
328
- created=int(datetime.datetime.now().timestamp()),
329
- latency=end - start,
330
- ),
331
- )
332
-
333
- return messages