rasa-pro 3.14.0rc1__py3-none-any.whl → 3.14.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/builder/copilot/constants.py +4 -1
- rasa/builder/copilot/copilot.py +155 -79
- rasa/builder/copilot/models.py +304 -108
- rasa/builder/copilot/prompts/copilot_training_error_handler_prompt.jinja2 +53 -0
- rasa/builder/guardrails/policy_checker.py +1 -1
- rasa/builder/jobs.py +182 -12
- rasa/builder/models.py +12 -3
- rasa/builder/service.py +15 -2
- rasa/cli/project_templates/finance/domain/general/help.yml +0 -0
- rasa/core/brokers/kafka.py +5 -1
- rasa/core/concurrent_lock_store.py +38 -21
- rasa/core/constants.py +6 -0
- rasa/core/iam_credentials_providers/aws_iam_credentials_providers.py +69 -4
- rasa/core/iam_credentials_providers/credentials_provider_protocol.py +2 -1
- rasa/core/lock_store.py +4 -0
- rasa/core/redis_connection_factory.py +7 -2
- rasa/core/tracker_stores/redis_tracker_store.py +4 -0
- rasa/core/tracker_stores/sql_tracker_store.py +3 -1
- rasa/version.py +1 -1
- {rasa_pro-3.14.0rc1.dist-info → rasa_pro-3.14.0rc2.dist-info}/METADATA +2 -1
- {rasa_pro-3.14.0rc1.dist-info → rasa_pro-3.14.0rc2.dist-info}/RECORD +24 -22
- {rasa_pro-3.14.0rc1.dist-info → rasa_pro-3.14.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0rc1.dist-info → rasa_pro-3.14.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0rc1.dist-info → rasa_pro-3.14.0rc2.dist-info}/entry_points.txt +0 -0
rasa/builder/jobs.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
|
1
|
-
from typing import Any, Optional
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
2
|
|
|
3
3
|
import structlog
|
|
4
4
|
from sanic import Sanic
|
|
5
5
|
|
|
6
6
|
from rasa.builder import config
|
|
7
|
+
from rasa.builder.copilot.models import (
|
|
8
|
+
CopilotContext,
|
|
9
|
+
FileContent,
|
|
10
|
+
InternalCopilotRequestChatMessage,
|
|
11
|
+
LogContent,
|
|
12
|
+
ResponseCategory,
|
|
13
|
+
TrainingErrorLog,
|
|
14
|
+
)
|
|
7
15
|
from rasa.builder.exceptions import (
|
|
8
16
|
LLMGenerationError,
|
|
9
17
|
ProjectGenerationError,
|
|
@@ -11,6 +19,7 @@ from rasa.builder.exceptions import (
|
|
|
11
19
|
ValidationError,
|
|
12
20
|
)
|
|
13
21
|
from rasa.builder.job_manager import JobInfo, job_manager
|
|
22
|
+
from rasa.builder.llm_service import llm_service
|
|
14
23
|
from rasa.builder.models import (
|
|
15
24
|
JobStatus,
|
|
16
25
|
JobStatusEvent,
|
|
@@ -28,9 +37,14 @@ structlogger = structlog.get_logger()
|
|
|
28
37
|
|
|
29
38
|
|
|
30
39
|
async def push_job_status_event(
|
|
31
|
-
job: JobInfo,
|
|
40
|
+
job: JobInfo,
|
|
41
|
+
status: JobStatus,
|
|
42
|
+
message: Optional[str] = None,
|
|
43
|
+
payload: Optional[Dict[str, Any]] = None,
|
|
32
44
|
) -> None:
|
|
33
|
-
event = JobStatusEvent.from_status(
|
|
45
|
+
event = JobStatusEvent.from_status(
|
|
46
|
+
status=status.value, message=message, payload=payload
|
|
47
|
+
)
|
|
34
48
|
job.status = status.value
|
|
35
49
|
await job.put(event)
|
|
36
50
|
|
|
@@ -210,7 +224,7 @@ async def run_template_to_bot_job(
|
|
|
210
224
|
async def run_replace_all_files_job(
|
|
211
225
|
app: "Sanic",
|
|
212
226
|
job: JobInfo,
|
|
213
|
-
bot_files:
|
|
227
|
+
bot_files: Dict[str, Any],
|
|
214
228
|
) -> None:
|
|
215
229
|
"""Run the replace-all-files job in the background.
|
|
216
230
|
|
|
@@ -228,7 +242,7 @@ async def run_replace_all_files_job(
|
|
|
228
242
|
try:
|
|
229
243
|
project_generator.replace_all_bot_files(bot_files)
|
|
230
244
|
|
|
231
|
-
#
|
|
245
|
+
# Validating
|
|
232
246
|
await push_job_status_event(job, JobStatus.validating)
|
|
233
247
|
training_input = project_generator.get_training_input()
|
|
234
248
|
validation_error = await validate_project(training_input.importer)
|
|
@@ -236,12 +250,13 @@ async def run_replace_all_files_job(
|
|
|
236
250
|
raise ValidationError(validation_error)
|
|
237
251
|
await push_job_status_event(job, JobStatus.validation_success)
|
|
238
252
|
|
|
239
|
-
#
|
|
253
|
+
# Training
|
|
240
254
|
await push_job_status_event(job, JobStatus.training)
|
|
241
255
|
agent = await train_and_load_agent(training_input)
|
|
242
256
|
update_agent(agent, app)
|
|
243
257
|
await push_job_status_event(job, JobStatus.train_success)
|
|
244
258
|
|
|
259
|
+
# Send final done event
|
|
245
260
|
await push_job_status_event(job, JobStatus.done)
|
|
246
261
|
job_manager.mark_done(job)
|
|
247
262
|
|
|
@@ -257,26 +272,181 @@ async def run_replace_all_files_job(
|
|
|
257
272
|
included_log_levels=log_levels,
|
|
258
273
|
)
|
|
259
274
|
error_message = exc.get_error_message_with_logs(log_levels=log_levels)
|
|
260
|
-
|
|
261
|
-
|
|
275
|
+
# Push error event and start copilot analysis job
|
|
276
|
+
await push_error_and_start_copilot_analysis(
|
|
277
|
+
app,
|
|
278
|
+
job,
|
|
279
|
+
JobStatus.validation_error,
|
|
280
|
+
error_message,
|
|
281
|
+
bot_files,
|
|
262
282
|
)
|
|
283
|
+
|
|
284
|
+
# After error mark job as done
|
|
263
285
|
job_manager.mark_done(job, error=error_message)
|
|
264
286
|
|
|
265
287
|
except TrainingError as exc:
|
|
288
|
+
error_message = str(exc)
|
|
266
289
|
structlogger.debug(
|
|
267
290
|
"replace_all_files_job.train_error",
|
|
268
291
|
job_id=job.id,
|
|
269
|
-
error=
|
|
292
|
+
error=error_message,
|
|
270
293
|
)
|
|
271
|
-
|
|
272
|
-
|
|
294
|
+
# Push error event and start copilot analysis job
|
|
295
|
+
await push_error_and_start_copilot_analysis(
|
|
296
|
+
app,
|
|
297
|
+
job,
|
|
298
|
+
JobStatus.train_error,
|
|
299
|
+
error_message,
|
|
300
|
+
bot_files,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# After error mark job as done
|
|
304
|
+
job_manager.mark_done(job, error=error_message)
|
|
273
305
|
|
|
274
306
|
except Exception as exc:
|
|
275
307
|
# Capture full traceback for anything truly unexpected
|
|
308
|
+
error_message = str(exc)
|
|
276
309
|
structlogger.exception(
|
|
277
310
|
"replace_all_files_job.unexpected_error",
|
|
278
311
|
job_id=job.id,
|
|
312
|
+
error=error_message,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Push error event and start copilot analysis job
|
|
316
|
+
await push_error_and_start_copilot_analysis(
|
|
317
|
+
app,
|
|
318
|
+
job,
|
|
319
|
+
JobStatus.error,
|
|
320
|
+
error_message,
|
|
321
|
+
bot_files,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# After error mark job as done
|
|
325
|
+
job_manager.mark_done(job, error=str(exc))
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
async def push_error_and_start_copilot_analysis(
|
|
329
|
+
app: "Sanic",
|
|
330
|
+
original_job: JobInfo,
|
|
331
|
+
original_job_status: JobStatus,
|
|
332
|
+
error_message: str,
|
|
333
|
+
bot_files: Dict[str, Any],
|
|
334
|
+
) -> None:
|
|
335
|
+
"""Start a copilot analysis job and notify the client.
|
|
336
|
+
|
|
337
|
+
Creates a copilot analysis job and sends the new job ID to the client. The new
|
|
338
|
+
job runs in the background.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
app: The Sanic application instance
|
|
342
|
+
original_job: The original job that failed
|
|
343
|
+
original_job_status: The status of the job that failed
|
|
344
|
+
error_message: The error message to analyze
|
|
345
|
+
bot_files: The bot files to include in analysis
|
|
346
|
+
"""
|
|
347
|
+
# Create a copilot analysis job. Send the new job ID to the client and
|
|
348
|
+
# run the Copilot Analysis job in the background.
|
|
349
|
+
message = "Failed to train the assistant. Starting copilot analysis."
|
|
350
|
+
|
|
351
|
+
copilot_job = job_manager.create_job()
|
|
352
|
+
# Push the error status event for the original job
|
|
353
|
+
await push_job_status_event(
|
|
354
|
+
original_job,
|
|
355
|
+
original_job_status,
|
|
356
|
+
message=message,
|
|
357
|
+
payload={"copilot_job_id": copilot_job.id},
|
|
358
|
+
)
|
|
359
|
+
# Run the copilot analysis job in the background
|
|
360
|
+
app.add_task(
|
|
361
|
+
run_copilot_training_error_analysis_job(
|
|
362
|
+
app, copilot_job, error_message, bot_files
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
structlogger.debug(
|
|
366
|
+
f"update_files_job.{original_job_status.value}.copilot_analysis_start",
|
|
367
|
+
event_info=message,
|
|
368
|
+
job_id=original_job.id,
|
|
369
|
+
error=error_message,
|
|
370
|
+
copilot_job_id=copilot_job.id,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
async def run_copilot_training_error_analysis_job(
|
|
375
|
+
app: "Sanic",
|
|
376
|
+
job: JobInfo,
|
|
377
|
+
training_error_message: str,
|
|
378
|
+
bot_files: Dict[str, Any],
|
|
379
|
+
) -> None:
|
|
380
|
+
"""Run copilot training error analysis job."""
|
|
381
|
+
await push_job_status_event(job, JobStatus.received)
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
# Create message content blocks with log content and available files
|
|
385
|
+
log_content_block = LogContent(
|
|
386
|
+
type="log", content=training_error_message, context="training_error"
|
|
387
|
+
)
|
|
388
|
+
file_content_blocks = [
|
|
389
|
+
FileContent(type="file", file_path=file_path, file_content=file_content)
|
|
390
|
+
for file_path, file_content in bot_files.items()
|
|
391
|
+
]
|
|
392
|
+
context = CopilotContext(
|
|
393
|
+
tracker_context=None, # No conversation context needed
|
|
394
|
+
copilot_chat_history=[
|
|
395
|
+
InternalCopilotRequestChatMessage(
|
|
396
|
+
role="internal_copilot_request",
|
|
397
|
+
content=[log_content_block, *file_content_blocks],
|
|
398
|
+
response_category=ResponseCategory.TRAINING_ERROR_LOG_ANALYSIS,
|
|
399
|
+
)
|
|
400
|
+
],
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Generate copilot response
|
|
404
|
+
copilot_client = llm_service.instantiate_copilot()
|
|
405
|
+
(
|
|
406
|
+
original_stream,
|
|
407
|
+
generation_context,
|
|
408
|
+
) = await copilot_client.generate_response(context)
|
|
409
|
+
|
|
410
|
+
copilot_response_handler = llm_service.instantiate_handler(
|
|
411
|
+
config.COPILOT_HANDLER_ROLLING_BUFFER_SIZE
|
|
412
|
+
)
|
|
413
|
+
intercepted_stream = copilot_response_handler.handle_response(original_stream)
|
|
414
|
+
|
|
415
|
+
# Stream the copilot response as job events
|
|
416
|
+
async for token in intercepted_stream:
|
|
417
|
+
# Send each token as a job event using the same format as /copilot endpoint
|
|
418
|
+
await push_job_status_event(
|
|
419
|
+
job, JobStatus.copilot_analyzing, payload=token.sse_data
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# Send references (if any) as part of copilot_analyzing stream
|
|
423
|
+
if generation_context.relevant_documents:
|
|
424
|
+
reference_section = copilot_response_handler.extract_references(
|
|
425
|
+
generation_context.relevant_documents
|
|
426
|
+
)
|
|
427
|
+
await push_job_status_event(
|
|
428
|
+
job, JobStatus.copilot_analyzing, payload=reference_section.sse_data
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# Send original error log as part of copilot_analyzing stream
|
|
432
|
+
training_error_log = TrainingErrorLog(logs=[log_content_block])
|
|
433
|
+
await push_job_status_event(
|
|
434
|
+
job, JobStatus.copilot_analyzing, payload=training_error_log.sse_data
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Send success status
|
|
438
|
+
await push_job_status_event(job, JobStatus.copilot_analysis_success)
|
|
439
|
+
|
|
440
|
+
await push_job_status_event(job, JobStatus.done)
|
|
441
|
+
job_manager.mark_done(job)
|
|
442
|
+
|
|
443
|
+
except Exception as exc:
|
|
444
|
+
structlogger.exception(
|
|
445
|
+
"copilot_training_error_analysis_job.error",
|
|
446
|
+
job_id=job.id,
|
|
279
447
|
error=str(exc),
|
|
280
448
|
)
|
|
281
|
-
await push_job_status_event(
|
|
449
|
+
await push_job_status_event(
|
|
450
|
+
job, JobStatus.copilot_analysis_error, message=str(exc)
|
|
451
|
+
)
|
|
282
452
|
job_manager.mark_done(job, error=str(exc))
|
rasa/builder/models.py
CHANGED
|
@@ -136,12 +136,14 @@ class JobStatusEvent(ServerSentEvent):
|
|
|
136
136
|
cls,
|
|
137
137
|
status: str,
|
|
138
138
|
message: Optional[str] = None,
|
|
139
|
+
payload: Optional[Dict[str, Any]] = None,
|
|
139
140
|
) -> "JobStatusEvent":
|
|
140
141
|
"""Factory for job-status events.
|
|
141
142
|
|
|
142
143
|
Args:
|
|
143
144
|
status: The job status (e.g. "training", "train_success").
|
|
144
145
|
message: Optional error message for error events.
|
|
146
|
+
payload: Optional additional payload data to include in the event.
|
|
145
147
|
|
|
146
148
|
Returns:
|
|
147
149
|
A JobStatusEvent instance with the appropriate event type and data.
|
|
@@ -150,11 +152,13 @@ class JobStatusEvent(ServerSentEvent):
|
|
|
150
152
|
ServerSentEventType.error if message else ServerSentEventType.progress
|
|
151
153
|
)
|
|
152
154
|
|
|
153
|
-
|
|
155
|
+
event_payload: Dict[str, Any] = {"status": status}
|
|
154
156
|
if message:
|
|
155
|
-
|
|
157
|
+
event_payload["message"] = message
|
|
158
|
+
if payload:
|
|
159
|
+
event_payload.update(payload)
|
|
156
160
|
|
|
157
|
-
return cls(event=event_type.value, data=
|
|
161
|
+
return cls(event=event_type.value, data=event_payload)
|
|
158
162
|
|
|
159
163
|
|
|
160
164
|
class ValidationResult(BaseModel):
|
|
@@ -193,6 +197,11 @@ class JobStatus(str, Enum):
|
|
|
193
197
|
validation_success = "validation_success"
|
|
194
198
|
validation_error = "validation_error"
|
|
195
199
|
|
|
200
|
+
copilot_analysis_start = "copilot_analysis_start"
|
|
201
|
+
copilot_analyzing = "copilot_analyzing"
|
|
202
|
+
copilot_analysis_success = "copilot_analysis_success"
|
|
203
|
+
copilot_analysis_error = "copilot_analysis_error"
|
|
204
|
+
|
|
196
205
|
|
|
197
206
|
class JobCreateResponse(BaseModel):
|
|
198
207
|
job_id: str = Field(...)
|
rasa/builder/service.py
CHANGED
|
@@ -532,7 +532,19 @@ async def get_bot_files(request: Request) -> HTTPResponse:
|
|
|
532
532
|
"**Error Events (can occur at any time):**\n"
|
|
533
533
|
"- `validation_error` - Bot configuration files are invalid\n"
|
|
534
534
|
"- `train_error` - Files updated but training failed\n"
|
|
535
|
+
"- `copilot_analysis_start` - Copilot analysis started (includes `copilot_job_id` "
|
|
536
|
+
"in payload)\n"
|
|
535
537
|
"- `error` - Unexpected error occurred\n\n"
|
|
538
|
+
"**Copilot Analysis Events:**\n"
|
|
539
|
+
"When training or validation fails, a separate copilot analysis job is "
|
|
540
|
+
"automatically started. The `copilot_analysis_start` event includes a "
|
|
541
|
+
"`copilot_job_id` in the payload.\nConnect to `/job-events/<copilot_job_id>` to "
|
|
542
|
+
"receive the following events:\n"
|
|
543
|
+
"- `copilot_analyzing` - Copilot is analyzing errors and providing suggestions. "
|
|
544
|
+
"Uses the same SSE event payload format as the `/copilot` endpoint with `content`, "
|
|
545
|
+
"`response_category`, and `completeness` fields.\n"
|
|
546
|
+
"- `copilot_analysis_success` - Copilot analysis completed with references.\n"
|
|
547
|
+
"- `copilot_analysis_error` - Copilot analysis failed\n\n"
|
|
536
548
|
"**Usage:**\n"
|
|
537
549
|
"1. Send POST request with Content-Type: application/json\n"
|
|
538
550
|
"2. The response will be a JSON object `{job_id: ...}`\n"
|
|
@@ -1042,7 +1054,8 @@ async def copilot(request: Request) -> None:
|
|
|
1042
1054
|
# Offload telemetry logging to a background task
|
|
1043
1055
|
request.app.add_task(
|
|
1044
1056
|
asyncio.to_thread(
|
|
1045
|
-
telemetry.log_user_turn,
|
|
1057
|
+
telemetry.log_user_turn,
|
|
1058
|
+
req.last_message.get_flattened_text_content(),
|
|
1046
1059
|
)
|
|
1047
1060
|
)
|
|
1048
1061
|
|
|
@@ -1159,7 +1172,7 @@ async def copilot(request: Request) -> None:
|
|
|
1159
1172
|
system_message=generation_context.system_message,
|
|
1160
1173
|
chat_history=generation_context.chat_history,
|
|
1161
1174
|
last_user_message=(
|
|
1162
|
-
req.last_message.
|
|
1175
|
+
req.last_message.get_flattened_text_content()
|
|
1163
1176
|
if (req.last_message and req.last_message.role == ROLE_USER)
|
|
1164
1177
|
else None
|
|
1165
1178
|
),
|
|
File without changes
|
rasa/core/brokers/kafka.py
CHANGED
|
@@ -12,6 +12,7 @@ import structlog
|
|
|
12
12
|
|
|
13
13
|
import rasa.shared.utils.common
|
|
14
14
|
from rasa.core.brokers.broker import EventBroker
|
|
15
|
+
from rasa.core.constants import KAFKA_SERVICE_NAME
|
|
15
16
|
from rasa.core.exceptions import KafkaProducerInitializationError
|
|
16
17
|
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
17
18
|
IAMCredentialsProviderInput,
|
|
@@ -99,7 +100,10 @@ class KafkaEventBroker(EventBroker):
|
|
|
99
100
|
self.queue_size = kwargs.get("queue_size")
|
|
100
101
|
self.ssl_check_hostname = "https" if ssl_check_hostname else None
|
|
101
102
|
self.iam_credentials_provider = create_iam_credentials_provider(
|
|
102
|
-
IAMCredentialsProviderInput(
|
|
103
|
+
IAMCredentialsProviderInput(
|
|
104
|
+
service_type=SupportedServiceType.EVENT_BROKER,
|
|
105
|
+
service_name=KAFKA_SERVICE_NAME,
|
|
106
|
+
)
|
|
103
107
|
)
|
|
104
108
|
|
|
105
109
|
# PII management attributes
|
|
@@ -6,6 +6,9 @@ from typing import Deque, Optional, Text
|
|
|
6
6
|
import structlog
|
|
7
7
|
from pydantic import ValidationError
|
|
8
8
|
|
|
9
|
+
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
10
|
+
SupportedServiceType,
|
|
11
|
+
)
|
|
9
12
|
from rasa.core.lock import Ticket, TicketLock
|
|
10
13
|
from rasa.core.lock_store import (
|
|
11
14
|
DEFAULT_SOCKET_TIMEOUT_IN_SECONDS,
|
|
@@ -108,6 +111,7 @@ class ConcurrentRedisLockStore(LockStore):
|
|
|
108
111
|
redis_config = RedisConfig(
|
|
109
112
|
host=host,
|
|
110
113
|
port=port,
|
|
114
|
+
service_type=SupportedServiceType.LOCK_STORE,
|
|
111
115
|
db=db,
|
|
112
116
|
username=username,
|
|
113
117
|
password=password,
|
|
@@ -150,32 +154,45 @@ class ConcurrentRedisLockStore(LockStore):
|
|
|
150
154
|
),
|
|
151
155
|
)
|
|
152
156
|
|
|
153
|
-
def
|
|
154
|
-
"""
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
157
|
+
def _scan_cluster_keys(self, pattern: Text) -> list:
|
|
158
|
+
"""Scan keys in cluster mode with proper cursor handling."""
|
|
159
|
+
keys = []
|
|
160
|
+
cursor = 0
|
|
161
|
+
|
|
162
|
+
while True:
|
|
163
|
+
try:
|
|
164
|
+
cursor, batch_keys = self.red.scan(cursor, match=pattern, count=100)
|
|
165
|
+
keys.extend(batch_keys)
|
|
166
|
+
|
|
167
|
+
if isinstance(cursor, dict):
|
|
168
|
+
# cursor is a dict mapping each node to its scan position. e.g
|
|
169
|
+
# {'127.0.0.1:7000': 0, '127.0.0.1:7001': 5, '127.0.0.1:7002': 0}
|
|
170
|
+
# A cursor value of 0 means that node has finished scanning
|
|
171
|
+
# When all nodes show 0, the entire cluster scan is complete
|
|
172
|
+
if all(v == 0 for v in cursor.values()):
|
|
173
|
+
break
|
|
174
|
+
else:
|
|
175
|
+
# if scan is complete
|
|
164
176
|
if cursor == 0:
|
|
165
177
|
break
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
# Standard and sentinel modes use KEYS
|
|
175
|
-
keys = self.red.keys(pattern)
|
|
178
|
+
|
|
179
|
+
except Exception as e:
|
|
180
|
+
structlogger.warning(
|
|
181
|
+
"concurrent_redis_lock_store._get_keys_by_pattern.scan_interrupted",
|
|
182
|
+
event_info=f"SCAN interrupted in cluster mode: {e}. "
|
|
183
|
+
f"Returning {len(keys)} keys found so far.",
|
|
184
|
+
)
|
|
185
|
+
break
|
|
176
186
|
|
|
177
187
|
return keys
|
|
178
188
|
|
|
189
|
+
def _get_keys_by_pattern(self, pattern: Text) -> list:
|
|
190
|
+
"""Get keys by pattern, using SCAN for cluster mode and KEYS for others."""
|
|
191
|
+
if self.deployment_mode == DeploymentMode.CLUSTER.value:
|
|
192
|
+
return self._scan_cluster_keys(pattern)
|
|
193
|
+
else:
|
|
194
|
+
return self.red.keys(pattern)
|
|
195
|
+
|
|
179
196
|
def issue_ticket(
|
|
180
197
|
self, conversation_id: Text, lock_lifetime: float = LOCK_LIFETIME
|
|
181
198
|
) -> int:
|
rasa/core/constants.py
CHANGED
|
@@ -123,3 +123,9 @@ SQL_TRACKER_STORE_SSL_ROOT_CERTIFICATE_ENV_VAR_NAME = (
|
|
|
123
123
|
"SQL_TRACKER_STORE_SSL_ROOT_CERTIFICATE"
|
|
124
124
|
)
|
|
125
125
|
AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME = "AWS_ELASTICACHE_CLUSTER_NAME"
|
|
126
|
+
RDS_SQL_DB_AWS_IAM_ENABLED_ENV_VAR_NAME = "RDS_SQL_DB_AWS_IAM_ENABLED"
|
|
127
|
+
ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME = "ELASTICACHE_REDIS_AWS_IAM_ENABLED"
|
|
128
|
+
KAFKA_MSK_AWS_IAM_ENABLED_ENV_VAR_NAME = "KAFKA_MSK_AWS_IAM_ENABLED"
|
|
129
|
+
SQL_SERVICE_NAME = "sql"
|
|
130
|
+
REDIS_SERVICE_NAME = "redis"
|
|
131
|
+
KAFKA_SERVICE_NAME = "kafka"
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import threading
|
|
3
3
|
import time
|
|
4
|
-
from typing import Optional, Tuple
|
|
4
|
+
from typing import Dict, Optional, Tuple
|
|
5
5
|
from urllib.parse import ParseResult, urlencode, urlunparse
|
|
6
6
|
|
|
7
7
|
import boto3
|
|
@@ -14,6 +14,14 @@ from botocore.session import get_session
|
|
|
14
14
|
from botocore.signers import RequestSigner
|
|
15
15
|
from cachetools import TTLCache, cached
|
|
16
16
|
|
|
17
|
+
from rasa.core.constants import (
|
|
18
|
+
ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME,
|
|
19
|
+
KAFKA_MSK_AWS_IAM_ENABLED_ENV_VAR_NAME,
|
|
20
|
+
KAFKA_SERVICE_NAME,
|
|
21
|
+
RDS_SQL_DB_AWS_IAM_ENABLED_ENV_VAR_NAME,
|
|
22
|
+
REDIS_SERVICE_NAME,
|
|
23
|
+
SQL_SERVICE_NAME,
|
|
24
|
+
)
|
|
17
25
|
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
18
26
|
IAMCredentialsProvider,
|
|
19
27
|
IAMCredentialsProviderInput,
|
|
@@ -24,6 +32,25 @@ from rasa.shared.exceptions import ConnectionException
|
|
|
24
32
|
|
|
25
33
|
structlogger = structlog.get_logger(__name__)
|
|
26
34
|
|
|
35
|
+
SERVICE_CONFIG: Dict[Tuple[SupportedServiceType, str], str] = {
|
|
36
|
+
(
|
|
37
|
+
SupportedServiceType.TRACKER_STORE,
|
|
38
|
+
SQL_SERVICE_NAME,
|
|
39
|
+
): RDS_SQL_DB_AWS_IAM_ENABLED_ENV_VAR_NAME,
|
|
40
|
+
(
|
|
41
|
+
SupportedServiceType.TRACKER_STORE,
|
|
42
|
+
REDIS_SERVICE_NAME,
|
|
43
|
+
): ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME,
|
|
44
|
+
(
|
|
45
|
+
SupportedServiceType.EVENT_BROKER,
|
|
46
|
+
KAFKA_SERVICE_NAME,
|
|
47
|
+
): KAFKA_MSK_AWS_IAM_ENABLED_ENV_VAR_NAME,
|
|
48
|
+
(
|
|
49
|
+
SupportedServiceType.LOCK_STORE,
|
|
50
|
+
REDIS_SERVICE_NAME,
|
|
51
|
+
): ELASTICACHE_REDIS_AWS_IAM_ENABLED_ENV_VAR_NAME,
|
|
52
|
+
}
|
|
53
|
+
|
|
27
54
|
|
|
28
55
|
class AWSRDSIAMCredentialsProvider(IAMCredentialsProvider):
|
|
29
56
|
"""Generates temporary credentials for AWS RDS using IAM roles."""
|
|
@@ -203,21 +230,59 @@ class AWSElasticacheRedisIAMCredentialsProvider(redis.CredentialProvider):
|
|
|
203
230
|
return TemporaryCredentials()
|
|
204
231
|
|
|
205
232
|
|
|
233
|
+
def is_iam_enabled(provider_input: "IAMCredentialsProviderInput") -> bool:
|
|
234
|
+
"""Checks if IAM authentication is enabled for the given service."""
|
|
235
|
+
service_type = provider_input.service_type
|
|
236
|
+
service_name = provider_input.service_name
|
|
237
|
+
iam_enabled_env_var_name = SERVICE_CONFIG.get((service_type, service_name))
|
|
238
|
+
|
|
239
|
+
if not iam_enabled_env_var_name:
|
|
240
|
+
structlogger.warning(
|
|
241
|
+
"rasa.core.aws_iam_credentials_providers.is_iam_enabled.unsupported_service",
|
|
242
|
+
event_info=f"IAM authentication check requested for unsupported service: "
|
|
243
|
+
f"{service_name}",
|
|
244
|
+
)
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
return os.getenv(iam_enabled_env_var_name, "false").lower() == "true"
|
|
248
|
+
|
|
249
|
+
|
|
206
250
|
def create_aws_iam_credentials_provider(
|
|
207
251
|
provider_input: "IAMCredentialsProviderInput",
|
|
208
252
|
) -> Optional["IAMCredentialsProvider"]:
|
|
209
253
|
"""Factory function to create an AWS IAM credentials provider."""
|
|
210
|
-
|
|
254
|
+
iam_enabled = is_iam_enabled(provider_input)
|
|
255
|
+
if not iam_enabled:
|
|
256
|
+
structlogger.debug(
|
|
257
|
+
"rasa.core.aws_iam_credentials_providers.create_provider.iam_not_enabled",
|
|
258
|
+
event_info=f"IAM authentication not enabled for service: "
|
|
259
|
+
f"{provider_input.service_type}",
|
|
260
|
+
)
|
|
261
|
+
return None
|
|
262
|
+
|
|
263
|
+
if (
|
|
264
|
+
provider_input.service_type == SupportedServiceType.TRACKER_STORE
|
|
265
|
+
and provider_input.service_name == SQL_SERVICE_NAME
|
|
266
|
+
):
|
|
211
267
|
return AWSRDSIAMCredentialsProvider(
|
|
212
268
|
username=provider_input.username,
|
|
213
269
|
host=provider_input.host,
|
|
214
270
|
port=provider_input.port,
|
|
215
271
|
)
|
|
216
272
|
|
|
217
|
-
if
|
|
273
|
+
if (
|
|
274
|
+
provider_input.service_type == SupportedServiceType.TRACKER_STORE
|
|
275
|
+
and provider_input.service_name == REDIS_SERVICE_NAME
|
|
276
|
+
):
|
|
277
|
+
return AWSElasticacheRedisIAMCredentialsProvider(
|
|
278
|
+
username=provider_input.username,
|
|
279
|
+
cluster_name=provider_input.cluster_name,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if provider_input.service_type == SupportedServiceType.EVENT_BROKER:
|
|
218
283
|
return AWSMSKafkaIAMCredentialsProvider()
|
|
219
284
|
|
|
220
|
-
if provider_input.
|
|
285
|
+
if provider_input.service_type == SupportedServiceType.LOCK_STORE:
|
|
221
286
|
return AWSElasticacheRedisIAMCredentialsProvider(
|
|
222
287
|
username=provider_input.username,
|
|
223
288
|
cluster_name=provider_input.cluster_name,
|
|
@@ -47,7 +47,8 @@ class SupportedServiceType(Enum):
|
|
|
47
47
|
class IAMCredentialsProviderInput(BaseModel):
|
|
48
48
|
"""Input data for creating an IAM credentials provider."""
|
|
49
49
|
|
|
50
|
-
|
|
50
|
+
service_type: SupportedServiceType
|
|
51
|
+
service_name: str
|
|
51
52
|
username: Optional[str] = None
|
|
52
53
|
host: Optional[str] = None
|
|
53
54
|
port: Optional[int] = None
|
rasa/core/lock_store.py
CHANGED
|
@@ -18,6 +18,9 @@ from pydantic import (
|
|
|
18
18
|
|
|
19
19
|
import rasa.shared.utils.common
|
|
20
20
|
from rasa.core.constants import DEFAULT_LOCK_LIFETIME, IAM_CLOUD_PROVIDER_ENV_VAR_NAME
|
|
21
|
+
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
22
|
+
SupportedServiceType,
|
|
23
|
+
)
|
|
21
24
|
from rasa.core.lock import TicketLock
|
|
22
25
|
from rasa.core.redis_connection_factory import (
|
|
23
26
|
DeploymentMode,
|
|
@@ -336,6 +339,7 @@ class RedisLockStore(LockStore):
|
|
|
336
339
|
redis_config = RedisConfig(
|
|
337
340
|
host=str(self.config.host),
|
|
338
341
|
port=self.config.port,
|
|
342
|
+
service_type=SupportedServiceType.LOCK_STORE,
|
|
339
343
|
db=self.config.db,
|
|
340
344
|
username=self.config.username,
|
|
341
345
|
password=self.config.password,
|
|
@@ -6,7 +6,10 @@ import redis
|
|
|
6
6
|
import structlog
|
|
7
7
|
from pydantic import BaseModel, ConfigDict
|
|
8
8
|
|
|
9
|
-
from rasa.core.constants import
|
|
9
|
+
from rasa.core.constants import (
|
|
10
|
+
AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME,
|
|
11
|
+
REDIS_SERVICE_NAME,
|
|
12
|
+
)
|
|
10
13
|
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
11
14
|
IAMCredentialsProvider,
|
|
12
15
|
IAMCredentialsProviderInput,
|
|
@@ -65,6 +68,7 @@ class RedisConfig(BaseModel):
|
|
|
65
68
|
|
|
66
69
|
host: Text = "localhost"
|
|
67
70
|
port: int = 6379
|
|
71
|
+
service_type: SupportedServiceType
|
|
68
72
|
username: Optional[Text] = None
|
|
69
73
|
password: Optional[Text] = None
|
|
70
74
|
use_ssl: bool = False
|
|
@@ -117,7 +121,8 @@ class RedisConnectionFactory:
|
|
|
117
121
|
|
|
118
122
|
iam_credentials_provider = create_iam_credentials_provider(
|
|
119
123
|
IAMCredentialsProviderInput(
|
|
120
|
-
|
|
124
|
+
service_type=config.service_type,
|
|
125
|
+
service_name=REDIS_SERVICE_NAME,
|
|
121
126
|
username=config.username,
|
|
122
127
|
cluster_name=os.getenv(AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME),
|
|
123
128
|
)
|
|
@@ -8,6 +8,9 @@ from pydantic import ValidationError
|
|
|
8
8
|
|
|
9
9
|
import rasa.shared
|
|
10
10
|
from rasa.core.brokers.broker import EventBroker
|
|
11
|
+
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
12
|
+
SupportedServiceType,
|
|
13
|
+
)
|
|
11
14
|
from rasa.core.redis_connection_factory import (
|
|
12
15
|
DeploymentMode,
|
|
13
16
|
RedisConfig,
|
|
@@ -54,6 +57,7 @@ class RedisTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
54
57
|
config = RedisConfig(
|
|
55
58
|
host=host,
|
|
56
59
|
port=port,
|
|
60
|
+
service_type=SupportedServiceType.TRACKER_STORE,
|
|
57
61
|
db=db,
|
|
58
62
|
username=username,
|
|
59
63
|
password=password,
|
|
@@ -27,6 +27,7 @@ from rasa.core.constants import (
|
|
|
27
27
|
POSTGRESQL_MAX_OVERFLOW,
|
|
28
28
|
POSTGRESQL_POOL_SIZE,
|
|
29
29
|
POSTGRESQL_SCHEMA,
|
|
30
|
+
SQL_SERVICE_NAME,
|
|
30
31
|
SQL_TRACKER_STORE_SSL_MODE_ENV_VAR_NAME,
|
|
31
32
|
SQL_TRACKER_STORE_SSL_ROOT_CERTIFICATE_ENV_VAR_NAME,
|
|
32
33
|
)
|
|
@@ -229,7 +230,8 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
229
230
|
|
|
230
231
|
iam_credentials_provider = create_iam_credentials_provider(
|
|
231
232
|
IAMCredentialsProviderInput(
|
|
232
|
-
|
|
233
|
+
service_type=SupportedServiceType.TRACKER_STORE,
|
|
234
|
+
service_name=SQL_SERVICE_NAME,
|
|
233
235
|
username=username,
|
|
234
236
|
host=host,
|
|
235
237
|
port=port,
|
rasa/version.py
CHANGED