rasa-pro 3.14.0.dev7__py3-none-any.whl → 3.14.0.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/agents/agent_manager.py +1 -1
- rasa/agents/constants.py +2 -2
- rasa/agents/protocol/a2a/a2a_agent.py +385 -227
- rasa/agents/protocol/mcp/mcp_base_agent.py +30 -13
- rasa/agents/protocol/mcp/mcp_open_agent.py +31 -8
- rasa/agents/protocol/mcp/mcp_task_agent.py +32 -9
- rasa/agents/schemas/agent_output.py +1 -1
- rasa/agents/utils.py +90 -1
- rasa/builder/README.md +120 -0
- rasa/builder/__init__.py +0 -0
- rasa/builder/auth.py +176 -0
- rasa/builder/config.py +92 -0
- rasa/builder/copilot/__init__.py +0 -0
- rasa/builder/copilot/constants.py +31 -0
- rasa/builder/copilot/copilot.py +450 -0
- rasa/builder/copilot/copilot_response_handler.py +522 -0
- rasa/builder/copilot/copilot_templated_message_provider.py +58 -0
- rasa/builder/copilot/exceptions.py +32 -0
- rasa/builder/copilot/models.py +500 -0
- rasa/builder/copilot/prompts/__init__.py +0 -0
- rasa/builder/copilot/prompts/copilot_system_prompt.jinja2 +766 -0
- rasa/builder/copilot/prompts/latest_user_message_context_prompt.jinja2 +61 -0
- rasa/builder/copilot/signing.py +305 -0
- rasa/builder/copilot/telemetry.py +210 -0
- rasa/builder/copilot/templated_messages/__init__.py +0 -0
- rasa/builder/copilot/templated_messages/copilot_internal_messages_templates.yml +16 -0
- rasa/builder/copilot/templated_messages/copilot_templated_responses.yml +38 -0
- rasa/builder/document_retrieval/__init__.py +0 -0
- rasa/builder/document_retrieval/constants.py +15 -0
- rasa/builder/document_retrieval/inkeep-rag-response-schema.json +64 -0
- rasa/builder/document_retrieval/inkeep_document_retrieval.py +238 -0
- rasa/builder/document_retrieval/models.py +62 -0
- rasa/builder/download.py +140 -0
- rasa/builder/exceptions.py +91 -0
- rasa/builder/guardrails/__init__.py +1 -0
- rasa/builder/guardrails/constants.py +9 -0
- rasa/builder/guardrails/exceptions.py +4 -0
- rasa/builder/guardrails/lakera.py +206 -0
- rasa/builder/guardrails/models.py +231 -0
- rasa/builder/guardrails/store.py +238 -0
- rasa/builder/guardrails/utils.py +328 -0
- rasa/builder/job_manager.py +87 -0
- rasa/builder/jobs.py +282 -0
- rasa/builder/llm_service.py +246 -0
- rasa/builder/logging_utils.py +265 -0
- rasa/builder/main.py +243 -0
- rasa/builder/models.py +216 -0
- rasa/builder/project_generator.py +458 -0
- rasa/builder/project_info.py +72 -0
- rasa/builder/scrape_rasa_docs.py +97 -0
- rasa/builder/service.py +1345 -0
- rasa/builder/shared/tracker_context.py +212 -0
- rasa/builder/skill_to_bot_prompt.jinja2 +164 -0
- rasa/builder/template_cache.py +244 -0
- rasa/builder/training_service.py +194 -0
- rasa/builder/validation_service.py +97 -0
- rasa/cli/project_templates/basic/README.md +23 -0
- rasa/cli/project_templates/basic/actions/__init__ +0 -0
- rasa/cli/project_templates/basic/actions/action_human_handoff.py +40 -0
- rasa/cli/project_templates/basic/actions/actions.md +10 -0
- rasa/cli/project_templates/basic/config.yml +29 -0
- rasa/cli/project_templates/basic/credentials.yml +33 -0
- rasa/cli/project_templates/basic/data/data.md +9 -0
- rasa/cli/project_templates/basic/data/general/feedback.yml +21 -0
- rasa/cli/project_templates/basic/data/general/goodbye.yml +6 -0
- rasa/cli/project_templates/basic/data/general/hello.yml +6 -0
- rasa/cli/project_templates/basic/data/general/help.yml +6 -0
- rasa/cli/project_templates/basic/data/general/human_handoff.yml +16 -0
- rasa/cli/project_templates/basic/data/general/show_faqs.yml +6 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_cannot_handle.yml +7 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_completed.yml +7 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_correction.yml +7 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_search.yml +8 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_session_start.yml +8 -0
- rasa/cli/project_templates/basic/docs/docs.md +5 -0
- rasa/cli/project_templates/basic/docs/template.txt +28 -0
- rasa/cli/project_templates/basic/domain/domain.md +8 -0
- rasa/cli/project_templates/basic/domain/general/feedback.yml +25 -0
- rasa/cli/project_templates/basic/domain/general/goodbye.yml +9 -0
- rasa/cli/project_templates/basic/domain/general/hello.yml +7 -0
- rasa/cli/project_templates/basic/domain/general/help.yml +21 -0
- rasa/cli/project_templates/basic/domain/general/human_handoff.yml +32 -0
- rasa/cli/project_templates/basic/domain/general/show_faqs.yml +14 -0
- rasa/cli/project_templates/basic/domain/system/patterns/pattern_cannot_handle.yml +5 -0
- rasa/cli/project_templates/basic/domain/system/patterns/pattern_session_start.yml +19 -0
- rasa/cli/project_templates/basic/endpoints.yml +67 -0
- rasa/cli/project_templates/basic/prompts/rephraser_demo_personality_prompt.jinja2 +38 -0
- rasa/cli/project_templates/default/config.yml +4 -0
- rasa/cli/project_templates/default/endpoints.yml +4 -0
- rasa/cli/project_templates/finance/README.md +25 -0
- rasa/cli/project_templates/finance/actions/__init__.py +46 -0
- rasa/cli/project_templates/finance/actions/accounts/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/accounts/action_ask_account.py +47 -0
- rasa/cli/project_templates/finance/actions/accounts/action_check_balance.py +40 -0
- rasa/cli/project_templates/finance/actions/action_session_start.py +74 -0
- rasa/cli/project_templates/finance/actions/actions.md +15 -0
- rasa/cli/project_templates/finance/actions/cards/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/cards/action_ask_card.py +48 -0
- rasa/cli/project_templates/finance/actions/cards/action_check_card_existence.py +36 -0
- rasa/cli/project_templates/finance/actions/cards/action_update_card_status.py +54 -0
- rasa/cli/project_templates/finance/actions/database.py +277 -0
- rasa/cli/project_templates/finance/actions/transfers/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/transfers/action_add_payee.py +52 -0
- rasa/cli/project_templates/finance/actions/transfers/action_ask_account_from.py +51 -0
- rasa/cli/project_templates/finance/actions/transfers/action_check_payee_existence.py +40 -0
- rasa/cli/project_templates/finance/actions/transfers/action_check_sufficient_funds.py +40 -0
- rasa/cli/project_templates/finance/actions/transfers/action_list_payees.py +46 -0
- rasa/cli/project_templates/finance/actions/transfers/action_process_immediate_payment.py +18 -0
- rasa/cli/project_templates/finance/actions/transfers/action_remove_payee.py +49 -0
- rasa/cli/project_templates/finance/actions/transfers/action_schedule_payment.py +19 -0
- rasa/cli/project_templates/finance/actions/transfers/action_validate_payment_date.py +36 -0
- rasa/cli/project_templates/finance/config.yml +23 -0
- rasa/cli/project_templates/finance/credentials.yml +32 -0
- rasa/cli/project_templates/finance/csvs/accounts.csv +8 -0
- rasa/cli/project_templates/finance/csvs/advisors.csv +7 -0
- rasa/cli/project_templates/finance/csvs/appointments.csv +211 -0
- rasa/cli/project_templates/finance/csvs/branches.csv +10 -0
- rasa/cli/project_templates/finance/csvs/cards.csv +11 -0
- rasa/cli/project_templates/finance/csvs/payees.csv +11 -0
- rasa/cli/project_templates/finance/csvs/transactions.csv +71 -0
- rasa/cli/project_templates/finance/csvs/users.csv +4 -0
- rasa/cli/project_templates/finance/data/accounts/check_balance.yml +10 -0
- rasa/cli/project_templates/finance/data/cards/block_card.yml +66 -0
- rasa/cli/project_templates/finance/data/cards/select_card.yml +12 -0
- rasa/cli/project_templates/finance/data/data.md +11 -0
- rasa/cli/project_templates/finance/data/general/bot_identity.yml +6 -0
- rasa/cli/project_templates/finance/data/general/feedback.yml +20 -0
- rasa/cli/project_templates/finance/data/general/goodbye.yml +6 -0
- rasa/cli/project_templates/finance/data/general/hello.yml +7 -0
- rasa/cli/project_templates/finance/data/general/help.yml +9 -0
- rasa/cli/project_templates/finance/data/general/human_handoff.yml +16 -0
- rasa/cli/project_templates/finance/data/general/welcome.yml +9 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_chitchat.yml +5 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_completed.yml +7 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_correction.yml +7 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_search.yml +8 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_session_start.yml +8 -0
- rasa/cli/project_templates/finance/data/system/source/accounts.json +51 -0
- rasa/cli/project_templates/finance/data/system/source/advisors.json +44 -0
- rasa/cli/project_templates/finance/data/system/source/appointments.json +1474 -0
- rasa/cli/project_templates/finance/data/system/source/branches.json +47 -0
- rasa/cli/project_templates/finance/data/system/source/cards.json +72 -0
- rasa/cli/project_templates/finance/data/system/source/payees.json +74 -0
- rasa/cli/project_templates/finance/data/system/source/transactions.json +492 -0
- rasa/cli/project_templates/finance/data/system/source/users.json +29 -0
- rasa/cli/project_templates/finance/data/transfers/add_payee.yml +29 -0
- rasa/cli/project_templates/finance/data/transfers/list_payees.yml +5 -0
- rasa/cli/project_templates/finance/data/transfers/remove_payee.yml +21 -0
- rasa/cli/project_templates/finance/data/transfers/transfer_money.yml +67 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/block_card/consequences_of_blocking_card.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/block_card/reasons_to_block_card.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/block_card/recovering_from_card_fraud.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/block_card/tips_for_card_security.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/block_card/what_to_do_if_card_is_lost.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/check_balance/account_balance_security.txt +7 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/check_balance/common_balance_inquiries.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/check_balance/methods_to_check_balance.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/check_balance/understanding_balance_updates.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/check_balance/what_to_do_if_balance_is_incorrect.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/manage_payees/benefits_of_authorised_payees.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/manage_payees/common_issues_with_payees.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/manage_payees/general_payee_information.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/manage_payees/payee_management_tips.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/manage_payees/understanding_payee_types.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/transfer_money/common_transfer_errors.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/transfer_money/fees_for_transfers.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/transfer_money/general_transfer_information.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/transfer_money/security_tips_for_transfers.txt +8 -0
- rasa/cli/project_templates/finance/docs/bank_of_rasa_faq/transfer_money/transfer_processing_times.txt +8 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part1.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part10.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part11.txt +48 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part12.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part13.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part14.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part15.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part16.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part17.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part18.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part19.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part2.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part20.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part21.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part22.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part23.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part24.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part25.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part26.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part27.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part28.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part29.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part3.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part30.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part31.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part32.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part33.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part34.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part35.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part36.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part37.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part38.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part39.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part4.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part40.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part41.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part42.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part43.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part44.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part45.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part46.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part47.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part48.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part49.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part5.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part50.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part51.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part52.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part53.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part54.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part55.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part56.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part57.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part58.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part59.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part6.txt +47 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part60.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part61.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part7.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part8.txt +50 -0
- rasa/cli/project_templates/finance/docs/huggingface_alpaca_dataset/questions_part9.txt +47 -0
- rasa/cli/project_templates/finance/domain/accounts/check_balance.yml +11 -0
- rasa/cli/project_templates/finance/domain/cards/block_card.yml +101 -0
- rasa/cli/project_templates/finance/domain/cards/select_card.yml +12 -0
- rasa/cli/project_templates/finance/domain/domain.md +10 -0
- rasa/cli/project_templates/finance/domain/general/agent_details.yml +12 -0
- rasa/cli/project_templates/finance/domain/general/bot_identity.yml +5 -0
- rasa/cli/project_templates/finance/domain/general/cannot_handle.yml +5 -0
- rasa/cli/project_templates/finance/domain/general/defaults.yml +24 -0
- rasa/cli/project_templates/finance/domain/general/feedback.yml +28 -0
- rasa/cli/project_templates/finance/domain/general/goodbye.yml +7 -0
- rasa/cli/project_templates/finance/domain/general/help.yml +5 -0
- rasa/cli/project_templates/finance/domain/general/human_handoff.yml +30 -0
- rasa/cli/project_templates/finance/domain/general/utils.yml +13 -0
- rasa/cli/project_templates/finance/domain/general/welcome.yml +8 -0
- rasa/cli/project_templates/finance/domain/transfers/add_payee.yml +47 -0
- rasa/cli/project_templates/finance/domain/transfers/list_payees.yml +4 -0
- rasa/cli/project_templates/finance/domain/transfers/remove_payee.yml +16 -0
- rasa/cli/project_templates/finance/domain/transfers/transfer_money.yml +79 -0
- rasa/cli/project_templates/finance/endpoints.yml +66 -0
- rasa/cli/project_templates/finance/prompts/rephraser_demo_personality_prompt.jinja2 +19 -0
- rasa/cli/project_templates/telco/README.md +25 -0
- rasa/cli/project_templates/telco/actions/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/actions.md +12 -0
- rasa/cli/project_templates/telco/actions/billing/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/billing/actions_billing.py +204 -0
- rasa/cli/project_templates/telco/actions/general/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/general/action_human_handoff.py +49 -0
- rasa/cli/project_templates/telco/actions/network/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/network/actions_get_data_from_db.py +48 -0
- rasa/cli/project_templates/telco/actions/network/actions_run_diagnostics.py +28 -0
- rasa/cli/project_templates/telco/actions/network/actions_session_start.py +18 -0
- rasa/cli/project_templates/telco/config.yml +29 -0
- rasa/cli/project_templates/telco/credentials.yml +33 -0
- rasa/cli/project_templates/telco/csvs/billing.csv +19 -0
- rasa/cli/project_templates/telco/csvs/customers.csv +5 -0
- rasa/cli/project_templates/telco/data/billing/flow_understand_bill.yml +45 -0
- rasa/cli/project_templates/telco/data/data.md +11 -0
- rasa/cli/project_templates/telco/data/general/bot_challenge.yml +6 -0
- rasa/cli/project_templates/telco/data/general/feedback.yml +20 -0
- rasa/cli/project_templates/telco/data/general/goodbye.yml +6 -0
- rasa/cli/project_templates/telco/data/general/hello.yml +6 -0
- rasa/cli/project_templates/telco/data/general/human_handoff.yml +16 -0
- rasa/cli/project_templates/telco/data/general/patterns.yml +30 -0
- rasa/cli/project_templates/telco/data/network/flow_reboot_router.yml +8 -0
- rasa/cli/project_templates/telco/data/network/flow_reset_router.yml +7 -0
- rasa/cli/project_templates/telco/data/network/flow_solve_internet_issue.yml +73 -0
- rasa/cli/project_templates/telco/docs/docs.md +5 -0
- rasa/cli/project_templates/telco/docs/network/reset_vs_rboot_router.txt +1 -0
- rasa/cli/project_templates/telco/docs/network/restart_router.txt +6 -0
- rasa/cli/project_templates/telco/docs/network/run_speed_test.txt +6 -0
- rasa/cli/project_templates/telco/domain/billing/understand_bill.yml +102 -0
- rasa/cli/project_templates/telco/domain/domain.md +14 -0
- rasa/cli/project_templates/telco/domain/general/bot_challenge.yml +4 -0
- rasa/cli/project_templates/telco/domain/general/feedback.yml +25 -0
- rasa/cli/project_templates/telco/domain/general/goodbye.yml +7 -0
- rasa/cli/project_templates/telco/domain/general/hello.yml +5 -0
- rasa/cli/project_templates/telco/domain/general/human_handoff.yml +26 -0
- rasa/cli/project_templates/telco/domain/general/patterns.yml +33 -0
- rasa/cli/project_templates/telco/domain/network/reboot_router.yml +21 -0
- rasa/cli/project_templates/telco/domain/network/reset_router.yml +12 -0
- rasa/cli/project_templates/telco/domain/network/run_speed_test.yml +25 -0
- rasa/cli/project_templates/telco/domain/network/solve_internet_issue.yml +75 -0
- rasa/cli/project_templates/telco/domain/shared.yml +129 -0
- rasa/cli/project_templates/telco/endpoints.yml +67 -0
- rasa/cli/project_templates/telco/prompts/rephraser_demo_personality_prompt.jinja2 +40 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/billing/understand_bill.yml +67 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/general/bot_challenge.yml +8 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/general/feedback.yml +46 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/general/goodbye.yml +9 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/general/hello.yml +8 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/general/human_handoff.yml +35 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/general/patterns.yml +23 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/network/solve_internet_issue.yml +57 -0
- rasa/cli/project_templates/tutorial/config.yml +2 -1
- rasa/cli/scaffold.py +46 -2
- rasa/core/actions/action.py +0 -1
- rasa/core/available_agents.py +2 -0
- rasa/core/available_endpoints.py +17 -2
- rasa/core/channels/channel.py +4 -3
- rasa/core/channels/constants.py +3 -0
- rasa/core/channels/development_inspector.py +2 -22
- rasa/core/channels/inspector/README.md +26 -14
- rasa/core/channels/inspector/dist/assets/{arc-cce7e0a8.js → arc-edef10dd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-e2a49be7.js → blockDiagram-38ab4fdb-49f6762b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-3def7895.js → c4Diagram-3d4e48cf-313c08e6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-63aa27d1.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-e66fe4df.js → classDiagram-70f12bd4-35e41ce9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-eb874aaa.js → classDiagram-v2-f2320105-f346068d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-5566bae8.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-cf934643.js → createText-2e5e7dd3-7a44bce8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-8fdf9155.js → edges-e0da2a9e-d7cf78c7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-6106fb96.js → erDiagram-9861fffd-9813e81c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-4c2bb040.js → flowDb-956e92f1-d8ba0870.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-f0ff96af.js → flowDiagram-66a62f08-51f0db4d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-32936074.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-a21707ec.js → flowchart-elk-definition-4a651766-ff9ea384.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-c165acb1.js → ganttDiagram-c361ad54-a8e13b6b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-b0564cf1.js → gitGraphDiagram-72cf32ee-3b171c6d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-e557e67a.js → graph-790ef78b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-1ce60e9e.js → index-3862675e-ecdce073.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-d705da80.js +1352 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-893569e2.js → infoDiagram-f8f76790-f5a422fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-c29c864f.js → journeyDiagram-49397b02-3185b7ac.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-649a5eae.js → layout-837fd3aa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-0e5685ed.js → line-7e05afcb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-eaa320bd.js → linear-162eb295.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-f35df9e6.js → mindmap-definition-fc14e90a-f4978aee.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-78339e96.js → pieDiagram-8a3498a8-b25d0a52.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-9b5f2f14.js → quadrantDiagram-120e2f19-63db1afa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-d05ddb3a.js → requirementDiagram-deff3bca-1b486cc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-d9be5dfd.js → sankeyDiagram-04a897e0-7e795291.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-0f1c4348.js → sequenceDiagram-704730f1-b8aba159.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-9ddf63b3.js → stateDiagram-587899a1-41529fd5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-bc2b81ed.js → stateDiagram-v2-d93cdb3a-b241043c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-0a287936.js → styles-6aaf32cf-b5b53234.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-e3941990.js → styles-9a916d00-13d138e5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-ce4eca24.js → styles-c10674c1-94cbde3f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-d822b1a8.js → svgDrawCommon-08f97a94-453ae764.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-e144c7a7.js → timeline-definition-85554ec2-8dcb88a4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-ab7f4e14.js → xychartDiagram-e933f94c-376af5f0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +2 -2
- rasa/core/channels/inspector/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +16 -42
- rasa/core/channels/inspector/src/components/Chat.tsx +2 -3
- rasa/core/channels/inspector/src/components/DialogueHistoryStack.tsx +1 -0
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +20 -3
- rasa/core/channels/inspector/src/components/LatencyDisplay.tsx +63 -35
- rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +14 -0
- rasa/core/channels/inspector/src/types.ts +32 -7
- rasa/core/channels/socketio.py +212 -51
- rasa/core/channels/studio_chat.py +59 -57
- rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
- rasa/core/channels/voice_stream/asr/azure.py +6 -3
- rasa/core/channels/voice_stream/asr/deepgram.py +1 -1
- rasa/core/channels/voice_stream/audiocodes.py +3 -0
- rasa/core/channels/voice_stream/browser_audio.py +53 -3
- rasa/core/channels/voice_stream/genesys.py +2 -1
- rasa/core/channels/voice_stream/jambonz.py +9 -1
- rasa/core/channels/voice_stream/twilio_media_streams.py +16 -0
- rasa/core/channels/voice_stream/voice_channel.py +66 -3
- rasa/core/constants.py +6 -0
- rasa/core/iam_credentials_providers/__init__.py +0 -0
- rasa/core/iam_credentials_providers/aws_iam_credentials_providers.py +66 -0
- rasa/core/iam_credentials_providers/credentials_provider_protocol.py +89 -0
- rasa/core/policies/enterprise_search_policy.py +4 -7
- rasa/core/policies/flows/flow_executor.py +14 -5
- rasa/core/policies/ted_policy.py +7 -5
- rasa/core/processor.py +32 -0
- rasa/core/redis_connection_factory.py +411 -0
- rasa/core/run.py +13 -3
- rasa/core/tracker_stores/redis_tracker_store.py +32 -14
- rasa/core/tracker_stores/sql_tracker_store.py +57 -1
- rasa/dialogue_understanding/generator/flow_retrieval.py +10 -9
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +10 -5
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +10 -5
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +20 -12
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +19 -12
- rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +6 -35
- rasa/dialogue_understanding/patterns/cancel.py +27 -6
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +1 -1
- rasa/dialogue_understanding/processor/command_processor.py +35 -0
- rasa/engine/graph.py +5 -1
- rasa/engine/recipes/default_components.py +78 -10
- rasa/engine/recipes/default_recipe.py +41 -1
- rasa/engine/storage/local_model_storage.py +83 -3
- rasa/graph_components/validators/default_recipe_validator.py +153 -135
- rasa/model_manager/model_api.py +4 -5
- rasa/model_manager/runner_service.py +1 -1
- rasa/model_manager/socket_bridge.py +20 -15
- rasa/model_manager/trainer_service.py +12 -9
- rasa/model_manager/utils.py +1 -29
- rasa/model_manager/warm_rasa_process.py +1 -1
- rasa/model_training.py +14 -0
- rasa/nlu/classifiers/diet_classifier.py +22 -6
- rasa/nlu/classifiers/logistic_regression_classifier.py +18 -0
- rasa/nlu/extractors/extractor.py +1 -2
- rasa/shared/agents/auth/__init__.py +0 -0
- rasa/shared/agents/auth/agent_auth_factory.py +74 -0
- rasa/shared/agents/auth/agent_auth_manager.py +86 -0
- rasa/shared/agents/auth/auth_strategy/__init__.py +19 -0
- rasa/shared/agents/auth/auth_strategy/agent_auth_strategy.py +52 -0
- rasa/shared/agents/auth/auth_strategy/api_key_auth_strategy.py +42 -0
- rasa/shared/agents/auth/auth_strategy/bearer_token_auth_strategy.py +28 -0
- rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +159 -0
- rasa/shared/agents/auth/constants.py +11 -0
- rasa/shared/agents/auth/types.py +11 -0
- rasa/shared/core/constants.py +1 -0
- rasa/shared/core/domain.py +58 -11
- rasa/shared/core/events.py +2 -0
- rasa/shared/core/flows/constants.py +5 -0
- rasa/shared/core/flows/flow_step.py +7 -1
- rasa/shared/core/flows/flows_list.py +6 -0
- rasa/shared/core/flows/steps/call.py +15 -12
- rasa/shared/core/flows/validation.py +238 -44
- rasa/shared/core/flows/yaml_flows_io.py +15 -6
- rasa/shared/core/slots.py +4 -0
- rasa/shared/exceptions.py +12 -0
- rasa/shared/importers/importer.py +6 -0
- rasa/shared/importers/utils.py +77 -1
- rasa/shared/nlu/training_data/schemas/responses.yml +3 -0
- rasa/shared/providers/_utils.py +60 -44
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +2 -0
- rasa/shared/providers/llm/_base_litellm_client.py +2 -2
- rasa/shared/providers/llm/default_litellm_llm_client.py +2 -0
- rasa/shared/providers/llm/llm_response.py +4 -4
- rasa/shared/utils/common.py +24 -0
- rasa/shared/utils/llm.py +2 -1
- rasa/shared/utils/mcp/server_connection.py +84 -23
- rasa/shared/utils/mcp/utils.py +20 -0
- rasa/studio/upload.py +16 -47
- rasa/telemetry.py +97 -23
- rasa/tracing/config.py +38 -12
- rasa/tracing/instrumentation/attribute_extractors.py +5 -1
- rasa/tracing/instrumentation/instrumentation.py +85 -8
- rasa/utils/common.py +1 -1
- rasa/utils/io.py +27 -9
- rasa/utils/json_utils.py +6 -1
- rasa/utils/log_utils.py +5 -1
- rasa/utils/openapi.py +144 -0
- rasa/utils/tensorflow/__init__.py +29 -0
- rasa/utils/tensorflow/callback.py +1 -1
- rasa/utils/tensorflow/crf.py +1 -1
- rasa/utils/tensorflow/data_generator.py +21 -8
- rasa/utils/tensorflow/layers.py +11 -4
- rasa/utils/tensorflow/metrics.py +7 -3
- rasa/utils/tensorflow/models.py +41 -6
- rasa/utils/tensorflow/rasa_layers.py +6 -4
- rasa/utils/tensorflow/transformer.py +2 -3
- rasa/utils/train_utils.py +68 -38
- rasa/validator.py +18 -16
- rasa/version.py +1 -1
- rasa_pro-3.14.0.dev8.dist-info/METADATA +199 -0
- {rasa_pro-3.14.0.dev7.dist-info → rasa_pro-3.14.0.dev8.dist-info}/RECORD +466 -156
- rasa/core/channels/inspector/dist/assets/channel-858c2c20.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-4b80996c.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-16f09b7a.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-996fe816.js +0 -1353
- rasa_pro-3.14.0.dev7.dist-info/METADATA +0 -190
- {rasa_pro-3.14.0.dev7.dist-info → rasa_pro-3.14.0.dev8.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0.dev7.dist-info → rasa_pro-3.14.0.dev8.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0.dev7.dist-info → rasa_pro-3.14.0.dev8.dist-info}/entry_points.txt +0 -0
rasa/utils/io.py
CHANGED
|
@@ -26,6 +26,7 @@ from typing_extensions import Protocol
|
|
|
26
26
|
|
|
27
27
|
import rasa.shared.constants
|
|
28
28
|
import rasa.shared.utils.io
|
|
29
|
+
from rasa.shared.exceptions import RasaException
|
|
29
30
|
|
|
30
31
|
if TYPE_CHECKING:
|
|
31
32
|
from prompt_toolkit.validation import Validator
|
|
@@ -124,9 +125,7 @@ def create_path(file_path: Text) -> None:
|
|
|
124
125
|
def file_type_validator(
|
|
125
126
|
valid_file_types: List[Text], error_message: Text
|
|
126
127
|
) -> Type["Validator"]:
|
|
127
|
-
"""Creates a
|
|
128
|
-
file paths.
|
|
129
|
-
"""
|
|
128
|
+
"""Creates a file type validator class for the questionary package."""
|
|
130
129
|
|
|
131
130
|
def is_valid(path: Text) -> bool:
|
|
132
131
|
return path is not None and any(
|
|
@@ -137,9 +136,7 @@ def file_type_validator(
|
|
|
137
136
|
|
|
138
137
|
|
|
139
138
|
def not_empty_validator(error_message: Text) -> Type["Validator"]:
|
|
140
|
-
"""Creates a
|
|
141
|
-
that the user entered something other than whitespace.
|
|
142
|
-
"""
|
|
139
|
+
"""Creates a not empty validator class for the questionary package."""
|
|
143
140
|
|
|
144
141
|
def is_valid(input: Text) -> bool:
|
|
145
142
|
return input is not None and input.strip() != ""
|
|
@@ -150,9 +147,7 @@ def not_empty_validator(error_message: Text) -> Type["Validator"]:
|
|
|
150
147
|
def create_validator(
|
|
151
148
|
function: Callable[[Text], bool], error_message: Text
|
|
152
149
|
) -> Type["Validator"]:
|
|
153
|
-
"""Helper method to create
|
|
154
|
-
removed when questionary supports `Validator` objects.
|
|
155
|
-
"""
|
|
150
|
+
"""Helper method to create a validator class from a callable function."""
|
|
156
151
|
from prompt_toolkit.document import Document
|
|
157
152
|
from prompt_toolkit.validation import ValidationError, Validator
|
|
158
153
|
|
|
@@ -250,3 +245,26 @@ def write_yaml(
|
|
|
250
245
|
|
|
251
246
|
with Path(target).open("w", encoding="utf-8") as outfile:
|
|
252
247
|
dumper.dump(data, outfile, transform=transform)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class InvalidPathException(RasaException):
|
|
251
|
+
"""Raised if a path is invalid - e.g. path traversal is detected."""
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def subpath(parent: str, child: str) -> str:
|
|
255
|
+
"""Return the path to the child directory of the parent directory.
|
|
256
|
+
|
|
257
|
+
Ensures, that child doesn't navigate to parent directories. Prevents
|
|
258
|
+
path traversal. Raises an InvalidPathException if the path is invalid.
|
|
259
|
+
|
|
260
|
+
Based on Snyk's directory traversal mitigation:
|
|
261
|
+
https://learn.snyk.io/lesson/directory-traversal/
|
|
262
|
+
"""
|
|
263
|
+
safe_path = os.path.abspath(os.path.join(parent, child))
|
|
264
|
+
parent = os.path.abspath(parent)
|
|
265
|
+
|
|
266
|
+
common_base = os.path.commonpath([parent, safe_path])
|
|
267
|
+
if common_base != parent:
|
|
268
|
+
raise InvalidPathException(f"Invalid path: {safe_path}")
|
|
269
|
+
|
|
270
|
+
return safe_path
|
rasa/utils/json_utils.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from decimal import Decimal
|
|
3
|
-
from typing import Any, Text
|
|
3
|
+
from typing import Any, Dict, List, Text
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class DecimalEncoder(json.JSONEncoder):
|
|
@@ -58,3 +58,8 @@ def replace_decimals_with_floats(obj: Any) -> Any:
|
|
|
58
58
|
Input `obj` with all `Decimal` types replaced by `float`s.
|
|
59
59
|
"""
|
|
60
60
|
return json.loads(json.dumps(obj, cls=DecimalEncoder))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def extract_values(data: Dict, keys: List[Text]) -> Dict:
|
|
64
|
+
"""Extracts values for given keys from a dictionary."""
|
|
65
|
+
return {key: data.get(key) for key in keys if data.get(key)}
|
rasa/utils/log_utils.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import logging
|
|
4
4
|
import os
|
|
5
5
|
import sys
|
|
6
|
-
from typing import Any, Optional
|
|
6
|
+
from typing import Any, List, Optional
|
|
7
7
|
|
|
8
8
|
import structlog
|
|
9
9
|
from structlog.dev import ConsoleRenderer
|
|
@@ -37,6 +37,7 @@ class HumanConsoleRenderer(ConsoleRenderer):
|
|
|
37
37
|
def configure_structlog(
|
|
38
38
|
log_level: Optional[int] = None,
|
|
39
39
|
include_time: bool = False,
|
|
40
|
+
additional_processors: Optional[List[structlog.typing.Processor]] = None,
|
|
40
41
|
) -> None:
|
|
41
42
|
"""Configure logging of the server."""
|
|
42
43
|
if log_level is None: # Log level NOTSET is 0 so we use `is None` here
|
|
@@ -75,6 +76,9 @@ def configure_structlog(
|
|
|
75
76
|
if include_time:
|
|
76
77
|
shared_processors.append(structlog.processors.TimeStamper(fmt="iso"))
|
|
77
78
|
|
|
79
|
+
if additional_processors:
|
|
80
|
+
shared_processors.extend(additional_processors)
|
|
81
|
+
|
|
78
82
|
if not FORCE_JSON_LOGGING and sys.stderr.isatty():
|
|
79
83
|
# Pretty printing when we run in a terminal session.
|
|
80
84
|
# Automatically prints pretty tracebacks when "rich" is installed
|
rasa/utils/openapi.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Type
|
|
2
|
+
|
|
3
|
+
from pydantic.main import BaseModel
|
|
4
|
+
from sanic_openapi import openapi
|
|
5
|
+
from sanic_openapi.openapi3.types import Schema
|
|
6
|
+
|
|
7
|
+
_SUPPORTED_ATTRIBUTES = frozenset(["format", "enum", "required", "example"])
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _to_schema(
|
|
11
|
+
definition_stack: List[str], schema_def: Dict[str, Any], definitions: Dict[str, Any]
|
|
12
|
+
) -> Schema:
|
|
13
|
+
type = schema_def.get("type")
|
|
14
|
+
|
|
15
|
+
if type == "object":
|
|
16
|
+
properties_spec = schema_def.get("properties", {})
|
|
17
|
+
properties = {}
|
|
18
|
+
for key in properties_spec:
|
|
19
|
+
properties[key] = _to_schema(
|
|
20
|
+
definition_stack=definition_stack,
|
|
21
|
+
schema_def=properties_spec[key],
|
|
22
|
+
definitions=definitions,
|
|
23
|
+
)
|
|
24
|
+
schema = openapi.Object(
|
|
25
|
+
title=schema_def.get("title"),
|
|
26
|
+
description=schema_def.get("description"),
|
|
27
|
+
required=schema_def.get("required"),
|
|
28
|
+
properties=properties,
|
|
29
|
+
)
|
|
30
|
+
elif type == "array":
|
|
31
|
+
schema = openapi.Array(
|
|
32
|
+
description=schema_def.get("description"),
|
|
33
|
+
required=schema_def.get("required"),
|
|
34
|
+
items=_to_schema(
|
|
35
|
+
definition_stack=definition_stack,
|
|
36
|
+
schema_def=schema_def.get("items"),
|
|
37
|
+
definitions=definitions,
|
|
38
|
+
),
|
|
39
|
+
)
|
|
40
|
+
elif type is None:
|
|
41
|
+
if allof_spec := schema_def.get("allOf"): # Model, Enum
|
|
42
|
+
definition = allof_spec[0]["$ref"].split("/")[-1]
|
|
43
|
+
definition_data = definitions.get(definition)
|
|
44
|
+
if definition_data is None:
|
|
45
|
+
schema = openapi.Object(
|
|
46
|
+
title=definition, description=schema_def.get("description")
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
schema = (
|
|
50
|
+
_to_schema(
|
|
51
|
+
definition_stack=definition_stack + [definition],
|
|
52
|
+
schema_def={**definition_data},
|
|
53
|
+
definitions=definitions,
|
|
54
|
+
)
|
|
55
|
+
if definition not in definition_stack
|
|
56
|
+
else openapi.Object(
|
|
57
|
+
title=definition, description=schema_def.get("description")
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
elif anyof_spec := schema_def.get("anyOf"): # Union
|
|
62
|
+
anyof = []
|
|
63
|
+
for any in anyof_spec:
|
|
64
|
+
if any.get("type"):
|
|
65
|
+
schema_type_obj = Schema(
|
|
66
|
+
**{
|
|
67
|
+
"type": any.get("type"),
|
|
68
|
+
"description": any.get("description"),
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
anyof.append(schema_type_obj)
|
|
72
|
+
else:
|
|
73
|
+
definition = any["$ref"].split("/")[-1]
|
|
74
|
+
if definition not in definition_stack:
|
|
75
|
+
definition_data = definitions.get(definition)
|
|
76
|
+
if definition_data is not None:
|
|
77
|
+
anyof.append(
|
|
78
|
+
_to_schema(
|
|
79
|
+
definition_stack=definition_stack + [definition],
|
|
80
|
+
schema_def=definition_data,
|
|
81
|
+
definitions=definitions,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
anyof.append(
|
|
86
|
+
openapi.Object(
|
|
87
|
+
title=definition,
|
|
88
|
+
description=schema_def.get(
|
|
89
|
+
"description", definition
|
|
90
|
+
),
|
|
91
|
+
properties={},
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
anyof.append(
|
|
96
|
+
openapi.Object(
|
|
97
|
+
title=definition,
|
|
98
|
+
description=schema_def.get("description", definition),
|
|
99
|
+
properties={},
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
schema = Schema(anyOf=anyof)
|
|
103
|
+
elif ref := schema_def.get("$ref"): # $ref
|
|
104
|
+
definition = ref.split("/")[-1]
|
|
105
|
+
definition_data = definitions.get(definition)
|
|
106
|
+
if definition_data is not None:
|
|
107
|
+
schema = _to_schema(
|
|
108
|
+
definition_stack=definition_stack,
|
|
109
|
+
schema_def=definition_data,
|
|
110
|
+
definitions=definitions,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
schema = openapi.Object(
|
|
114
|
+
title=definition, description=schema_def.get("description")
|
|
115
|
+
)
|
|
116
|
+
else: # Any type
|
|
117
|
+
schema = Schema(
|
|
118
|
+
**{"type": "object", "description": schema_def.get("description")}
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
schema_spec = {
|
|
123
|
+
"type": schema_def.get("type"),
|
|
124
|
+
"description": schema_def.get("description"),
|
|
125
|
+
}
|
|
126
|
+
for spec in _SUPPORTED_ATTRIBUTES:
|
|
127
|
+
if schema_def.get(spec):
|
|
128
|
+
schema_spec[spec] = schema_def.get(spec)
|
|
129
|
+
schema = Schema(**schema_spec)
|
|
130
|
+
|
|
131
|
+
return schema
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def model_to_schema(model: Type[BaseModel]) -> Schema:
|
|
135
|
+
schema = model.model_json_schema()
|
|
136
|
+
# Handle both $defs (newer JSON Schema) and definitions (older JSON Schema)
|
|
137
|
+
definitions = schema.get("$defs") or schema.get("definitions") or {}
|
|
138
|
+
return _to_schema(
|
|
139
|
+
definition_stack=[],
|
|
140
|
+
schema_def=dict(
|
|
141
|
+
filter(lambda key: key[0] not in ("definitions", "$defs"), schema.items())
|
|
142
|
+
),
|
|
143
|
+
definitions=definitions,
|
|
144
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""TensorFlow utilities for Rasa components.
|
|
2
|
+
|
|
3
|
+
This module provides TensorFlow-related utilities and components.
|
|
4
|
+
TensorFlow is an optional dependency and should be installed separately
|
|
5
|
+
if you want to use components that require it. These are:
|
|
6
|
+
- DIETClassifier
|
|
7
|
+
- TEDPolicy
|
|
8
|
+
- UnexpecTEDIntentPolicy
|
|
9
|
+
- ResponseSelector
|
|
10
|
+
- ConveRTFeaturizer
|
|
11
|
+
- LanguageModelFeaturizer
|
|
12
|
+
|
|
13
|
+
To install Rasa with TensorFlow support:
|
|
14
|
+
`pip install "rasa[tensorflow]"`
|
|
15
|
+
|
|
16
|
+
To install it with poetry:
|
|
17
|
+
`poetry install --extras tensorflow`
|
|
18
|
+
|
|
19
|
+
For macOS with Apple Silicon (M1/M2) (platform-specific TensorFlow installation):
|
|
20
|
+
`pip install "rasa[tensorflow,metal]"`
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import importlib.util
|
|
24
|
+
import logging
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# check if TensorFlow is available
|
|
29
|
+
TENSORFLOW_AVAILABLE = importlib.util.find_spec("tensorflow") is not None
|
|
@@ -55,7 +55,7 @@ class RasaModelCheckpoint(tf.keras.callbacks.Callback):
|
|
|
55
55
|
"""
|
|
56
56
|
super().__init__()
|
|
57
57
|
|
|
58
|
-
self.checkpoint_file = checkpoint_dir / "checkpoint.
|
|
58
|
+
self.checkpoint_file = checkpoint_dir / "checkpoint.weights.h5"
|
|
59
59
|
self.best_metrics_so_far: Dict[Text, Any] = {}
|
|
60
60
|
|
|
61
61
|
def on_epoch_end(self, epoch: int, logs: Optional[Dict[Text, Any]] = None) -> None:
|
rasa/utils/tensorflow/crf.py
CHANGED
|
@@ -9,7 +9,7 @@ from tensorflow.types.experimental import TensorLike
|
|
|
9
9
|
# (modified to our neeeds)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class CrfDecodeForwardRnnCell(tf.keras.layers.
|
|
12
|
+
class CrfDecodeForwardRnnCell(tf.keras.layers.Layer):
|
|
13
13
|
"""Computes the forward decoding in a linear-chain CRF."""
|
|
14
14
|
|
|
15
15
|
def __init__(self, transition_params: TensorLike, **kwargs: Any) -> None:
|
|
@@ -71,13 +71,22 @@ class RasaDataGenerator(Sequence):
|
|
|
71
71
|
# balancing on the next epoch
|
|
72
72
|
return data
|
|
73
73
|
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _create_default_array() -> np.ndarray:
|
|
76
|
+
"""Create a default empty array for missing features.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A default empty array with shape (0, 1) and dtype float32.
|
|
80
|
+
"""
|
|
81
|
+
return np.zeros((0, 1), dtype=np.float32)
|
|
82
|
+
|
|
74
83
|
@staticmethod
|
|
75
84
|
def prepare_batch(
|
|
76
85
|
data: Data,
|
|
77
86
|
start: Optional[int] = None,
|
|
78
87
|
end: Optional[int] = None,
|
|
79
88
|
tuple_sizes: Optional[Dict[Text, int]] = None,
|
|
80
|
-
) -> Tuple[
|
|
89
|
+
) -> Tuple[np.ndarray, ...]:
|
|
81
90
|
"""Slices model data into batch using given start and end value.
|
|
82
91
|
|
|
83
92
|
Args:
|
|
@@ -85,8 +94,8 @@ class RasaDataGenerator(Sequence):
|
|
|
85
94
|
start: The start index of the batch
|
|
86
95
|
end: The end index of the batch
|
|
87
96
|
tuple_sizes: In case the feature is not present we propagate the batch with
|
|
88
|
-
|
|
89
|
-
what kind of feature.
|
|
97
|
+
default arrays. Tuple sizes contains the number of how many default values
|
|
98
|
+
to add for what kind of feature.
|
|
90
99
|
|
|
91
100
|
Returns:
|
|
92
101
|
The features of the batch.
|
|
@@ -95,12 +104,14 @@ class RasaDataGenerator(Sequence):
|
|
|
95
104
|
|
|
96
105
|
for key, attribute_data in data.items():
|
|
97
106
|
for sub_key, f_data in attribute_data.items():
|
|
98
|
-
# add
|
|
107
|
+
# add default arrays for not present values during processing
|
|
99
108
|
if not f_data:
|
|
100
109
|
if tuple_sizes:
|
|
101
|
-
batch_data += [
|
|
110
|
+
batch_data += [
|
|
111
|
+
RasaDataGenerator._create_default_array()
|
|
112
|
+
] * tuple_sizes[key]
|
|
102
113
|
else:
|
|
103
|
-
batch_data.append(
|
|
114
|
+
batch_data.append(RasaDataGenerator._create_default_array())
|
|
104
115
|
continue
|
|
105
116
|
|
|
106
117
|
for v in f_data:
|
|
@@ -409,8 +420,10 @@ class RasaBatchDataGenerator(RasaDataGenerator):
|
|
|
409
420
|
end = start + self._current_batch_size
|
|
410
421
|
|
|
411
422
|
# return input and target data, as our target data is inside the input
|
|
412
|
-
# data return
|
|
413
|
-
return self.prepare_batch(
|
|
423
|
+
# data return default array for the target data
|
|
424
|
+
return self.prepare_batch(
|
|
425
|
+
self._data, start, end
|
|
426
|
+
), RasaDataGenerator._create_default_array()
|
|
414
427
|
|
|
415
428
|
def on_epoch_end(self) -> None:
|
|
416
429
|
"""Update the data after every epoch."""
|
rasa/utils/tensorflow/layers.py
CHANGED
|
@@ -3,9 +3,7 @@ from typing import Any, Callable, List, Optional, Text, Tuple, Union
|
|
|
3
3
|
|
|
4
4
|
import tensorflow as tf
|
|
5
5
|
import tensorflow.keras.backend as K
|
|
6
|
-
|
|
7
|
-
# TODO: The following is not (yet) available via tf.keras
|
|
8
|
-
from keras.src.utils.control_flow_util import smart_cond
|
|
6
|
+
from tensorflow.python.keras.utils.control_flow_util import smart_cond
|
|
9
7
|
|
|
10
8
|
import rasa.utils.tensorflow.crf
|
|
11
9
|
import rasa.utils.tensorflow.layers_utils as layers_utils
|
|
@@ -278,6 +276,7 @@ class RandomlyConnectedDense(tf.keras.layers.Dense):
|
|
|
278
276
|
kernel_constraint: Constraint function applied to
|
|
279
277
|
the `kernel` weights matrix.
|
|
280
278
|
bias_constraint: Constraint function applied to the bias vector.
|
|
279
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
281
280
|
"""
|
|
282
281
|
super().__init__(**kwargs)
|
|
283
282
|
|
|
@@ -367,7 +366,12 @@ class RandomlyConnectedDense(tf.keras.layers.Dense):
|
|
|
367
366
|
Returns:
|
|
368
367
|
The processed inputs.
|
|
369
368
|
"""
|
|
370
|
-
if
|
|
369
|
+
# Apply kernel masking if needed (Keras 3.x compatibility check)
|
|
370
|
+
if (
|
|
371
|
+
self.density < 1.0
|
|
372
|
+
and hasattr(self, "kernel_mask")
|
|
373
|
+
and self.kernel_mask is not None
|
|
374
|
+
):
|
|
371
375
|
# Set fraction of the `kernel` weights to zero according to precomputed mask
|
|
372
376
|
self.kernel.assign(self.kernel * self.kernel_mask)
|
|
373
377
|
return super().call(inputs)
|
|
@@ -724,6 +728,7 @@ class DotProductLoss(tf.keras.layers.Layer):
|
|
|
724
728
|
Currently, the only possible value is `SOFTMAX`.
|
|
725
729
|
similarity_type: Similarity measure to use, either `cosine` or `inner`.
|
|
726
730
|
name: Optional name of the layer.
|
|
731
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
727
732
|
|
|
728
733
|
Raises:
|
|
729
734
|
TFLayerConfigException: When `similarity_type` is not one of `COSINE` or
|
|
@@ -883,6 +888,7 @@ class SingleLabelDotProductLoss(DotProductLoss):
|
|
|
883
888
|
values are approximately bounded.
|
|
884
889
|
model_confidence: Normalization of confidence values during inference.
|
|
885
890
|
Currently, the only possible value is `SOFTMAX`.
|
|
891
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
886
892
|
"""
|
|
887
893
|
super().__init__(
|
|
888
894
|
num_candidates,
|
|
@@ -1244,6 +1250,7 @@ class MultiLabelDotProductLoss(DotProductLoss):
|
|
|
1244
1250
|
Used inside _loss_cross_entropy() only.
|
|
1245
1251
|
model_confidence: Normalization of confidence values during inference.
|
|
1246
1252
|
Currently, the only possible value is `SOFTMAX`.
|
|
1253
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
1247
1254
|
"""
|
|
1248
1255
|
super().__init__(
|
|
1249
1256
|
num_candidates,
|
rasa/utils/tensorflow/metrics.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Any, Dict, Optional
|
|
2
2
|
|
|
3
3
|
import tensorflow as tf
|
|
4
|
-
from tensorflow.keras import backend as K
|
|
5
4
|
from tensorflow.types.experimental import TensorLike
|
|
6
5
|
|
|
7
6
|
# original code taken from
|
|
@@ -118,7 +117,7 @@ class FBetaScore(tf.keras.metrics.Metric):
|
|
|
118
117
|
|
|
119
118
|
def _zero_wt_init(name: Any) -> Any:
|
|
120
119
|
return self.add_weight(
|
|
121
|
-
name, shape=self.init_shape, initializer="zeros", dtype=self.dtype
|
|
120
|
+
name=name, shape=self.init_shape, initializer="zeros", dtype=self.dtype
|
|
122
121
|
)
|
|
123
122
|
|
|
124
123
|
self.true_positives = _zero_wt_init("true_positives")
|
|
@@ -197,7 +196,12 @@ class FBetaScore(tf.keras.metrics.Metric):
|
|
|
197
196
|
|
|
198
197
|
def reset_state(self) -> None:
|
|
199
198
|
reset_value = tf.zeros(self.init_shape, dtype=self.dtype)
|
|
200
|
-
|
|
199
|
+
# In Keras 3.x, self.variables contains string names, not variable objects,
|
|
200
|
+
# so each metric variable is reset using assign() instead of K.batch_set_value()
|
|
201
|
+
self.true_positives.assign(reset_value)
|
|
202
|
+
self.false_positives.assign(reset_value)
|
|
203
|
+
self.false_negatives.assign(reset_value)
|
|
204
|
+
self.weights_intermediate.assign(reset_value)
|
|
201
205
|
|
|
202
206
|
def reset_states(self) -> None:
|
|
203
207
|
# Backwards compatibility alias of `reset_state`. New classes should
|
rasa/utils/tensorflow/models.py
CHANGED
|
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple, Union
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
from keras import Model
|
|
11
|
-
from keras.src.utils import tf_utils
|
|
12
11
|
|
|
13
12
|
import rasa.utils.train_utils
|
|
14
13
|
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
@@ -78,6 +77,7 @@ class RasaModel(Model):
|
|
|
78
77
|
|
|
79
78
|
Args:
|
|
80
79
|
random_seed: set the random seed to get reproducible results
|
|
80
|
+
**kwargs: Additional keyword arguments passed to the parent class
|
|
81
81
|
"""
|
|
82
82
|
# make sure that keras releases resources from previously trained model
|
|
83
83
|
tf.keras.backend.clear_session()
|
|
@@ -273,7 +273,8 @@ class RasaModel(Model):
|
|
|
273
273
|
if self._run_eagerly:
|
|
274
274
|
# Once we take advantage of TF's distributed training, this is where
|
|
275
275
|
# scheduled functions will be forced to execute and return actual values.
|
|
276
|
-
|
|
276
|
+
step_output = self.predict_step(batch_in)
|
|
277
|
+
outputs = self._convert_tensors_to_numpy(step_output)
|
|
277
278
|
if DIAGNOSTIC_DATA in outputs:
|
|
278
279
|
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
|
|
279
280
|
outputs[DIAGNOSTIC_DATA]
|
|
@@ -287,9 +288,8 @@ class RasaModel(Model):
|
|
|
287
288
|
|
|
288
289
|
# Once we take advantage of TF's distributed training, this is where
|
|
289
290
|
# scheduled functions will be forced to execute and return actual values.
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
)
|
|
291
|
+
step_output = self._tf_predict_step(list(batch_in))
|
|
292
|
+
outputs = self._convert_tensors_to_numpy(step_output)
|
|
293
293
|
if DIAGNOSTIC_DATA in outputs:
|
|
294
294
|
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
|
|
295
295
|
outputs[DIAGNOSTIC_DATA]
|
|
@@ -388,6 +388,23 @@ class RasaModel(Model):
|
|
|
388
388
|
|
|
389
389
|
return {k: _recurse(v) for k, v in input_dict.items()}
|
|
390
390
|
|
|
391
|
+
def _convert_tensors_to_numpy(
|
|
392
|
+
self, step_output: Dict[Text, Any]
|
|
393
|
+
) -> Dict[Text, Any]:
|
|
394
|
+
"""Convert TensorFlow tensors to numpy arrays for Keras 3.x compatibility.
|
|
395
|
+
|
|
396
|
+
Replaces the deprecated tf_utils.sync_to_numpy_or_python_type() function.
|
|
397
|
+
Converts tensors (objects with 'numpy' method) to numpy arrays,
|
|
398
|
+
leaves others unchanged.
|
|
399
|
+
"""
|
|
400
|
+
outputs = {}
|
|
401
|
+
for key, value in step_output.items():
|
|
402
|
+
if hasattr(value, "numpy"):
|
|
403
|
+
outputs[key] = value.numpy()
|
|
404
|
+
else:
|
|
405
|
+
outputs[key] = value
|
|
406
|
+
return outputs
|
|
407
|
+
|
|
391
408
|
def _get_metric_results(self, prefix: Optional[Text] = "") -> Dict[Text, float]:
|
|
392
409
|
return {
|
|
393
410
|
f"{prefix}{metric.name}": metric.result()
|
|
@@ -403,7 +420,18 @@ class RasaModel(Model):
|
|
|
403
420
|
overwrite: If 'True' an already existing model with the same file name will
|
|
404
421
|
be overwritten.
|
|
405
422
|
"""
|
|
406
|
-
|
|
423
|
+
# Ensure filename ends with .weights.h5 and model is built for Keras 3.x
|
|
424
|
+
# compatibility
|
|
425
|
+
model_file_name = str(model_file_name)
|
|
426
|
+
if not model_file_name.endswith(".weights.h5"):
|
|
427
|
+
model_file_name += ".weights.h5"
|
|
428
|
+
|
|
429
|
+
if not self.built:
|
|
430
|
+
import tensorflow as tf
|
|
431
|
+
|
|
432
|
+
_ = self(tf.zeros((1, 1)))
|
|
433
|
+
|
|
434
|
+
self.save_weights(model_file_name, overwrite=overwrite)
|
|
407
435
|
|
|
408
436
|
@classmethod
|
|
409
437
|
def load(
|
|
@@ -444,6 +472,13 @@ class RasaModel(Model):
|
|
|
444
472
|
)
|
|
445
473
|
data_generator = RasaBatchDataGenerator(model_data_example, batch_size=1)
|
|
446
474
|
model.fit(data_generator, verbose=False)
|
|
475
|
+
|
|
476
|
+
# Ensure model is built before loading weights (Keras 3 compatibility)
|
|
477
|
+
if not model.built:
|
|
478
|
+
# Build the model by calling it on sample data
|
|
479
|
+
sample_batch = next(iter(data_generator))
|
|
480
|
+
_ = model(sample_batch)
|
|
481
|
+
|
|
447
482
|
# load trained weights
|
|
448
483
|
model.load_weights(model_file_name)
|
|
449
484
|
|
|
@@ -301,12 +301,12 @@ class ConcatenateSparseDenseFeatures(RasaCustomLayer):
|
|
|
301
301
|
) -> tf.Tensor:
|
|
302
302
|
"""Turns sparse tensor into dense, possibly adds dropout before and/or after."""
|
|
303
303
|
if self.SPARSE_DROPOUT in self._tf_layers:
|
|
304
|
-
feature = self._tf_layers[self.SPARSE_DROPOUT](feature, training)
|
|
304
|
+
feature = self._tf_layers[self.SPARSE_DROPOUT](feature, training=training)
|
|
305
305
|
|
|
306
306
|
feature = self._tf_layers[self.SPARSE_TO_DENSE](feature)
|
|
307
307
|
|
|
308
308
|
if self.DENSE_DROPOUT in self._tf_layers:
|
|
309
|
-
feature = self._tf_layers[self.DENSE_DROPOUT](feature, training)
|
|
309
|
+
feature = self._tf_layers[self.DENSE_DROPOUT](feature, training=training)
|
|
310
310
|
|
|
311
311
|
return feature
|
|
312
312
|
|
|
@@ -1002,7 +1002,9 @@ class RasaSequenceLayer(RasaCustomLayer):
|
|
|
1002
1002
|
]((sequence_features, sentence_features, sequence_feature_lengths))
|
|
1003
1003
|
|
|
1004
1004
|
# Apply one or more dense layers.
|
|
1005
|
-
seq_sent_features = self._tf_layers[self.FFNN](
|
|
1005
|
+
seq_sent_features = self._tf_layers[self.FFNN](
|
|
1006
|
+
seq_sent_features, training=training
|
|
1007
|
+
)
|
|
1006
1008
|
|
|
1007
1009
|
# If using masked language modeling, mask the transformer inputs and get labels
|
|
1008
1010
|
# for the masked tokens and a boolean mask. Note that TED does not use MLM loss,
|
|
@@ -1031,7 +1033,7 @@ class RasaSequenceLayer(RasaCustomLayer):
|
|
|
1031
1033
|
if self._has_transformer:
|
|
1032
1034
|
mask_padding = 1 - mask_combined_sequence_sentence
|
|
1033
1035
|
outputs, attention_weights = self._tf_layers[self.TRANSFORMER](
|
|
1034
|
-
seq_sent_features_masked, mask_padding, training
|
|
1036
|
+
seq_sent_features_masked, mask_padding, training=training
|
|
1035
1037
|
)
|
|
1036
1038
|
outputs = tf.nn.gelu(outputs)
|
|
1037
1039
|
else:
|
|
@@ -2,10 +2,8 @@ from typing import Optional, Text, Tuple, Union
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import tensorflow as tf
|
|
5
|
-
|
|
6
|
-
# TODO: The following is not (yet) available via tf.keras
|
|
7
|
-
from keras.src.utils.control_flow_util import smart_cond
|
|
8
5
|
from tensorflow.keras import backend as K
|
|
6
|
+
from tensorflow.python.keras.utils.control_flow_util import smart_cond
|
|
9
7
|
|
|
10
8
|
from rasa.utils.tensorflow.exceptions import TFLayerConfigException
|
|
11
9
|
from rasa.utils.tensorflow.layers import RandomlyConnectedDense
|
|
@@ -280,6 +278,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
|
|
|
280
278
|
value: A tensor with shape (..., length, depth).
|
|
281
279
|
pad_mask: Float tensor with shape broadcastable
|
|
282
280
|
to (..., length, length). Defaults to None.
|
|
281
|
+
training: A tensor
|
|
283
282
|
|
|
284
283
|
Returns:
|
|
285
284
|
output: A tensor with shape (..., length, depth).
|