rasa-pro 3.13.0.dev1__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (58) hide show
  1. rasa/core/actions/action.py +0 -6
  2. rasa/core/channels/voice_ready/audiocodes.py +52 -17
  3. rasa/core/channels/voice_stream/audiocodes.py +53 -9
  4. rasa/core/channels/voice_stream/genesys.py +146 -16
  5. rasa/core/information_retrieval/faiss.py +6 -1
  6. rasa/core/information_retrieval/information_retrieval.py +40 -2
  7. rasa/core/information_retrieval/milvus.py +7 -2
  8. rasa/core/information_retrieval/qdrant.py +7 -2
  9. rasa/core/policies/enterprise_search_policy.py +61 -301
  10. rasa/core/policies/flows/flow_executor.py +3 -38
  11. rasa/core/processor.py +27 -6
  12. rasa/core/utils.py +53 -0
  13. rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
  14. rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
  15. rasa/dialogue_understanding/generator/command_generator.py +67 -0
  16. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  17. rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -13
  18. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  19. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +20 -1
  20. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +7 -0
  21. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
  22. rasa/dialogue_understanding/processor/command_processor.py +7 -65
  23. rasa/dialogue_understanding/stack/utils.py +0 -38
  24. rasa/dialogue_understanding_test/io.py +13 -8
  25. rasa/document_retrieval/__init__.py +0 -0
  26. rasa/document_retrieval/constants.py +32 -0
  27. rasa/document_retrieval/document_post_processor.py +351 -0
  28. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  29. rasa/document_retrieval/document_retriever.py +333 -0
  30. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  31. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
  32. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
  33. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
  34. rasa/document_retrieval/query_rewriter.py +234 -0
  35. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
  36. rasa/engine/recipes/default_components.py +2 -0
  37. rasa/shared/core/constants.py +0 -8
  38. rasa/shared/core/domain.py +12 -3
  39. rasa/shared/core/flows/flow.py +0 -17
  40. rasa/shared/core/flows/flows_yaml_schema.json +3 -38
  41. rasa/shared/core/flows/steps/collect.py +5 -18
  42. rasa/shared/core/flows/utils.py +1 -16
  43. rasa/shared/core/slot_mappings.py +11 -5
  44. rasa/shared/nlu/constants.py +0 -1
  45. rasa/shared/utils/common.py +11 -1
  46. rasa/shared/utils/llm.py +1 -1
  47. rasa/tracing/instrumentation/attribute_extractors.py +10 -7
  48. rasa/tracing/instrumentation/instrumentation.py +12 -12
  49. rasa/validator.py +1 -123
  50. rasa/version.py +1 -1
  51. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +1 -1
  52. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +55 -47
  53. rasa/core/actions/action_handle_digressions.py +0 -164
  54. rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
  55. rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
  56. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
  57. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
  58. {rasa_pro-3.13.0.dev1.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -108,10 +108,6 @@ logger = logging.getLogger(__name__)
108
108
  def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["Action"]:
109
109
  """List default actions."""
110
110
  from rasa.core.actions.action_clean_stack import ActionCleanStack
111
- from rasa.core.actions.action_handle_digressions import (
112
- ActionBlockDigressions,
113
- ActionContinueDigression,
114
- )
115
111
  from rasa.core.actions.action_hangup import ActionHangup
116
112
  from rasa.core.actions.action_repeat_bot_messages import ActionRepeatBotMessages
117
113
  from rasa.core.actions.action_run_slot_rejections import ActionRunSlotRejections
@@ -146,8 +142,6 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
146
142
  ActionResetRouting(),
147
143
  ActionHangup(),
148
144
  ActionRepeatBotMessages(),
149
- ActionBlockDigressions(),
150
- ActionContinueDigression(),
151
145
  ]
152
146
 
153
147
 
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import copy
3
+ import hmac
3
4
  import json
4
5
  import uuid
5
6
  from collections import defaultdict
@@ -114,11 +115,21 @@ class Conversation:
114
115
  async def handle_activities(
115
116
  self,
116
117
  message: Dict[Text, Any],
118
+ input_channel_name: str,
117
119
  output_channel: OutputChannel,
118
120
  on_new_message: Callable[[UserMessage], Awaitable[Any]],
119
121
  ) -> None:
120
122
  """Handle activities sent by Audiocodes."""
121
123
  structlogger.debug("audiocodes.handle.activities")
124
+ if input_channel_name == "":
125
+ structlogger.warning(
126
+ "audiocodes.handle.activities.empty_input_channel_name",
127
+ event_info=(
128
+ f"Audiocodes input channel name is empty "
129
+ f"for conversation {self.conversation_id}"
130
+ ),
131
+ )
132
+
122
133
  for activity in message["activities"]:
123
134
  text = None
124
135
  if activity[ACTIVITY_ID_KEY] in self.activity_ids:
@@ -142,6 +153,7 @@ class Conversation:
142
153
  metadata = self.get_metadata(activity)
143
154
  user_msg = UserMessage(
144
155
  text=text,
156
+ input_channel=input_channel_name,
145
157
  output_channel=output_channel,
146
158
  sender_id=self.conversation_id,
147
159
  metadata=metadata,
@@ -245,8 +257,13 @@ class AudiocodesInput(InputChannel):
245
257
 
246
258
  def _check_token(self, token: Optional[Text]) -> None:
247
259
  if not token:
260
+ structlogger.error("audiocodes.token_not_provided")
248
261
  raise HttpUnauthorized("Authentication token required.")
249
262
 
263
+ if not hmac.compare_digest(str(token), str(self.token)):
264
+ structlogger.error("audiocodes.invalid_token", invalid_token=token)
265
+ raise HttpUnauthorized("Invalid authentication token.")
266
+
250
267
  def _get_conversation(
251
268
  self, token: Optional[Text], conversation_id: Text
252
269
  ) -> Conversation:
@@ -388,7 +405,12 @@ class AudiocodesInput(InputChannel):
388
405
  # start a background task to handle activities
389
406
  self._create_task(
390
407
  conversation_id,
391
- conversation.handle_activities(request.json, ac_output, on_new_message),
408
+ conversation.handle_activities(
409
+ request.json,
410
+ input_channel_name=self.name(),
411
+ output_channel=ac_output,
412
+ on_new_message=on_new_message,
413
+ ),
392
414
  )
393
415
  return response.json(response_json)
394
416
 
@@ -401,23 +423,9 @@ class AudiocodesInput(InputChannel):
401
423
  Example of payload:
402
424
  {"conversation": <conversation_id>, "reason": Optional[Text]}.
403
425
  """
404
- self._get_conversation(request.token, conversation_id)
405
- reason = {"reason": request.json.get("reason")}
406
- await on_new_message(
407
- UserMessage(
408
- text=f"{INTENT_MESSAGE_PREFIX}session_end",
409
- output_channel=None,
410
- sender_id=conversation_id,
411
- metadata=reason,
412
- )
413
- )
414
- del self.conversations[conversation_id]
415
- structlogger.debug(
416
- "audiocodes.disconnect",
417
- conversation=conversation_id,
418
- request=request.json,
426
+ return await self._handle_disconnect(
427
+ request, conversation_id, on_new_message
419
428
  )
420
- return response.json({})
421
429
 
422
430
  @ac_webhook.route("/conversation/<conversation_id>/keepalive", methods=["POST"])
423
431
  async def keepalive(request: Request, conversation_id: Text) -> HTTPResponse:
@@ -432,6 +440,32 @@ class AudiocodesInput(InputChannel):
432
440
 
433
441
  return ac_webhook
434
442
 
443
+ async def _handle_disconnect(
444
+ self,
445
+ request: Request,
446
+ conversation_id: Text,
447
+ on_new_message: Callable[[UserMessage], Awaitable[Any]],
448
+ ) -> HTTPResponse:
449
+ """Triggered when the call is disconnected."""
450
+ self._get_conversation(request.token, conversation_id)
451
+ reason = {"reason": request.json.get("reason")}
452
+ await on_new_message(
453
+ UserMessage(
454
+ text=f"{INTENT_MESSAGE_PREFIX}session_end",
455
+ output_channel=None,
456
+ input_channel=self.name(),
457
+ sender_id=conversation_id,
458
+ metadata=reason,
459
+ )
460
+ )
461
+ del self.conversations[conversation_id]
462
+ structlogger.debug(
463
+ "audiocodes.disconnect",
464
+ conversation=conversation_id,
465
+ request=request.json,
466
+ )
467
+ return response.json({})
468
+
435
469
 
436
470
  class AudiocodesOutput(OutputChannel):
437
471
  @classmethod
@@ -439,6 +473,7 @@ class AudiocodesOutput(OutputChannel):
439
473
  return CHANNEL_NAME
440
474
 
441
475
  def __init__(self) -> None:
476
+ super().__init__()
442
477
  self.messages: List[Dict] = []
443
478
 
444
479
  async def add_message(self, message: Dict) -> None:
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import base64
3
+ import hmac
3
4
  import json
4
5
  from typing import Any, Awaitable, Callable, Dict, Optional, Text
5
6
 
@@ -103,6 +104,7 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
103
104
 
104
105
  def __init__(
105
106
  self,
107
+ token: Optional[Text],
106
108
  server_url: str,
107
109
  asr_config: Dict,
108
110
  tts_config: Dict,
@@ -110,6 +112,22 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
110
112
  ):
111
113
  mark_as_beta_feature("Audiocodes (audiocodes_stream) Channel")
112
114
  super().__init__(server_url, asr_config, tts_config, monitor_silence)
115
+ self.token = token
116
+
117
+ @classmethod
118
+ def from_credentials(
119
+ cls, credentials: Optional[Dict[str, Any]]
120
+ ) -> VoiceInputChannel:
121
+ if not credentials:
122
+ raise ValueError("No credentials given for Audiocodes voice channel.")
123
+
124
+ return cls(
125
+ token=credentials.get("token"),
126
+ server_url=credentials["server_url"],
127
+ asr_config=credentials["asr"],
128
+ tts_config=credentials["tts"],
129
+ monitor_silence=credentials.get("monitor_silence", False),
130
+ )
113
131
 
114
132
  def channel_bytes_to_rasa_audio_bytes(self, input_bytes: bytes) -> RasaAudioBytes:
115
133
  return RasaAudioBytes(base64.b64decode(input_bytes))
@@ -135,6 +153,13 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
135
153
  )
136
154
  if activity["name"] == "start":
137
155
  return map_call_params(activity["parameters"])
156
+ elif data["type"] == "connection.validate":
157
+ # not part of call flow; only sent when integration is created
158
+ logger.info(
159
+ "audiocodes_stream.collect_call_parameters.connection.validate",
160
+ event_info="received request to validate integration",
161
+ )
162
+ self._send_validated(channel_websocket, data)
138
163
  else:
139
164
  logger.warning("audiocodes_stream.unknown_message", data=data)
140
165
  return None
@@ -158,7 +183,7 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
158
183
  elif activity["name"] == "playFinished":
159
184
  logger.debug("audiocodes_stream.playFinished", data=activity)
160
185
  if call_state.should_hangup:
161
- logger.info("audiocodes.hangup")
186
+ logger.info("audiocodes_stream.hangup")
162
187
  self._send_hangup(ws, data)
163
188
  # the conversation should continue until
164
189
  # we receive a end message from audiocodes
@@ -180,11 +205,10 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
180
205
  elif data["type"] == "session.end":
181
206
  logger.debug("audiocodes_stream.end", data=data)
182
207
  return EndConversationAction()
183
- elif data["type"] == "connection.validate":
184
- # not part of call flow; only sent when integration is created
185
- self._send_validated(ws, data)
186
208
  else:
187
- logger.warning("audiocodes_stream.unknown_message", data=data)
209
+ logger.warning(
210
+ "audiocodes_stream.map_input_message.unknown_message", data=data
211
+ )
188
212
 
189
213
  return ContinueConversationAction()
190
214
 
@@ -254,6 +278,17 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
254
278
  self.tts_cache,
255
279
  )
256
280
 
281
+ def _is_token_valid(self, token: Optional[Text]) -> bool:
282
+ # If no token is set, always return True
283
+ if not self.token:
284
+ return True
285
+
286
+ # Token is required, but not provided
287
+ if not token:
288
+ return False
289
+
290
+ return hmac.compare_digest(str(self.token), str(token))
291
+
257
292
  def blueprint(
258
293
  self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
259
294
  ) -> Blueprint:
@@ -266,17 +301,26 @@ class AudiocodesVoiceInputChannel(VoiceInputChannel):
266
301
 
267
302
  @blueprint.websocket("/websocket") # type: ignore
268
303
  async def receive(request: Request, ws: Websocket) -> None:
269
- # TODO: validate API key header
270
- logger.info("audiocodes.receive", message="Starting audio streaming")
304
+ if not self._is_token_valid(request.token):
305
+ logger.error(
306
+ "audiocodes_stream.invalid_token",
307
+ invalid_token=request.token,
308
+ )
309
+ await ws.close(code=1008, reason="Invalid token")
310
+ return
311
+
312
+ logger.info(
313
+ "audiocodes_stream.receive", event_info="Started websocket connection"
314
+ )
271
315
  try:
272
316
  await self.run_audio_streaming(on_new_message, ws)
273
317
  except Exception as e:
274
318
  logger.exception(
275
- "audiocodes.receive",
319
+ "audiocodes_stream.receive",
276
320
  message="Error during audio streaming",
277
321
  error=e,
278
322
  )
279
- # return 500 error
323
+ await ws.close(code=1011, reason="Error during audio streaming")
280
324
  raise
281
325
 
282
326
  return blueprint
@@ -1,4 +1,7 @@
1
1
  import asyncio
2
+ import base64
3
+ import hashlib
4
+ import hmac
2
5
  import json
3
6
  from typing import Any, Awaitable, Callable, Dict, Optional, Text
4
7
 
@@ -45,6 +48,7 @@ in the documentation but observed in their example app
45
48
  https://github.com/GenesysCloudBlueprints/audioconnector-server-reference-implementation
46
49
  """
47
50
  MAXIMUM_BINARY_MESSAGE_SIZE = 64000 # 64KB
51
+ HEADER_API_KEY = "X-Api-Key"
48
52
  logger = structlog.get_logger(__name__)
49
53
 
50
54
 
@@ -86,8 +90,31 @@ class GenesysInputChannel(VoiceInputChannel):
86
90
  def name(cls) -> str:
87
91
  return "genesys"
88
92
 
89
- def __init__(self, *args: Any, **kwargs: Any) -> None:
93
+ def __init__(
94
+ self, api_key: Text, client_secret: Optional[Text], *args: Any, **kwargs: Any
95
+ ) -> None:
90
96
  super().__init__(*args, **kwargs)
97
+ self.api_key = api_key
98
+ self.client_secret = client_secret
99
+
100
+ @classmethod
101
+ def from_credentials(
102
+ cls, credentials: Optional[Dict[str, Any]]
103
+ ) -> VoiceInputChannel:
104
+ if not credentials:
105
+ raise ValueError("No credentials given for Genesys voice channel.")
106
+
107
+ if not credentials.get("api_key"):
108
+ raise ValueError("No API key given for Genesys voice channel (api_key).")
109
+
110
+ return cls(
111
+ api_key=credentials["api_key"],
112
+ client_secret=credentials.get("client_secret"),
113
+ server_url=credentials["server_url"],
114
+ asr_config=credentials["asr"],
115
+ tts_config=credentials["tts"],
116
+ monitor_silence=credentials.get("monitor_silence", False),
117
+ )
91
118
 
92
119
  def _ensure_channel_data_initialized(self) -> None:
93
120
  """Initialize Genesys-specific channel data if not already present.
@@ -273,6 +300,93 @@ class GenesysInputChannel(VoiceInputChannel):
273
300
  logger.debug("genesys.disconnect", message=message)
274
301
  _schedule_ws_task(ws.send(json.dumps(message)))
275
302
 
303
+ def _calculate_signature(self, request: Request) -> str:
304
+ """Calculate the signature using request data."""
305
+ org_id = request.headers.get("Audiohook-Organization-Id")
306
+ session_id = request.headers.get("Audiohook-Session-Id")
307
+ correlation_id = request.headers.get("Audiohook-Correlation-Id")
308
+ api_key = request.headers.get(HEADER_API_KEY)
309
+
310
+ # order of components is important!
311
+ components = [
312
+ ("@request-target", "/webhooks/genesys/websocket"),
313
+ ("audiohook-session-id", session_id),
314
+ ("audiohook-organization-id", org_id),
315
+ ("audiohook-correlation-id", correlation_id),
316
+ (HEADER_API_KEY.lower(), api_key),
317
+ ("@authority", self.server_url),
318
+ ]
319
+
320
+ # Create signature base string
321
+ signing_string = ""
322
+ for name, value in components:
323
+ signing_string += f'"{name}": {value}\n'
324
+
325
+ # Add @signature-params
326
+ signature_input = request.headers["Signature-Input"]
327
+ _, params_str = signature_input.split("=", 1)
328
+ signing_string += f'"@signature-params": {params_str}'
329
+
330
+ # Calculate the HMAC signature
331
+ key_bytes = base64.b64decode(self.client_secret)
332
+ signature = hmac.new(
333
+ key_bytes, signing_string.encode("utf-8"), hashlib.sha256
334
+ ).digest()
335
+ return base64.b64encode(signature).decode("utf-8")
336
+
337
+ async def _verify_signature(self, request: Request) -> bool:
338
+ """Verify the HTTP message signature from Genesys."""
339
+ if not self.client_secret:
340
+ logger.info(
341
+ "genesys.verify_signature.no_client_secret",
342
+ event_info="Signature verification skipped",
343
+ )
344
+ return True # Skip verification if no client secret
345
+
346
+ signature = request.headers.get("Signature")
347
+ signature_input = request.headers.get("Signature-Input")
348
+ if not signature or not signature_input:
349
+ logger.error("genesys.signature.missing_signature_header")
350
+ return False
351
+
352
+ try:
353
+ actual_signature = signature.split("=", 1)[1].strip(':"')
354
+ expected_signature = self._calculate_signature(request)
355
+ return hmac.compare_digest(
356
+ expected_signature.encode("utf-8"), actual_signature.encode("utf-8")
357
+ )
358
+ except Exception as e:
359
+ logger.exception("genesys.signature.verification_error", error=e)
360
+ return False
361
+
362
+ def _ensure_required_headers(self, request: Request) -> bool:
363
+ """Ensure required headers are present in the request."""
364
+ required_headers = [
365
+ "Audiohook-Organization-Id",
366
+ "Audiohook-Correlation-Id",
367
+ "Audiohook-Session-Id",
368
+ HEADER_API_KEY,
369
+ ]
370
+
371
+ missing_headers = [
372
+ header for header in required_headers if header not in request.headers
373
+ ]
374
+
375
+ if missing_headers:
376
+ logger.error(
377
+ "genesys.missing_required_headers",
378
+ missing_headers=missing_headers,
379
+ )
380
+ return False
381
+ return True
382
+
383
+ def _ensure_api_key(self, request: Request) -> bool:
384
+ """Ensure the API key is present in the request."""
385
+ api_key = request.headers.get(HEADER_API_KEY)
386
+ if not hmac.compare_digest(str(self.api_key), str(api_key)):
387
+ return False
388
+ return True
389
+
276
390
  def blueprint(
277
391
  self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
278
392
  ) -> Blueprint:
@@ -289,23 +403,39 @@ class GenesysInputChannel(VoiceInputChannel):
289
403
  "genesys.receive",
290
404
  audiohook_session_id=request.headers.get("audiohook-session-id"),
291
405
  )
292
- # validate required headers
293
- required_headers = [
294
- "audiohook-organization-id",
295
- "audiohook-correlation-id",
296
- "audiohook-session-id",
297
- "x-api-key",
298
- ]
299
-
300
- for header in required_headers:
301
- if header not in request.headers:
302
- await ws.close(1008, f"Missing required header: {header}")
303
- return
304
-
305
- # TODO: validate API key header
406
+
407
+ # verify signature
408
+ if not await self._verify_signature(request):
409
+ logger.error("genesys.receive.invalid_signature")
410
+ await ws.close(code=1008, reason="Invalid signature")
411
+ return
412
+
413
+ # ensure required headers are present
414
+ if not self._ensure_required_headers(request):
415
+ await ws.close(code=1002, reason="Missing required headers")
416
+ return
417
+
418
+ # ensure API key is correct
419
+ if not self._ensure_api_key(request):
420
+ logger.error(
421
+ "genesys.receive.invalid_api_key",
422
+ invalid_api_key=request.headers.get(HEADER_API_KEY),
423
+ )
424
+ await ws.close(code=1008, reason="Invalid API key")
425
+ return
426
+
306
427
  # process audio streaming
307
428
  logger.info("genesys.receive", message="Starting audio streaming")
308
- await self.run_audio_streaming(on_new_message, ws)
429
+ try:
430
+ await self.run_audio_streaming(on_new_message, ws)
431
+ except Exception as e:
432
+ logger.exception(
433
+ "genesys.receive",
434
+ message="Error during audio streaming",
435
+ error=e,
436
+ )
437
+ await ws.close(code=1011, reason="Error during audio streaming")
438
+ raise
309
439
 
310
440
  return blueprint
311
441
 
@@ -169,10 +169,15 @@ class FAISS_Store(InformationRetrieval):
169
169
  pass
170
170
 
171
171
  async def search(
172
- self, query: Text, tracker_state: Dict[str, Any], threshold: float = 0.0
172
+ self,
173
+ query: Text,
174
+ tracker_state: Dict[str, Any],
175
+ threshold: float = 0.0,
176
+ k: int = 1,
173
177
  ) -> SearchResultList:
174
178
  logger.debug("information_retrieval.faiss_store.search", query=query)
175
179
  try:
180
+ # TODO: make use of k
176
181
  documents = await self.index.as_retriever().ainvoke(query)
177
182
  except Exception as exc:
178
183
  raise InformationRetrievalException from exc
@@ -36,6 +36,19 @@ class SearchResult:
36
36
  """Construct a SearchResult object from Langchain Document object."""
37
37
  return cls(text=document.page_content, metadata=document.metadata)
38
38
 
39
+ @classmethod
40
+ def from_dict(cls, data: dict[str, Any]) -> "SearchResult":
41
+ """Construct a SearchResult object from a JSON object."""
42
+ return cls(text=data["text"], metadata=data["metadata"], score=data["score"])
43
+
44
+ def to_dict(self) -> dict[str, Any]:
45
+ """Convert the SearchResult object to a dictionary."""
46
+ return {
47
+ "text": self.text,
48
+ "metadata": self.metadata,
49
+ "score": self.score,
50
+ }
51
+
39
52
 
40
53
  @dataclass
41
54
  class SearchResultList:
@@ -44,8 +57,7 @@ class SearchResultList:
44
57
 
45
58
  @classmethod
46
59
  def from_document_list(cls, documents: List["Document"]) -> "SearchResultList":
47
- """
48
- Convert a list of Langchain Documents to a SearchResultList object.
60
+ """Convert a list of Langchain Documents to a SearchResultList object.
49
61
 
50
62
  Args:
51
63
  documents: List of Langchain Documents.
@@ -58,6 +70,31 @@ class SearchResultList:
58
70
  metadata={"total_results": len(documents)},
59
71
  )
60
72
 
73
+ @classmethod
74
+ def from_dict(cls, data: dict[str, Any]) -> "SearchResultList":
75
+ """Convert a JSON object to a SearchResultList object.
76
+
77
+ Args:
78
+ data: JSON object.
79
+
80
+ Returns:
81
+ SearchResultList object.
82
+ """
83
+ if not data:
84
+ return cls(results=[], metadata={})
85
+
86
+ return cls(
87
+ results=[SearchResult.from_dict(result) for result in data["results"]],
88
+ metadata=data["metadata"],
89
+ )
90
+
91
+ def to_dict(self) -> dict[str, Any]:
92
+ """Convert the SearchResultList object to a dictionary."""
93
+ return {
94
+ "results": [result.to_dict() for result in self.results],
95
+ "metadata": self.metadata,
96
+ }
97
+
61
98
 
62
99
  class InformationRetrievalException(RasaException):
63
100
  """Base class for exceptions raised by InformationRetrieval operations."""
@@ -89,6 +126,7 @@ class InformationRetrieval:
89
126
  query: Text,
90
127
  tracker_state: dict[str, Any],
91
128
  threshold: float = 0.0,
129
+ k: int = 1,
92
130
  ) -> SearchResultList:
93
131
  """Search for a document in the InformationRetrieval system."""
94
132
  raise NotImplementedError(
@@ -31,20 +31,25 @@ class Milvus_Store(InformationRetrieval):
31
31
  )
32
32
 
33
33
  async def search(
34
- self, query: Text, tracker_state: Dict[str, Any], threshold: float = 0.0
34
+ self,
35
+ query: Text,
36
+ tracker_state: Dict[str, Any],
37
+ threshold: float = 0.0,
38
+ k: int = 1,
35
39
  ) -> SearchResultList:
36
40
  """Search for documents in the Milvus store.
37
41
 
38
42
  Args:
39
43
  query: The query to search for.
40
44
  threshold: minimum similarity score to consider a document a match.
45
+ k: number of results to return.
41
46
 
42
47
  Returns:
43
48
  A list of documents that match the query.
44
49
  """
45
50
  logger.debug("information_retrieval.milvus_store.search", query=query)
46
51
  try:
47
- hits = await self.client.asimilarity_search_with_score(query, k=4)
52
+ hits = await self.client.asimilarity_search_with_score(query, k=k)
48
53
  except Exception as exc:
49
54
  raise InformationRetrievalException from exc
50
55
 
@@ -66,13 +66,18 @@ class Qdrant_Store(InformationRetrieval):
66
66
  )
67
67
 
68
68
  async def search(
69
- self, query: Text, tracker_state: Dict[str, Any], threshold: float = 0.0
69
+ self,
70
+ query: Text,
71
+ tracker_state: Dict[str, Any],
72
+ threshold: float = 0.0,
73
+ k: int = 1,
70
74
  ) -> SearchResultList:
71
75
  """Search for a document in the Qdrant vector store.
72
76
 
73
77
  Args:
74
78
  query: The query to search for.
75
79
  threshold: minimum similarity score to consider a document a match.
80
+ k: number of results to return.
76
81
 
77
82
  Returns:
78
83
  A list of documents that match the query.
@@ -80,7 +85,7 @@ class Qdrant_Store(InformationRetrieval):
80
85
  logger.debug("information_retrieval.qdrant_store.search", query=query)
81
86
  try:
82
87
  hits = await self.client.asimilarity_search(
83
- query, k=4, score_threshold=threshold
88
+ query, k=k, score_threshold=threshold
84
89
  )
85
90
  except ValidationError as e:
86
91
  raise PayloadNotFoundException(