rasa-pro 3.14.0rc1__py3-none-any.whl → 3.14.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

rasa/builder/jobs.py CHANGED
@@ -1,9 +1,17 @@
1
- from typing import Any, Optional
1
+ from typing import Any, Dict, Optional
2
2
 
3
3
  import structlog
4
4
  from sanic import Sanic
5
5
 
6
6
  from rasa.builder import config
7
+ from rasa.builder.copilot.models import (
8
+ CopilotContext,
9
+ FileContent,
10
+ InternalCopilotRequestChatMessage,
11
+ LogContent,
12
+ ResponseCategory,
13
+ TrainingErrorLog,
14
+ )
7
15
  from rasa.builder.exceptions import (
8
16
  LLMGenerationError,
9
17
  ProjectGenerationError,
@@ -11,6 +19,7 @@ from rasa.builder.exceptions import (
11
19
  ValidationError,
12
20
  )
13
21
  from rasa.builder.job_manager import JobInfo, job_manager
22
+ from rasa.builder.llm_service import llm_service
14
23
  from rasa.builder.models import (
15
24
  JobStatus,
16
25
  JobStatusEvent,
@@ -28,9 +37,14 @@ structlogger = structlog.get_logger()
28
37
 
29
38
 
30
39
  async def push_job_status_event(
31
- job: JobInfo, status: JobStatus, message: Optional[str] = None
40
+ job: JobInfo,
41
+ status: JobStatus,
42
+ message: Optional[str] = None,
43
+ payload: Optional[Dict[str, Any]] = None,
32
44
  ) -> None:
33
- event = JobStatusEvent.from_status(status=status.value, message=message)
45
+ event = JobStatusEvent.from_status(
46
+ status=status.value, message=message, payload=payload
47
+ )
34
48
  job.status = status.value
35
49
  await job.put(event)
36
50
 
@@ -210,7 +224,7 @@ async def run_template_to_bot_job(
210
224
  async def run_replace_all_files_job(
211
225
  app: "Sanic",
212
226
  job: JobInfo,
213
- bot_files: dict,
227
+ bot_files: Dict[str, Any],
214
228
  ) -> None:
215
229
  """Run the replace-all-files job in the background.
216
230
 
@@ -228,7 +242,7 @@ async def run_replace_all_files_job(
228
242
  try:
229
243
  project_generator.replace_all_bot_files(bot_files)
230
244
 
231
- # 1. Validating
245
+ # Validating
232
246
  await push_job_status_event(job, JobStatus.validating)
233
247
  training_input = project_generator.get_training_input()
234
248
  validation_error = await validate_project(training_input.importer)
@@ -236,12 +250,13 @@ async def run_replace_all_files_job(
236
250
  raise ValidationError(validation_error)
237
251
  await push_job_status_event(job, JobStatus.validation_success)
238
252
 
239
- # 2. Training
253
+ # Training
240
254
  await push_job_status_event(job, JobStatus.training)
241
255
  agent = await train_and_load_agent(training_input)
242
256
  update_agent(agent, app)
243
257
  await push_job_status_event(job, JobStatus.train_success)
244
258
 
259
+ # Send final done event
245
260
  await push_job_status_event(job, JobStatus.done)
246
261
  job_manager.mark_done(job)
247
262
 
@@ -257,26 +272,181 @@ async def run_replace_all_files_job(
257
272
  included_log_levels=log_levels,
258
273
  )
259
274
  error_message = exc.get_error_message_with_logs(log_levels=log_levels)
260
- await push_job_status_event(
261
- job, JobStatus.validation_error, message=error_message
275
+ # Push error event and start copilot analysis job
276
+ await push_error_and_start_copilot_analysis(
277
+ app,
278
+ job,
279
+ JobStatus.validation_error,
280
+ error_message,
281
+ bot_files,
262
282
  )
283
+
284
+ # After error mark job as done
263
285
  job_manager.mark_done(job, error=error_message)
264
286
 
265
287
  except TrainingError as exc:
288
+ error_message = str(exc)
266
289
  structlogger.debug(
267
290
  "replace_all_files_job.train_error",
268
291
  job_id=job.id,
269
- error=str(exc),
292
+ error=error_message,
270
293
  )
271
- await push_job_status_event(job, JobStatus.train_error, message=str(exc))
272
- job_manager.mark_done(job, error=str(exc))
294
+ # Push error event and start copilot analysis job
295
+ await push_error_and_start_copilot_analysis(
296
+ app,
297
+ job,
298
+ JobStatus.train_error,
299
+ error_message,
300
+ bot_files,
301
+ )
302
+
303
+ # After error mark job as done
304
+ job_manager.mark_done(job, error=error_message)
273
305
 
274
306
  except Exception as exc:
275
307
  # Capture full traceback for anything truly unexpected
308
+ error_message = str(exc)
276
309
  structlogger.exception(
277
310
  "replace_all_files_job.unexpected_error",
278
311
  job_id=job.id,
312
+ error=error_message,
313
+ )
314
+
315
+ # Push error event and start copilot analysis job
316
+ await push_error_and_start_copilot_analysis(
317
+ app,
318
+ job,
319
+ JobStatus.error,
320
+ error_message,
321
+ bot_files,
322
+ )
323
+
324
+ # After error mark job as done
325
+ job_manager.mark_done(job, error=str(exc))
326
+
327
+
328
+ async def push_error_and_start_copilot_analysis(
329
+ app: "Sanic",
330
+ original_job: JobInfo,
331
+ original_job_status: JobStatus,
332
+ error_message: str,
333
+ bot_files: Dict[str, Any],
334
+ ) -> None:
335
+ """Start a copilot analysis job and notify the client.
336
+
337
+ Creates a copilot analysis job and sends the new job ID to the client. The new
338
+ job runs in the background.
339
+
340
+ Args:
341
+ app: The Sanic application instance
342
+ original_job: The original job that failed
343
+ original_job_status: The status of the job that failed
344
+ error_message: The error message to analyze
345
+ bot_files: The bot files to include in analysis
346
+ """
347
+ # Create a copilot analysis job. Send the new job ID to the client and
348
+ # run the Copilot Analysis job in the background.
349
+ message = "Failed to train the assistant. Starting copilot analysis."
350
+
351
+ copilot_job = job_manager.create_job()
352
+ # Push the error status event for the original job
353
+ await push_job_status_event(
354
+ original_job,
355
+ original_job_status,
356
+ message=message,
357
+ payload={"copilot_job_id": copilot_job.id},
358
+ )
359
+ # Run the copilot analysis job in the background
360
+ app.add_task(
361
+ run_copilot_training_error_analysis_job(
362
+ app, copilot_job, error_message, bot_files
363
+ )
364
+ )
365
+ structlogger.debug(
366
+ f"update_files_job.{original_job_status.value}.copilot_analysis_start",
367
+ event_info=message,
368
+ job_id=original_job.id,
369
+ error=error_message,
370
+ copilot_job_id=copilot_job.id,
371
+ )
372
+
373
+
374
+ async def run_copilot_training_error_analysis_job(
375
+ app: "Sanic",
376
+ job: JobInfo,
377
+ training_error_message: str,
378
+ bot_files: Dict[str, Any],
379
+ ) -> None:
380
+ """Run copilot training error analysis job."""
381
+ await push_job_status_event(job, JobStatus.received)
382
+
383
+ try:
384
+ # Create message content blocks with log content and available files
385
+ log_content_block = LogContent(
386
+ type="log", content=training_error_message, context="training_error"
387
+ )
388
+ file_content_blocks = [
389
+ FileContent(type="file", file_path=file_path, file_content=file_content)
390
+ for file_path, file_content in bot_files.items()
391
+ ]
392
+ context = CopilotContext(
393
+ tracker_context=None, # No conversation context needed
394
+ copilot_chat_history=[
395
+ InternalCopilotRequestChatMessage(
396
+ role="internal_copilot_request",
397
+ content=[log_content_block, *file_content_blocks],
398
+ response_category=ResponseCategory.TRAINING_ERROR_LOG_ANALYSIS,
399
+ )
400
+ ],
401
+ )
402
+
403
+ # Generate copilot response
404
+ copilot_client = llm_service.instantiate_copilot()
405
+ (
406
+ original_stream,
407
+ generation_context,
408
+ ) = await copilot_client.generate_response(context)
409
+
410
+ copilot_response_handler = llm_service.instantiate_handler(
411
+ config.COPILOT_HANDLER_ROLLING_BUFFER_SIZE
412
+ )
413
+ intercepted_stream = copilot_response_handler.handle_response(original_stream)
414
+
415
+ # Stream the copilot response as job events
416
+ async for token in intercepted_stream:
417
+ # Send each token as a job event using the same format as /copilot endpoint
418
+ await push_job_status_event(
419
+ job, JobStatus.copilot_analyzing, payload=token.sse_data
420
+ )
421
+
422
+ # Send references (if any) as part of copilot_analyzing stream
423
+ if generation_context.relevant_documents:
424
+ reference_section = copilot_response_handler.extract_references(
425
+ generation_context.relevant_documents
426
+ )
427
+ await push_job_status_event(
428
+ job, JobStatus.copilot_analyzing, payload=reference_section.sse_data
429
+ )
430
+
431
+ # Send original error log as part of copilot_analyzing stream
432
+ training_error_log = TrainingErrorLog(logs=[log_content_block])
433
+ await push_job_status_event(
434
+ job, JobStatus.copilot_analyzing, payload=training_error_log.sse_data
435
+ )
436
+
437
+ # Send success status
438
+ await push_job_status_event(job, JobStatus.copilot_analysis_success)
439
+
440
+ await push_job_status_event(job, JobStatus.done)
441
+ job_manager.mark_done(job)
442
+
443
+ except Exception as exc:
444
+ structlogger.exception(
445
+ "copilot_training_error_analysis_job.error",
446
+ job_id=job.id,
279
447
  error=str(exc),
280
448
  )
281
- await push_job_status_event(job, JobStatus.error, message=str(exc))
449
+ await push_job_status_event(
450
+ job, JobStatus.copilot_analysis_error, message=str(exc)
451
+ )
282
452
  job_manager.mark_done(job, error=str(exc))
rasa/builder/models.py CHANGED
@@ -136,12 +136,14 @@ class JobStatusEvent(ServerSentEvent):
136
136
  cls,
137
137
  status: str,
138
138
  message: Optional[str] = None,
139
+ payload: Optional[Dict[str, Any]] = None,
139
140
  ) -> "JobStatusEvent":
140
141
  """Factory for job-status events.
141
142
 
142
143
  Args:
143
144
  status: The job status (e.g. "training", "train_success").
144
145
  message: Optional error message for error events.
146
+ payload: Optional additional payload data to include in the event.
145
147
 
146
148
  Returns:
147
149
  A JobStatusEvent instance with the appropriate event type and data.
@@ -150,11 +152,13 @@ class JobStatusEvent(ServerSentEvent):
150
152
  ServerSentEventType.error if message else ServerSentEventType.progress
151
153
  )
152
154
 
153
- payload: Dict[str, Any] = {"status": status}
155
+ event_payload: Dict[str, Any] = {"status": status}
154
156
  if message:
155
- payload["message"] = message
157
+ event_payload["message"] = message
158
+ if payload:
159
+ event_payload.update(payload)
156
160
 
157
- return cls(event=event_type.value, data=payload)
161
+ return cls(event=event_type.value, data=event_payload)
158
162
 
159
163
 
160
164
  class ValidationResult(BaseModel):
@@ -193,6 +197,11 @@ class JobStatus(str, Enum):
193
197
  validation_success = "validation_success"
194
198
  validation_error = "validation_error"
195
199
 
200
+ copilot_analysis_start = "copilot_analysis_start"
201
+ copilot_analyzing = "copilot_analyzing"
202
+ copilot_analysis_success = "copilot_analysis_success"
203
+ copilot_analysis_error = "copilot_analysis_error"
204
+
196
205
 
197
206
  class JobCreateResponse(BaseModel):
198
207
  job_id: str = Field(...)
rasa/builder/service.py CHANGED
@@ -532,7 +532,19 @@ async def get_bot_files(request: Request) -> HTTPResponse:
532
532
  "**Error Events (can occur at any time):**\n"
533
533
  "- `validation_error` - Bot configuration files are invalid\n"
534
534
  "- `train_error` - Files updated but training failed\n"
535
+ "- `copilot_analysis_start` - Copilot analysis started (includes `copilot_job_id` "
536
+ "in payload)\n"
535
537
  "- `error` - Unexpected error occurred\n\n"
538
+ "**Copilot Analysis Events:**\n"
539
+ "When training or validation fails, a separate copilot analysis job is "
540
+ "automatically started. The `copilot_analysis_start` event includes a "
541
+ "`copilot_job_id` in the payload.\nConnect to `/job-events/<copilot_job_id>` to "
542
+ "receive the following events:\n"
543
+ "- `copilot_analyzing` - Copilot is analyzing errors and providing suggestions. "
544
+ "Uses the same SSE event payload format as the `/copilot` endpoint with `content`, "
545
+ "`response_category`, and `completeness` fields.\n"
546
+ "- `copilot_analysis_success` - Copilot analysis completed with references.\n"
547
+ "- `copilot_analysis_error` - Copilot analysis failed\n\n"
536
548
  "**Usage:**\n"
537
549
  "1. Send POST request with Content-Type: application/json\n"
538
550
  "2. The response will be a JSON object `{job_id: ...}`\n"
@@ -1042,7 +1054,8 @@ async def copilot(request: Request) -> None:
1042
1054
  # Offload telemetry logging to a background task
1043
1055
  request.app.add_task(
1044
1056
  asyncio.to_thread(
1045
- telemetry.log_user_turn, req.last_message.get_text_content()
1057
+ telemetry.log_user_turn,
1058
+ req.last_message.get_flattened_text_content(),
1046
1059
  )
1047
1060
  )
1048
1061
 
@@ -1159,7 +1172,7 @@ async def copilot(request: Request) -> None:
1159
1172
  system_message=generation_context.system_message,
1160
1173
  chat_history=generation_context.chat_history,
1161
1174
  last_user_message=(
1162
- req.last_message.get_text_content()
1175
+ req.last_message.get_flattened_text_content()
1163
1176
  if (req.last_message and req.last_message.role == ROLE_USER)
1164
1177
  else None
1165
1178
  ),
@@ -12,6 +12,7 @@ import structlog
12
12
 
13
13
  import rasa.shared.utils.common
14
14
  from rasa.core.brokers.broker import EventBroker
15
+ from rasa.core.constants import KAFKA_SERVICE_NAME
15
16
  from rasa.core.exceptions import KafkaProducerInitializationError
16
17
  from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
17
18
  IAMCredentialsProviderInput,
@@ -99,7 +100,10 @@ class KafkaEventBroker(EventBroker):
99
100
  self.queue_size = kwargs.get("queue_size")
100
101
  self.ssl_check_hostname = "https" if ssl_check_hostname else None
101
102
  self.iam_credentials_provider = create_iam_credentials_provider(
102
- IAMCredentialsProviderInput(service_name=SupportedServiceType.EVENT_BROKER)
103
+ IAMCredentialsProviderInput(
104
+ service_type=SupportedServiceType.EVENT_BROKER,
105
+ service_name=KAFKA_SERVICE_NAME,
106
+ )
103
107
  )
104
108
 
105
109
  # PII management attributes
@@ -6,6 +6,9 @@ from typing import Deque, Optional, Text
6
6
  import structlog
7
7
  from pydantic import ValidationError
8
8
 
9
+ from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
10
+ SupportedServiceType,
11
+ )
9
12
  from rasa.core.lock import Ticket, TicketLock
10
13
  from rasa.core.lock_store import (
11
14
  DEFAULT_SOCKET_TIMEOUT_IN_SECONDS,
@@ -108,6 +111,7 @@ class ConcurrentRedisLockStore(LockStore):
108
111
  redis_config = RedisConfig(
109
112
  host=host,
110
113
  port=port,
114
+ service_type=SupportedServiceType.LOCK_STORE,
111
115
  db=db,
112
116
  username=username,
113
117
  password=password,
@@ -150,32 +154,45 @@ class ConcurrentRedisLockStore(LockStore):
150
154
  ),
151
155
  )
152
156
 
153
- def _get_keys_by_pattern(self, pattern: Text) -> list:
154
- """Get keys by pattern, using SCAN for cluster mode and KEYS for others."""
155
- if self.deployment_mode == DeploymentMode.CLUSTER.value:
156
- # In cluster mode, use SCAN to get keys more reliably
157
- keys = []
158
- cursor = 0
159
-
160
- while True:
161
- try:
162
- cursor, batch_keys = self.red.scan(cursor, match=pattern, count=100)
163
- keys.extend(batch_keys)
157
+ def _scan_cluster_keys(self, pattern: Text) -> list:
158
+ """Scan keys in cluster mode with proper cursor handling."""
159
+ keys = []
160
+ cursor = 0
161
+
162
+ while True:
163
+ try:
164
+ cursor, batch_keys = self.red.scan(cursor, match=pattern, count=100)
165
+ keys.extend(batch_keys)
166
+
167
+ if isinstance(cursor, dict):
168
+ # cursor is a dict mapping each node to its scan position. e.g
169
+ # {'127.0.0.1:7000': 0, '127.0.0.1:7001': 5, '127.0.0.1:7002': 0}
170
+ # A cursor value of 0 means that node has finished scanning
171
+ # When all nodes show 0, the entire cluster scan is complete
172
+ if all(v == 0 for v in cursor.values()):
173
+ break
174
+ else:
175
+ # if scan is complete
164
176
  if cursor == 0:
165
177
  break
166
- except Exception as e:
167
- structlogger.warning(
168
- "concurrent_redis_lock_store._get_keys_by_pattern.scan_interrupted",
169
- event_info=f"SCAN interrupted in cluster mode: {e}. "
170
- f"Returning {len(keys)} keys found so far.",
171
- )
172
- break
173
- else:
174
- # Standard and sentinel modes use KEYS
175
- keys = self.red.keys(pattern)
178
+
179
+ except Exception as e:
180
+ structlogger.warning(
181
+ "concurrent_redis_lock_store._get_keys_by_pattern.scan_interrupted",
182
+ event_info=f"SCAN interrupted in cluster mode: {e}. "
183
+ f"Returning {len(keys)} keys found so far.",
184
+ )
185
+ break
176
186
 
177
187
  return keys
178
188
 
189
+ def _get_keys_by_pattern(self, pattern: Text) -> list:
190
+ """Get keys by pattern, using SCAN for cluster mode and KEYS for others."""
191
+ if self.deployment_mode == DeploymentMode.CLUSTER.value:
192
+ return self._scan_cluster_keys(pattern)
193
+ else:
194
+ return self.red.keys(pattern)
195
+
179
196
  def issue_ticket(
180
197
  self, conversation_id: Text, lock_lifetime: float = LOCK_LIFETIME
181
198
  ) -> int:
rasa/core/constants.py CHANGED
@@ -123,3 +123,9 @@ SQL_TRACKER_STORE_SSL_ROOT_CERTIFICATE_ENV_VAR_NAME = (
123
123
  "SQL_TRACKER_STORE_SSL_ROOT_CERTIFICATE"
124
124
  )
125
125
  AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME = "AWS_ELASTICACHE_CLUSTER_NAME"
126
+ RDS_SQL_DB_AWS_IAM_ENABLED_ENV_VAR_NAME = "RDS_SQL_DB_AWS_IAM_ENABLED"
127
+ ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME = "ELASTICACHE_REDIS_AWS_IAM_ENABLED"
128
+ KAFKA_MSK_AWS_IAM_ENABLED_ENV_VAR_NAME = "KAFKA_MSK_AWS_IAM_ENABLED"
129
+ SQL_SERVICE_NAME = "sql"
130
+ REDIS_SERVICE_NAME = "redis"
131
+ KAFKA_SERVICE_NAME = "kafka"
@@ -1,7 +1,7 @@
1
1
  import os
2
2
  import threading
3
3
  import time
4
- from typing import Optional, Tuple
4
+ from typing import Dict, Optional, Tuple
5
5
  from urllib.parse import ParseResult, urlencode, urlunparse
6
6
 
7
7
  import boto3
@@ -14,6 +14,14 @@ from botocore.session import get_session
14
14
  from botocore.signers import RequestSigner
15
15
  from cachetools import TTLCache, cached
16
16
 
17
+ from rasa.core.constants import (
18
+ ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME,
19
+ KAFKA_MSK_AWS_IAM_ENABLED_ENV_VAR_NAME,
20
+ KAFKA_SERVICE_NAME,
21
+ RDS_SQL_DB_AWS_IAM_ENABLED_ENV_VAR_NAME,
22
+ REDIS_SERVICE_NAME,
23
+ SQL_SERVICE_NAME,
24
+ )
17
25
  from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
18
26
  IAMCredentialsProvider,
19
27
  IAMCredentialsProviderInput,
@@ -24,6 +32,25 @@ from rasa.shared.exceptions import ConnectionException
24
32
 
25
33
  structlogger = structlog.get_logger(__name__)
26
34
 
35
+ SERVICE_CONFIG: Dict[Tuple[SupportedServiceType, str], str] = {
36
+ (
37
+ SupportedServiceType.TRACKER_STORE,
38
+ SQL_SERVICE_NAME,
39
+ ): RDS_SQL_DB_AWS_IAM_ENABLED_ENV_VAR_NAME,
40
+ (
41
+ SupportedServiceType.TRACKER_STORE,
42
+ REDIS_SERVICE_NAME,
43
+ ): ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME,
44
+ (
45
+ SupportedServiceType.EVENT_BROKER,
46
+ KAFKA_SERVICE_NAME,
47
+ ): KAFKA_MSK_AWS_IAM_ENABLED_ENV_VAR_NAME,
48
+ (
49
+ SupportedServiceType.LOCK_STORE,
50
+ REDIS_SERVICE_NAME,
51
+ ): ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME,
52
+ }
53
+
27
54
 
28
55
  class AWSRDSIAMCredentialsProvider(IAMCredentialsProvider):
29
56
  """Generates temporary credentials for AWS RDS using IAM roles."""
@@ -203,21 +230,59 @@ class AWSElasticacheRedisIAMCredentialsProvider(redis.CredentialProvider):
203
230
  return TemporaryCredentials()
204
231
 
205
232
 
233
+ def is_iam_enabled(provider_input: "IAMCredentialsProviderInput") -> bool:
234
+ """Checks if IAM authentication is enabled for the given service."""
235
+ service_type = provider_input.service_type
236
+ service_name = provider_input.service_name
237
+ iam_enabled_env_var_name = SERVICE_CONFIG.get((service_type, service_name))
238
+
239
+ if not iam_enabled_env_var_name:
240
+ structlogger.warning(
241
+ "rasa.core.aws_iam_credentials_providers.is_iam_enabled.unsupported_service",
242
+ event_info=f"IAM authentication check requested for unsupported service: "
243
+ f"{service_name}",
244
+ )
245
+ return False
246
+
247
+ return os.getenv(iam_enabled_env_var_name, "false").lower() == "true"
248
+
249
+
206
250
  def create_aws_iam_credentials_provider(
207
251
  provider_input: "IAMCredentialsProviderInput",
208
252
  ) -> Optional["IAMCredentialsProvider"]:
209
253
  """Factory function to create an AWS IAM credentials provider."""
210
- if provider_input.service_name == SupportedServiceType.TRACKER_STORE:
254
+ iam_enabled = is_iam_enabled(provider_input)
255
+ if not iam_enabled:
256
+ structlogger.debug(
257
+ "rasa.core.aws_iam_credentials_providers.create_provider.iam_not_enabled",
258
+ event_info=f"IAM authentication not enabled for service: "
259
+ f"{provider_input.service_type}",
260
+ )
261
+ return None
262
+
263
+ if (
264
+ provider_input.service_type == SupportedServiceType.TRACKER_STORE
265
+ and provider_input.service_name == SQL_SERVICE_NAME
266
+ ):
211
267
  return AWSRDSIAMCredentialsProvider(
212
268
  username=provider_input.username,
213
269
  host=provider_input.host,
214
270
  port=provider_input.port,
215
271
  )
216
272
 
217
- if provider_input.service_name == SupportedServiceType.EVENT_BROKER:
273
+ if (
274
+ provider_input.service_type == SupportedServiceType.TRACKER_STORE
275
+ and provider_input.service_name == REDIS_SERVICE_NAME
276
+ ):
277
+ return AWSElasticacheRedisIAMCredentialsProvider(
278
+ username=provider_input.username,
279
+ cluster_name=provider_input.cluster_name,
280
+ )
281
+
282
+ if provider_input.service_type == SupportedServiceType.EVENT_BROKER:
218
283
  return AWSMSKafkaIAMCredentialsProvider()
219
284
 
220
- if provider_input.service_name == SupportedServiceType.LOCK_STORE:
285
+ if provider_input.service_type == SupportedServiceType.LOCK_STORE:
221
286
  return AWSElasticacheRedisIAMCredentialsProvider(
222
287
  username=provider_input.username,
223
288
  cluster_name=provider_input.cluster_name,
@@ -47,7 +47,8 @@ class SupportedServiceType(Enum):
47
47
  class IAMCredentialsProviderInput(BaseModel):
48
48
  """Input data for creating an IAM credentials provider."""
49
49
 
50
- service_name: SupportedServiceType
50
+ service_type: SupportedServiceType
51
+ service_name: str
51
52
  username: Optional[str] = None
52
53
  host: Optional[str] = None
53
54
  port: Optional[int] = None
rasa/core/lock_store.py CHANGED
@@ -18,6 +18,9 @@ from pydantic import (
18
18
 
19
19
  import rasa.shared.utils.common
20
20
  from rasa.core.constants import DEFAULT_LOCK_LIFETIME, IAM_CLOUD_PROVIDER_ENV_VAR_NAME
21
+ from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
22
+ SupportedServiceType,
23
+ )
21
24
  from rasa.core.lock import TicketLock
22
25
  from rasa.core.redis_connection_factory import (
23
26
  DeploymentMode,
@@ -336,6 +339,7 @@ class RedisLockStore(LockStore):
336
339
  redis_config = RedisConfig(
337
340
  host=str(self.config.host),
338
341
  port=self.config.port,
342
+ service_type=SupportedServiceType.LOCK_STORE,
339
343
  db=self.config.db,
340
344
  username=self.config.username,
341
345
  password=self.config.password,
@@ -6,7 +6,10 @@ import redis
6
6
  import structlog
7
7
  from pydantic import BaseModel, ConfigDict
8
8
 
9
- from rasa.core.constants import AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME
9
+ from rasa.core.constants import (
10
+ AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME,
11
+ REDIS_SERVICE_NAME,
12
+ )
10
13
  from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
11
14
  IAMCredentialsProvider,
12
15
  IAMCredentialsProviderInput,
@@ -65,6 +68,7 @@ class RedisConfig(BaseModel):
65
68
 
66
69
  host: Text = "localhost"
67
70
  port: int = 6379
71
+ service_type: SupportedServiceType
68
72
  username: Optional[Text] = None
69
73
  password: Optional[Text] = None
70
74
  use_ssl: bool = False
@@ -117,7 +121,8 @@ class RedisConnectionFactory:
117
121
 
118
122
  iam_credentials_provider = create_iam_credentials_provider(
119
123
  IAMCredentialsProviderInput(
120
- service_name=SupportedServiceType.LOCK_STORE,
124
+ service_type=config.service_type,
125
+ service_name=REDIS_SERVICE_NAME,
121
126
  username=config.username,
122
127
  cluster_name=os.getenv(AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME),
123
128
  )
@@ -8,6 +8,9 @@ from pydantic import ValidationError
8
8
 
9
9
  import rasa.shared
10
10
  from rasa.core.brokers.broker import EventBroker
11
+ from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
12
+ SupportedServiceType,
13
+ )
11
14
  from rasa.core.redis_connection_factory import (
12
15
  DeploymentMode,
13
16
  RedisConfig,
@@ -54,6 +57,7 @@ class RedisTrackerStore(TrackerStore, SerializedTrackerAsText):
54
57
  config = RedisConfig(
55
58
  host=host,
56
59
  port=port,
60
+ service_type=SupportedServiceType.TRACKER_STORE,
57
61
  db=db,
58
62
  username=username,
59
63
  password=password,
@@ -27,6 +27,7 @@ from rasa.core.constants import (
27
27
  POSTGRESQL_MAX_OVERFLOW,
28
28
  POSTGRESQL_POOL_SIZE,
29
29
  POSTGRESQL_SCHEMA,
30
+ SQL_SERVICE_NAME,
30
31
  SQL_TRACKER_STORE_SSL_MODE_ENV_VAR_NAME,
31
32
  SQL_TRACKER_STORE_SSL_ROOT_CERTIFICATE_ENV_VAR_NAME,
32
33
  )
@@ -229,7 +230,8 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
229
230
 
230
231
  iam_credentials_provider = create_iam_credentials_provider(
231
232
  IAMCredentialsProviderInput(
232
- service_name=SupportedServiceType.TRACKER_STORE,
233
+ service_type=SupportedServiceType.TRACKER_STORE,
234
+ service_name=SQL_SERVICE_NAME,
233
235
  username=username,
234
236
  host=host,
235
237
  port=port,
rasa/version.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # this file will automatically be changed,
2
2
  # do not add anything but the version number here!
3
- __version__ = "3.14.0rc1"
3
+ __version__ = "3.14.0rc2"