rasa-pro 3.13.12__py3-none-any.whl → 3.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +15 -3
- rasa/agents/__init__.py +0 -0
- rasa/agents/agent_factory.py +122 -0
- rasa/agents/agent_manager.py +213 -0
- rasa/agents/constants.py +43 -0
- rasa/agents/core/__init__.py +0 -0
- rasa/agents/core/agent_protocol.py +107 -0
- rasa/agents/core/types.py +81 -0
- rasa/agents/exceptions.py +38 -0
- rasa/agents/protocol/__init__.py +5 -0
- rasa/agents/protocol/a2a/__init__.py +0 -0
- rasa/agents/protocol/a2a/a2a_agent.py +889 -0
- rasa/agents/protocol/mcp/__init__.py +0 -0
- rasa/agents/protocol/mcp/mcp_base_agent.py +778 -0
- rasa/agents/protocol/mcp/mcp_open_agent.py +327 -0
- rasa/agents/protocol/mcp/mcp_task_agent.py +522 -0
- rasa/agents/schemas/__init__.py +13 -0
- rasa/agents/schemas/agent_input.py +38 -0
- rasa/agents/schemas/agent_output.py +26 -0
- rasa/agents/schemas/agent_tool_result.py +65 -0
- rasa/agents/schemas/agent_tool_schema.py +186 -0
- rasa/agents/templates/__init__.py +0 -0
- rasa/agents/templates/mcp_open_agent_prompt_template.jinja2 +20 -0
- rasa/agents/templates/mcp_task_agent_prompt_template.jinja2 +22 -0
- rasa/agents/utils.py +228 -0
- rasa/agents/validation.py +538 -0
- rasa/api.py +23 -9
- rasa/builder/README.md +120 -0
- rasa/builder/__init__.py +0 -0
- rasa/builder/auth.py +176 -0
- rasa/builder/config.py +96 -0
- rasa/builder/copilot/__init__.py +0 -0
- rasa/builder/copilot/constants.py +38 -0
- rasa/builder/copilot/copilot.py +562 -0
- rasa/builder/copilot/copilot_response_handler.py +522 -0
- rasa/builder/copilot/copilot_templated_message_provider.py +81 -0
- rasa/builder/copilot/exceptions.py +32 -0
- rasa/builder/copilot/models.py +690 -0
- rasa/builder/copilot/prompts/__init__.py +0 -0
- rasa/builder/copilot/prompts/copilot_system_prompt.jinja2 +787 -0
- rasa/builder/copilot/prompts/copilot_training_error_handler_prompt.jinja2 +53 -0
- rasa/builder/copilot/prompts/latest_user_message_context_prompt.jinja2 +91 -0
- rasa/builder/copilot/signing.py +305 -0
- rasa/builder/copilot/telemetry.py +234 -0
- rasa/builder/copilot/templated_messages/__init__.py +0 -0
- rasa/builder/copilot/templated_messages/copilot_internal_messages_templates.yml +16 -0
- rasa/builder/copilot/templated_messages/copilot_templated_responses.yml +41 -0
- rasa/builder/copilot/templated_messages/copilot_welcome_messages.yml +56 -0
- rasa/builder/document_retrieval/__init__.py +0 -0
- rasa/builder/document_retrieval/constants.py +15 -0
- rasa/builder/document_retrieval/inkeep-rag-response-schema.json +64 -0
- rasa/builder/document_retrieval/inkeep_document_retrieval.py +238 -0
- rasa/builder/document_retrieval/models.py +62 -0
- rasa/builder/download.py +140 -0
- rasa/builder/exceptions.py +91 -0
- rasa/builder/guardrails/__init__.py +1 -0
- rasa/builder/guardrails/clients.py +256 -0
- rasa/builder/guardrails/constants.py +12 -0
- rasa/builder/guardrails/exceptions.py +4 -0
- rasa/builder/guardrails/models.py +266 -0
- rasa/builder/guardrails/policy_checker.py +324 -0
- rasa/builder/guardrails/store.py +238 -0
- rasa/builder/guardrails/utils.py +94 -0
- rasa/builder/job_manager.py +87 -0
- rasa/builder/jobs.py +609 -0
- rasa/builder/llm_service.py +273 -0
- rasa/builder/logging_utils.py +265 -0
- rasa/builder/main.py +234 -0
- rasa/builder/models.py +229 -0
- rasa/builder/project_generator.py +463 -0
- rasa/builder/project_info.py +72 -0
- rasa/builder/service.py +1367 -0
- rasa/builder/shared/tracker_context.py +212 -0
- rasa/builder/skill_to_bot_prompt.jinja2 +164 -0
- rasa/builder/template_cache.py +69 -0
- rasa/builder/training_service.py +188 -0
- rasa/builder/validation_service.py +101 -0
- rasa/cli/arguments/data.py +9 -0
- rasa/cli/arguments/default_arguments.py +12 -0
- rasa/cli/arguments/run.py +2 -0
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +78 -10
- rasa/cli/dialogue_understanding_test.py +11 -7
- rasa/cli/e2e_test.py +10 -6
- rasa/cli/evaluate.py +4 -2
- rasa/cli/export.py +5 -2
- rasa/cli/inspect.py +9 -4
- rasa/cli/interactive.py +8 -4
- rasa/cli/llm_fine_tuning.py +12 -6
- rasa/cli/project_templates/basic/README.md +23 -0
- rasa/cli/project_templates/basic/actions/__init__ +0 -0
- rasa/cli/project_templates/basic/actions/action_human_handoff.py +40 -0
- rasa/cli/project_templates/basic/actions/actions.md +10 -0
- rasa/cli/project_templates/basic/config.yml +29 -0
- rasa/cli/project_templates/basic/credentials.yml +33 -0
- rasa/cli/project_templates/basic/data/data.md +8 -0
- rasa/cli/project_templates/basic/data/general/feedback.yml +21 -0
- rasa/cli/project_templates/basic/data/general/goodbye.yml +6 -0
- rasa/cli/project_templates/basic/data/general/hello.yml +6 -0
- rasa/cli/project_templates/basic/data/general/help.yml +6 -0
- rasa/cli/project_templates/basic/data/general/human_handoff.yml +16 -0
- rasa/cli/project_templates/basic/data/general/show_faqs.yml +6 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_cannot_handle.yml +7 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_completed.yml +7 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_correction.yml +7 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_search.yml +8 -0
- rasa/cli/project_templates/basic/data/system/patterns/pattern_session_start.yml +8 -0
- rasa/cli/project_templates/basic/docs/docs.md +5 -0
- rasa/cli/project_templates/basic/docs/template.txt +28 -0
- rasa/cli/project_templates/basic/domain/domain.md +11 -0
- rasa/cli/project_templates/basic/domain/general/feedback.yml +25 -0
- rasa/cli/project_templates/basic/domain/general/goodbye.yml +9 -0
- rasa/cli/project_templates/basic/domain/general/hello.yml +7 -0
- rasa/cli/project_templates/basic/domain/general/help.yml +21 -0
- rasa/cli/project_templates/basic/domain/general/human_handoff.yml +32 -0
- rasa/cli/project_templates/basic/domain/general/show_faqs.yml +14 -0
- rasa/cli/project_templates/basic/domain/system/patterns/pattern_cannot_handle.yml +5 -0
- rasa/cli/project_templates/basic/domain/system/patterns/pattern_session_start.yml +19 -0
- rasa/cli/project_templates/basic/endpoints.yml +67 -0
- rasa/cli/project_templates/basic/prompts/rephraser_demo_personality_prompt.jinja2 +38 -0
- rasa/cli/project_templates/basic/tests/e2e_test_cases/without_stub/general/feedback.yml +46 -0
- rasa/cli/project_templates/basic/tests/e2e_test_cases/without_stub/general/goodbye.yml +9 -0
- rasa/cli/project_templates/basic/tests/e2e_test_cases/without_stub/general/hello.yml +8 -0
- rasa/cli/project_templates/basic/tests/e2e_test_cases/without_stub/general/help.yml +8 -0
- rasa/cli/project_templates/basic/tests/e2e_test_cases/without_stub/general/human_handoff.yml +41 -0
- rasa/cli/project_templates/basic/tests/e2e_test_cases/without_stub/general/patterns.yml +32 -0
- rasa/cli/project_templates/basic/tests/e2e_test_cases/without_stub/general/show_faqs.yml +8 -0
- rasa/cli/project_templates/default/config.yml +4 -0
- rasa/cli/project_templates/default/endpoints.yml +4 -0
- rasa/cli/project_templates/defaults.py +1 -0
- rasa/cli/project_templates/finance/README.md +26 -0
- rasa/cli/project_templates/finance/actions/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/accounts/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/accounts/check_balance.py +18 -0
- rasa/cli/project_templates/finance/actions/actions.md +15 -0
- rasa/cli/project_templates/finance/actions/cards/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/cards/check_that_card_exists.py +21 -0
- rasa/cli/project_templates/finance/actions/cards/list_cards.py +22 -0
- rasa/cli/project_templates/finance/actions/contacts/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/contacts/add_contact.py +30 -0
- rasa/cli/project_templates/finance/actions/contacts/list_contacts.py +22 -0
- rasa/cli/project_templates/finance/actions/contacts/remove_contact.py +35 -0
- rasa/cli/project_templates/finance/actions/db.py +117 -0
- rasa/cli/project_templates/finance/actions/general/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/general/action_human_handoff.py +49 -0
- rasa/cli/project_templates/finance/actions/transfers/__init__.py +0 -0
- rasa/cli/project_templates/finance/actions/transfers/check_transfer_funds.py +27 -0
- rasa/cli/project_templates/finance/actions/transfers/check_transfer_limit.py +36 -0
- rasa/cli/project_templates/finance/actions/transfers/execute_recurrent_payment.py +20 -0
- rasa/cli/project_templates/finance/actions/transfers/execute_transfer.py +45 -0
- rasa/cli/project_templates/finance/actions/transfers/list_transactions.py +32 -0
- rasa/cli/project_templates/finance/config.yml +29 -0
- rasa/cli/project_templates/finance/credentials.yml +33 -0
- rasa/cli/project_templates/finance/data/accounts/check_balance.yml +9 -0
- rasa/cli/project_templates/finance/data/accounts/download_statements.yml +26 -0
- rasa/cli/project_templates/finance/data/bills/bill_pay_reminder.yml +25 -0
- rasa/cli/project_templates/finance/data/cards/activate_card.yml +35 -0
- rasa/cli/project_templates/finance/data/cards/block_card.yml +45 -0
- rasa/cli/project_templates/finance/data/cards/list_cards.yml +14 -0
- rasa/cli/project_templates/finance/data/cards/replace_card.yml +16 -0
- rasa/cli/project_templates/finance/data/cards/replace_eligible_card.yml +29 -0
- rasa/cli/project_templates/finance/data/contacts/add_contact.yml +33 -0
- rasa/cli/project_templates/finance/data/contacts/list_contacts.yml +14 -0
- rasa/cli/project_templates/finance/data/contacts/remove_contact.yml +31 -0
- rasa/cli/project_templates/finance/data/data.md +14 -0
- rasa/cli/project_templates/finance/data/general/bot_challenge.yml +6 -0
- rasa/cli/project_templates/finance/data/general/feedback.yml +20 -0
- rasa/cli/project_templates/finance/data/general/goodbye.yml +6 -0
- rasa/cli/project_templates/finance/data/general/hello.yml +6 -0
- rasa/cli/project_templates/finance/data/general/help.yml +9 -0
- rasa/cli/project_templates/finance/data/general/human_handoff.yml +16 -0
- rasa/cli/project_templates/finance/data/general/welcome.yml +9 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_completed.yml +7 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_correction.yml +7 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_search.yml +8 -0
- rasa/cli/project_templates/finance/data/system/patterns/pattern_session_start.yml +8 -0
- rasa/cli/project_templates/finance/data/transfers/check_transfer_limit.yml +18 -0
- rasa/cli/project_templates/finance/data/transfers/list_transactions.yml +46 -0
- rasa/cli/project_templates/finance/data/transfers/move_money_between_accounts.yml +51 -0
- rasa/cli/project_templates/finance/data/transfers/transfer_money.yml +34 -0
- rasa/cli/project_templates/finance/data/transfers/transfer_money_to_a_third_party.yml +175 -0
- rasa/cli/project_templates/finance/db/cards.json +18 -0
- rasa/cli/project_templates/finance/db/contacts.json +10 -0
- rasa/cli/project_templates/finance/db/my_account.json +6 -0
- rasa/cli/project_templates/finance/db/transactions.json +22 -0
- rasa/cli/project_templates/finance/docs/docs.md +8 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/account_features/budgeting_analytics.txt +22 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/account_features/multi_currency_accounts.txt +19 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/account_features/premium_benefits.txt +19 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/card_management/contactless_limits.txt +16 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/card_management/freeze_unfreeze_card.txt +16 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/card_management/lost_stolen_card.txt +19 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/money_transfers/instant_payments.txt +19 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/money_transfers/international_transfers.txt +19 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/security_fraud/fraud_protection.txt +22 -0
- rasa/cli/project_templates/finance/docs/fenlo_banking_faq/security_fraud/secure_payments.txt +22 -0
- rasa/cli/project_templates/finance/domain/accounts/check_balance.yml +15 -0
- rasa/cli/project_templates/finance/domain/accounts/download_statements.yml +40 -0
- rasa/cli/project_templates/finance/domain/bills/bill_pay_reminder.yml +49 -0
- rasa/cli/project_templates/finance/domain/cards/activate_card.yml +24 -0
- rasa/cli/project_templates/finance/domain/cards/block_card.yml +44 -0
- rasa/cli/project_templates/finance/domain/cards/list_cards.yml +16 -0
- rasa/cli/project_templates/finance/domain/cards/replace_card.yml +43 -0
- rasa/cli/project_templates/finance/domain/cards/shared.yml +15 -0
- rasa/cli/project_templates/finance/domain/contacts/add_contact.yml +37 -0
- rasa/cli/project_templates/finance/domain/contacts/list_contacts.yml +16 -0
- rasa/cli/project_templates/finance/domain/contacts/remove_contact.yml +32 -0
- rasa/cli/project_templates/finance/domain/domain.md +18 -0
- rasa/cli/project_templates/finance/domain/general/_shared.yml +39 -0
- rasa/cli/project_templates/finance/domain/general/bot_challenge.yml +4 -0
- rasa/cli/project_templates/finance/domain/general/cannot_handle.yml +8 -0
- rasa/cli/project_templates/finance/domain/general/feedback.yml +25 -0
- rasa/cli/project_templates/finance/domain/general/goodbye.yml +7 -0
- rasa/cli/project_templates/finance/domain/general/help.yml +0 -0
- rasa/cli/project_templates/finance/domain/general/human_handoff.yml +31 -0
- rasa/cli/project_templates/finance/domain/general/welcome.yml +39 -0
- rasa/cli/project_templates/finance/domain/transfers/check_transfer_limit.yml +32 -0
- rasa/cli/project_templates/finance/domain/transfers/list_transactions.yml +44 -0
- rasa/cli/project_templates/finance/domain/transfers/shared.yml +17 -0
- rasa/cli/project_templates/finance/domain/transfers/transfer_money.yml +221 -0
- rasa/cli/project_templates/finance/endpoints.yml +67 -0
- rasa/cli/project_templates/finance/prompts/rephraser_demo_personality_prompt.jinja2 +38 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/accounts/check_balance.yml +9 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/accounts/download_statements.yml +43 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/cards/block_card.yml +55 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/general/bot_challenge.yml +8 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/general/feedback.yml +46 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/general/goodbye.yml +9 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/general/hello.yml +8 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/general/human_handoff.yml +35 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/general/patterns.yml +22 -0
- rasa/cli/project_templates/finance/tests/e2e_test_cases/without_stub/transfers/transfer_money.yml +56 -0
- rasa/cli/project_templates/telco/README.md +25 -0
- rasa/cli/project_templates/telco/actions/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/actions.md +12 -0
- rasa/cli/project_templates/telco/actions/billing/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/billing/actions_billing.py +204 -0
- rasa/cli/project_templates/telco/actions/general/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/general/action_human_handoff.py +49 -0
- rasa/cli/project_templates/telco/actions/network/__init__.py +0 -0
- rasa/cli/project_templates/telco/actions/network/actions_get_data_from_db.py +48 -0
- rasa/cli/project_templates/telco/actions/network/actions_run_diagnostics.py +28 -0
- rasa/cli/project_templates/telco/actions/network/actions_session_start.py +18 -0
- rasa/cli/project_templates/telco/config.yml +29 -0
- rasa/cli/project_templates/telco/credentials.yml +33 -0
- rasa/cli/project_templates/telco/csvs/billing.csv +19 -0
- rasa/cli/project_templates/telco/csvs/customers.csv +5 -0
- rasa/cli/project_templates/telco/data/billing/flow_understand_bill.yml +45 -0
- rasa/cli/project_templates/telco/data/data.md +11 -0
- rasa/cli/project_templates/telco/data/general/bot_challenge.yml +6 -0
- rasa/cli/project_templates/telco/data/general/feedback.yml +20 -0
- rasa/cli/project_templates/telco/data/general/goodbye.yml +6 -0
- rasa/cli/project_templates/telco/data/general/hello.yml +6 -0
- rasa/cli/project_templates/telco/data/general/human_handoff.yml +16 -0
- rasa/cli/project_templates/telco/data/general/patterns.yml +30 -0
- rasa/cli/project_templates/telco/data/network/flow_reboot_router.yml +8 -0
- rasa/cli/project_templates/telco/data/network/flow_reset_router.yml +7 -0
- rasa/cli/project_templates/telco/data/network/flow_solve_internet_issue.yml +73 -0
- rasa/cli/project_templates/telco/docs/docs.md +8 -0
- rasa/cli/project_templates/telco/docs/network/reset_vs_rboot_router.txt +1 -0
- rasa/cli/project_templates/telco/docs/network/restart_router.txt +6 -0
- rasa/cli/project_templates/telco/docs/network/run_speed_test.txt +6 -0
- rasa/cli/project_templates/telco/domain/billing/understand_bill.yml +102 -0
- rasa/cli/project_templates/telco/domain/domain.md +13 -0
- rasa/cli/project_templates/telco/domain/general/bot_challenge.yml +4 -0
- rasa/cli/project_templates/telco/domain/general/feedback.yml +25 -0
- rasa/cli/project_templates/telco/domain/general/goodbye.yml +7 -0
- rasa/cli/project_templates/telco/domain/general/hello.yml +5 -0
- rasa/cli/project_templates/telco/domain/general/human_handoff.yml +26 -0
- rasa/cli/project_templates/telco/domain/general/patterns.yml +33 -0
- rasa/cli/project_templates/telco/domain/network/reboot_router.yml +21 -0
- rasa/cli/project_templates/telco/domain/network/reset_router.yml +12 -0
- rasa/cli/project_templates/telco/domain/network/run_speed_test.yml +25 -0
- rasa/cli/project_templates/telco/domain/network/solve_internet_issue.yml +74 -0
- rasa/cli/project_templates/telco/domain/shared.yml +129 -0
- rasa/cli/project_templates/telco/endpoints.yml +67 -0
- rasa/cli/project_templates/telco/prompts/rephraser_demo_personality_prompt.jinja2 +40 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/with_stub/network/solve_internet_not_slow.yml +33 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/with_stub/network/solve_internet_slow.yml +47 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/without_stub/billing/understand_bill.yml +67 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/without_stub/general/bot_challenge.yml +8 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/without_stub/general/feedback.yml +46 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/without_stub/general/goodbye.yml +9 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/without_stub/general/hello.yml +8 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/without_stub/general/human_handoff.yml +35 -0
- rasa/cli/project_templates/telco/tests/e2e_test_cases/without_stub/general/patterns.yml +23 -0
- rasa/cli/project_templates/tutorial/config.yml +2 -1
- rasa/cli/project_templates/tutorial/credentials.yml +10 -0
- rasa/cli/run.py +8 -10
- rasa/cli/scaffold.py +50 -6
- rasa/cli/shell.py +10 -5
- rasa/cli/studio/studio.py +1 -1
- rasa/cli/test.py +34 -14
- rasa/cli/train.py +44 -30
- rasa/cli/utils.py +1 -393
- rasa/cli/validation/__init__.py +0 -0
- rasa/cli/validation/bot_config.py +232 -0
- rasa/cli/validation/config_path_validation.py +257 -0
- rasa/cli/x.py +8 -4
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +53 -13
- rasa/core/actions/action_exceptions.py +1 -1
- rasa/core/actions/action_run_slot_rejections.py +1 -1
- rasa/core/actions/grpc_custom_action_executor.py +1 -1
- rasa/core/agent.py +22 -2
- rasa/core/available_agents.py +239 -0
- rasa/core/brokers/broker.py +1 -1
- rasa/core/brokers/kafka.py +56 -8
- rasa/core/channels/__init__.py +82 -35
- rasa/core/channels/channel.py +4 -3
- rasa/core/channels/constants.py +3 -0
- rasa/core/channels/development_inspector.py +29 -16
- rasa/core/channels/hangouts.py +2 -2
- rasa/core/channels/inspector/README.md +25 -13
- rasa/core/channels/inspector/dist/assets/{arc-0b11fe30.js → arc-6177260a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-9eef30a7.js → blockDiagram-38ab4fdb-b054f038.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-03e94f28.js → c4Diagram-3d4e48cf-f25427d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-bf9cbb34.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-95c09eba.js → classDiagram-70f12bd4-c7a2af53.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-38e8446c.js → classDiagram-v2-f2320105-58db65c0.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-8f9083bb.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-57dc3038.js → createText-2e5e7dd3-088372e2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-4bac0545.js → edges-e0da2a9e-58676240.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-81795c90.js → erDiagram-9861fffd-0c14d7c6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-89489ae6.js → flowDb-956e92f1-ea63f85c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-cd152627.js → flowDiagram-66a62f08-a2af48cd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-9ecd5b59.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-3da369bc.js → flowchart-elk-definition-4a651766-6937abe7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-85ec16f8.js → ganttDiagram-c361ad54-7473f357.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-495bc140.js → gitGraphDiagram-72cf32ee-d0c9405e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-1ec4d266.js → graph-0a6f8466.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-0a0e97c9.js → index-3862675e-7610671a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-74e01d94.js +1354 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-4d54bcde.js → infoDiagram-f8f76790-be397dc7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-dc097114.js → journeyDiagram-49397b02-4cefbf62.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-1a08981e.js → layout-e7fbc2bf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-95f7f1d3.js → line-a8aa457c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-97e69543.js → linear-3351e0d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-8c71ff03.js → mindmap-definition-fc14e90a-b8cbf605.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-f14c71c7.js → pieDiagram-8a3498a8-f327f774.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-f1d3c9ff.js → quadrantDiagram-120e2f19-2854c591.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-bfa2412f.js → requirementDiagram-deff3bca-964985d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-53f2c97b.js → sankeyDiagram-04a897e0-edeb4f33.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-319d7c0e.js → sequenceDiagram-704730f1-fcf70125.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-76a09418.js → stateDiagram-587899a1-0e770395.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-a67f15d4.js → stateDiagram-v2-d93cdb3a-af8dcd22.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-0654e7c3.js → styles-6aaf32cf-36a9e70d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-1394bb9d.js → styles-9a916d00-884a8b5b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-e4c5bdae.js → styles-c10674c1-dc097813.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-50957104.js → svgDrawCommon-08f97a94-5a2c7eed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-b0885a6a.js → timeline-definition-85554ec2-e89c4f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-79e6541a.js → xychartDiagram-e933f94c-afb6fe56.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/package.json +18 -18
- rasa/core/channels/inspector/src/App.tsx +56 -12
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +1 -1
- rasa/core/channels/inspector/src/components/DialogueAgentStack.tsx +108 -0
- rasa/core/channels/inspector/src/components/{DialogueStack.tsx → DialogueHistoryStack.tsx} +4 -2
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +20 -3
- rasa/core/channels/inspector/src/components/LatencyDisplay.tsx +296 -0
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +6 -2
- rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +26 -4
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +24 -3
- rasa/core/channels/inspector/src/helpers/utils.test.ts +127 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +66 -1
- rasa/core/channels/inspector/src/theme/base/styles.ts +19 -1
- rasa/core/channels/inspector/src/types.ts +55 -1
- rasa/core/channels/inspector/yarn.lock +336 -189
- rasa/core/channels/socketio.py +212 -51
- rasa/core/channels/studio_chat.py +82 -32
- rasa/core/channels/telegram.py +4 -9
- rasa/core/channels/voice_ready/twilio_voice.py +1 -1
- rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
- rasa/core/channels/voice_stream/asr/azure.py +6 -3
- rasa/core/channels/voice_stream/asr/deepgram.py +1 -1
- rasa/core/channels/voice_stream/audiocodes.py +11 -6
- rasa/core/channels/voice_stream/browser_audio.py +91 -4
- rasa/core/channels/voice_stream/call_state.py +13 -2
- rasa/core/channels/voice_stream/genesys.py +19 -15
- rasa/core/channels/voice_stream/jambonz.py +22 -12
- rasa/core/channels/voice_stream/tts/deepgram.py +140 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +35 -14
- rasa/core/channels/voice_stream/util.py +11 -1
- rasa/core/channels/voice_stream/voice_channel.py +170 -32
- rasa/core/concurrent_lock_store.py +83 -16
- rasa/core/config/__init__.py +0 -0
- rasa/core/{available_endpoints.py → config/available_endpoints.py} +56 -18
- rasa/core/config/configuration.py +295 -0
- rasa/core/config/credentials.py +19 -0
- rasa/core/config/message_procesing_config.py +34 -0
- rasa/core/constants.py +17 -0
- rasa/core/exceptions.py +1 -1
- rasa/core/featurizers/tracker_featurizers.py +3 -2
- rasa/core/iam_credentials_providers/__init__.py +0 -0
- rasa/core/iam_credentials_providers/aws_iam_credentials_providers.py +291 -0
- rasa/core/iam_credentials_providers/credentials_provider_protocol.py +91 -0
- rasa/core/lock_store.py +50 -10
- rasa/core/nlg/contextual_response_rephraser.py +5 -0
- rasa/core/nlg/generator.py +1 -1
- rasa/core/persistor.py +7 -7
- rasa/core/policies/enterprise_search_policy.py +9 -10
- rasa/core/policies/flow_policy.py +4 -4
- rasa/core/policies/flows/agent_executor.py +720 -0
- rasa/core/policies/flows/flow_exceptions.py +5 -2
- rasa/core/policies/flows/flow_executor.py +146 -77
- rasa/core/policies/flows/mcp_tool_executor.py +304 -0
- rasa/core/policies/intentless_policy.py +1 -1
- rasa/core/policies/rule_policy.py +1 -1
- rasa/core/policies/ted_policy.py +20 -12
- rasa/core/policies/unexpected_intent_policy.py +6 -0
- rasa/core/processor.py +100 -44
- rasa/core/redis_connection_factory.py +474 -0
- rasa/core/run.py +49 -10
- rasa/core/test.py +4 -0
- rasa/core/tracker_stores/redis_tracker_store.py +36 -14
- rasa/core/tracker_stores/sql_tracker_store.py +59 -1
- rasa/core/tracker_stores/tracker_store.py +3 -7
- rasa/core/train.py +1 -1
- rasa/core/training/interactive.py +20 -18
- rasa/core/training/story_conflict.py +5 -5
- rasa/core/utils.py +22 -23
- rasa/dialogue_understanding/commands/__init__.py +8 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +20 -6
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/clarify_command.py +20 -2
- rasa/dialogue_understanding/commands/continue_agent_command.py +91 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/restart_agent_command.py +162 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +75 -7
- rasa/dialogue_understanding/commands/utils.py +135 -2
- rasa/dialogue_understanding/generator/command_parser.py +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +0 -9
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +52 -12
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +89 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +88 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +42 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +40 -3
- rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +20 -3
- rasa/dialogue_understanding/patterns/cancel.py +27 -6
- rasa/dialogue_understanding/patterns/clarify.py +3 -14
- rasa/dialogue_understanding/patterns/continue_interrupted.py +239 -6
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +49 -9
- rasa/dialogue_understanding/processor/command_processor.py +136 -15
- rasa/dialogue_understanding/stack/dialogue_stack.py +98 -2
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +57 -0
- rasa/dialogue_understanding/stack/utils.py +57 -3
- rasa/dialogue_understanding/utils.py +24 -4
- rasa/dialogue_understanding_test/du_test_runner.py +8 -3
- rasa/e2e_test/e2e_test_runner.py +13 -3
- rasa/engine/caching.py +2 -2
- rasa/engine/constants.py +1 -1
- rasa/engine/graph.py +5 -1
- rasa/engine/loader.py +12 -0
- rasa/engine/recipes/default_components.py +138 -49
- rasa/engine/recipes/default_recipe.py +108 -11
- rasa/engine/runner/dask.py +8 -5
- rasa/engine/validation.py +25 -8
- rasa/graph_components/validators/default_recipe_validator.py +86 -28
- rasa/hooks.py +5 -5
- rasa/llm_fine_tuning/utils.py +2 -2
- rasa/model_manager/model_api.py +4 -5
- rasa/model_manager/runner_service.py +2 -2
- rasa/model_manager/socket_bridge.py +21 -17
- rasa/model_manager/trainer_service.py +12 -9
- rasa/model_manager/utils.py +1 -29
- rasa/model_manager/warm_rasa_process.py +13 -3
- rasa/model_training.py +60 -47
- rasa/nlu/classifiers/diet_classifier.py +198 -98
- rasa/nlu/classifiers/logistic_regression_classifier.py +1 -4
- rasa/nlu/classifiers/mitie_intent_classifier.py +3 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +1 -3
- rasa/nlu/extractors/crf_entity_extractor.py +9 -10
- rasa/nlu/extractors/mitie_entity_extractor.py +3 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +3 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +4 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +5 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +2 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +3 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +4 -2
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +4 -0
- rasa/nlu/selectors/response_selector.py +10 -2
- rasa/nlu/tokenizers/jieba_tokenizer.py +3 -4
- rasa/nlu/tokenizers/mitie_tokenizer.py +3 -2
- rasa/nlu/tokenizers/spacy_tokenizer.py +3 -2
- rasa/nlu/utils/mitie_utils.py +3 -0
- rasa/nlu/utils/spacy_utils.py +3 -2
- rasa/plugin.py +8 -8
- rasa/privacy/privacy_config.py +1 -1
- rasa/privacy/privacy_manager.py +12 -3
- rasa/server.py +15 -3
- rasa/shared/agents/__init__.py +0 -0
- rasa/shared/agents/auth/__init__.py +0 -0
- rasa/shared/agents/auth/agent_auth_factory.py +105 -0
- rasa/shared/agents/auth/agent_auth_manager.py +92 -0
- rasa/shared/agents/auth/auth_strategy/__init__.py +19 -0
- rasa/shared/agents/auth/auth_strategy/agent_auth_strategy.py +52 -0
- rasa/shared/agents/auth/auth_strategy/api_key_auth_strategy.py +42 -0
- rasa/shared/agents/auth/auth_strategy/bearer_token_auth_strategy.py +28 -0
- rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +170 -0
- rasa/shared/agents/auth/constants.py +13 -0
- rasa/shared/agents/auth/types.py +12 -0
- rasa/shared/agents/auth/utils.py +85 -0
- rasa/shared/agents/utils.py +35 -0
- rasa/shared/constants.py +11 -0
- rasa/shared/core/constants.py +17 -1
- rasa/shared/core/domain.py +62 -22
- rasa/shared/core/events.py +329 -0
- rasa/shared/core/flows/constants.py +5 -0
- rasa/shared/core/flows/flow.py +1 -1
- rasa/shared/core/flows/flow_step.py +7 -1
- rasa/shared/core/flows/flows_list.py +21 -5
- rasa/shared/core/flows/flows_yaml_schema.json +119 -184
- rasa/shared/core/flows/steps/call.py +57 -6
- rasa/shared/core/flows/steps/collect.py +98 -13
- rasa/shared/core/flows/validation.py +372 -8
- rasa/shared/core/flows/yaml_flows_io.py +19 -10
- rasa/shared/core/slots.py +6 -2
- rasa/shared/core/trackers.py +5 -2
- rasa/shared/core/training_data/story_reader/story_reader.py +1 -1
- rasa/shared/exceptions.py +39 -2
- rasa/shared/importers/importer.py +6 -0
- rasa/shared/importers/rasa.py +1 -1
- rasa/shared/importers/utils.py +86 -4
- rasa/shared/nlu/training_data/schemas/responses.yml +3 -0
- rasa/shared/providers/llm/_base_litellm_client.py +41 -9
- rasa/shared/providers/llm/litellm_router_llm_client.py +10 -6
- rasa/shared/providers/llm/llm_client.py +7 -3
- rasa/shared/providers/llm/llm_response.py +66 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +8 -4
- rasa/shared/utils/common.py +26 -1
- rasa/shared/utils/health_check/health_check.py +7 -3
- rasa/shared/utils/llm.py +92 -19
- rasa/shared/utils/mcp/__init__.py +0 -0
- rasa/shared/utils/mcp/server_connection.py +250 -0
- rasa/shared/utils/mcp/utils.py +20 -0
- rasa/shared/utils/schemas/events.py +42 -0
- rasa/shared/utils/yaml.py +3 -1
- rasa/studio/download.py +3 -0
- rasa/studio/prompts.py +1 -0
- rasa/studio/pull/pull.py +3 -2
- rasa/studio/train.py +8 -7
- rasa/studio/upload.py +19 -52
- rasa/telemetry.py +166 -28
- rasa/tracing/config.py +45 -12
- rasa/tracing/constants.py +14 -0
- rasa/tracing/instrumentation/attribute_extractors.py +142 -9
- rasa/tracing/instrumentation/instrumentation.py +626 -21
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +4 -4
- rasa/tracing/instrumentation/metrics.py +32 -0
- rasa/tracing/metric_instrument_provider.py +68 -0
- rasa/utils/common.py +92 -1
- rasa/utils/endpoints.py +11 -2
- rasa/utils/io.py +27 -9
- rasa/utils/json_utils.py +6 -1
- rasa/utils/log_utils.py +121 -7
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/openapi.py +144 -0
- rasa/utils/plotting.py +1 -1
- rasa/utils/pypred.py +45 -0
- rasa/utils/tensorflow/__init__.py +7 -0
- rasa/utils/tensorflow/callback.py +136 -101
- rasa/utils/tensorflow/crf.py +1 -1
- rasa/utils/tensorflow/data_generator.py +21 -8
- rasa/utils/tensorflow/layers.py +21 -11
- rasa/utils/tensorflow/metrics.py +7 -3
- rasa/utils/tensorflow/models.py +56 -8
- rasa/utils/tensorflow/rasa_layers.py +8 -6
- rasa/utils/tensorflow/transformer.py +2 -3
- rasa/utils/train_utils.py +54 -24
- rasa/validator.py +149 -16
- rasa/version.py +1 -1
- rasa_pro-3.14.0.dist-info/METADATA +212 -0
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.14.0.dist-info}/RECORD +581 -269
- rasa/core/channels/inspector/dist/assets/channel-51d02e9e.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-cc738fa6.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-0c716443.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-c804b295.js +0 -1335
- rasa_pro-3.13.12.dist-info/METADATA +0 -192
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.14.0.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.14.0.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.12.dist-info → rasa_pro-3.14.0.dist-info}/entry_points.txt +0 -0
rasa/utils/openapi.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Type
|
|
2
|
+
|
|
3
|
+
from pydantic.main import BaseModel
|
|
4
|
+
from sanic_openapi import openapi
|
|
5
|
+
from sanic_openapi.openapi3.types import Schema
|
|
6
|
+
|
|
7
|
+
_SUPPORTED_ATTRIBUTES = frozenset(["format", "enum", "required", "example"])
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _to_schema(
|
|
11
|
+
definition_stack: List[str], schema_def: Dict[str, Any], definitions: Dict[str, Any]
|
|
12
|
+
) -> Schema:
|
|
13
|
+
type = schema_def.get("type")
|
|
14
|
+
|
|
15
|
+
if type == "object":
|
|
16
|
+
properties_spec = schema_def.get("properties", {})
|
|
17
|
+
properties = {}
|
|
18
|
+
for key in properties_spec:
|
|
19
|
+
properties[key] = _to_schema(
|
|
20
|
+
definition_stack=definition_stack,
|
|
21
|
+
schema_def=properties_spec[key],
|
|
22
|
+
definitions=definitions,
|
|
23
|
+
)
|
|
24
|
+
schema = openapi.Object(
|
|
25
|
+
title=schema_def.get("title"),
|
|
26
|
+
description=schema_def.get("description"),
|
|
27
|
+
required=schema_def.get("required"),
|
|
28
|
+
properties=properties,
|
|
29
|
+
)
|
|
30
|
+
elif type == "array":
|
|
31
|
+
schema = openapi.Array(
|
|
32
|
+
description=schema_def.get("description"),
|
|
33
|
+
required=schema_def.get("required"),
|
|
34
|
+
items=_to_schema(
|
|
35
|
+
definition_stack=definition_stack,
|
|
36
|
+
schema_def=schema_def.get("items"),
|
|
37
|
+
definitions=definitions,
|
|
38
|
+
),
|
|
39
|
+
)
|
|
40
|
+
elif type is None:
|
|
41
|
+
if allof_spec := schema_def.get("allOf"): # Model, Enum
|
|
42
|
+
definition = allof_spec[0]["$ref"].split("/")[-1]
|
|
43
|
+
definition_data = definitions.get(definition)
|
|
44
|
+
if definition_data is None:
|
|
45
|
+
schema = openapi.Object(
|
|
46
|
+
title=definition, description=schema_def.get("description")
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
schema = (
|
|
50
|
+
_to_schema(
|
|
51
|
+
definition_stack=definition_stack + [definition],
|
|
52
|
+
schema_def={**definition_data},
|
|
53
|
+
definitions=definitions,
|
|
54
|
+
)
|
|
55
|
+
if definition not in definition_stack
|
|
56
|
+
else openapi.Object(
|
|
57
|
+
title=definition, description=schema_def.get("description")
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
elif anyof_spec := schema_def.get("anyOf"): # Union
|
|
62
|
+
anyof = []
|
|
63
|
+
for any in anyof_spec:
|
|
64
|
+
if any.get("type"):
|
|
65
|
+
schema_type_obj = Schema(
|
|
66
|
+
**{
|
|
67
|
+
"type": any.get("type"),
|
|
68
|
+
"description": any.get("description"),
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
anyof.append(schema_type_obj)
|
|
72
|
+
else:
|
|
73
|
+
definition = any["$ref"].split("/")[-1]
|
|
74
|
+
if definition not in definition_stack:
|
|
75
|
+
definition_data = definitions.get(definition)
|
|
76
|
+
if definition_data is not None:
|
|
77
|
+
anyof.append(
|
|
78
|
+
_to_schema(
|
|
79
|
+
definition_stack=definition_stack + [definition],
|
|
80
|
+
schema_def=definition_data,
|
|
81
|
+
definitions=definitions,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
anyof.append(
|
|
86
|
+
openapi.Object(
|
|
87
|
+
title=definition,
|
|
88
|
+
description=schema_def.get(
|
|
89
|
+
"description", definition
|
|
90
|
+
),
|
|
91
|
+
properties={},
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
anyof.append(
|
|
96
|
+
openapi.Object(
|
|
97
|
+
title=definition,
|
|
98
|
+
description=schema_def.get("description", definition),
|
|
99
|
+
properties={},
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
schema = Schema(anyOf=anyof)
|
|
103
|
+
elif ref := schema_def.get("$ref"): # $ref
|
|
104
|
+
definition = ref.split("/")[-1]
|
|
105
|
+
definition_data = definitions.get(definition)
|
|
106
|
+
if definition_data is not None:
|
|
107
|
+
schema = _to_schema(
|
|
108
|
+
definition_stack=definition_stack,
|
|
109
|
+
schema_def=definition_data,
|
|
110
|
+
definitions=definitions,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
schema = openapi.Object(
|
|
114
|
+
title=definition, description=schema_def.get("description")
|
|
115
|
+
)
|
|
116
|
+
else: # Any type
|
|
117
|
+
schema = Schema(
|
|
118
|
+
**{"type": "object", "description": schema_def.get("description")}
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
schema_spec = {
|
|
123
|
+
"type": schema_def.get("type"),
|
|
124
|
+
"description": schema_def.get("description"),
|
|
125
|
+
}
|
|
126
|
+
for spec in _SUPPORTED_ATTRIBUTES:
|
|
127
|
+
if schema_def.get(spec):
|
|
128
|
+
schema_spec[spec] = schema_def.get(spec)
|
|
129
|
+
schema = Schema(**schema_spec)
|
|
130
|
+
|
|
131
|
+
return schema
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def model_to_schema(model: Type[BaseModel]) -> Schema:
|
|
135
|
+
schema = model.model_json_schema()
|
|
136
|
+
# Handle both $defs (newer JSON Schema) and definitions (older JSON Schema)
|
|
137
|
+
definitions = schema.get("$defs") or schema.get("definitions") or {}
|
|
138
|
+
return _to_schema(
|
|
139
|
+
definition_stack=[],
|
|
140
|
+
schema_def=dict(
|
|
141
|
+
filter(lambda key: key[0] not in ("definitions", "$defs"), schema.items())
|
|
142
|
+
),
|
|
143
|
+
definitions=definitions,
|
|
144
|
+
)
|
rasa/utils/plotting.py
CHANGED
|
@@ -99,7 +99,7 @@ def plot_confusion_matrix(
|
|
|
99
99
|
zmax = confusion_matrix.max() if len(confusion_matrix) > 0 else 1
|
|
100
100
|
plt.clf()
|
|
101
101
|
if not color_map:
|
|
102
|
-
color_map = plt.cm.Blues
|
|
102
|
+
color_map = plt.cm.get_cmap("Blues")
|
|
103
103
|
plt.imshow(
|
|
104
104
|
confusion_matrix,
|
|
105
105
|
interpolation="nearest",
|
rasa/utils/pypred.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Wrapper module for pypred that provides a fixed Predicate class.
|
|
2
|
+
|
|
3
|
+
This module should be used instead of importing directly from pypred.
|
|
4
|
+
|
|
5
|
+
This patch fixes an issue where pypred creates excessive logs of being unable
|
|
6
|
+
to write a file when run in an environment with no write access is given.
|
|
7
|
+
|
|
8
|
+
https://rasahq.atlassian.net/browse/ATO-1925
|
|
9
|
+
|
|
10
|
+
The solution is based on https://github.com/FreeCAD/FreeCAD/issues/6315
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import ply.yacc
|
|
17
|
+
import pypred.parser
|
|
18
|
+
from pypred import Predicate as OriginalPredicate # noqa: TID251
|
|
19
|
+
|
|
20
|
+
# Store the original yacc function
|
|
21
|
+
_original_yacc = ply.yacc.yacc
|
|
22
|
+
|
|
23
|
+
# Create a logger that suppresses warnings to avoid yacc table file version warnings
|
|
24
|
+
_yacc_logger = logging.getLogger("ply.yacc")
|
|
25
|
+
_yacc_logger.setLevel(logging.ERROR)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def patched_yacc(*args: Any, **kwargs: Any) -> Any:
|
|
29
|
+
# Disable generation of debug ('parser.out') and table
|
|
30
|
+
# cache ('parsetab.py'), as it requires a writable location.
|
|
31
|
+
kwargs["write_tables"] = False
|
|
32
|
+
kwargs["module"] = pypred.parser
|
|
33
|
+
# Suppress yacc warnings by using a logger that only shows errors
|
|
34
|
+
kwargs["errorlog"] = _yacc_logger
|
|
35
|
+
return _original_yacc(*args, **kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Apply the patch
|
|
39
|
+
ply.yacc.yacc = patched_yacc
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Predicate(OriginalPredicate):
|
|
43
|
+
"""Fixed version of pypred.Predicate that uses the patched yacc parser."""
|
|
44
|
+
|
|
45
|
+
pass
|
|
@@ -2,111 +2,146 @@ import logging
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Any, Dict, Optional, Text
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
5
|
+
from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
|
|
6
|
+
|
|
7
|
+
if TENSORFLOW_AVAILABLE:
|
|
8
|
+
import tensorflow as tf
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
else:
|
|
11
|
+
# Placeholder values when TensorFlow is not available
|
|
12
|
+
tf = None
|
|
13
|
+
tqdm = None
|
|
7
14
|
|
|
8
15
|
import rasa.shared.utils.io
|
|
9
16
|
|
|
10
17
|
logger = logging.getLogger(__name__)
|
|
11
18
|
|
|
12
19
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
epoch:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
20
|
+
if TENSORFLOW_AVAILABLE:
|
|
21
|
+
|
|
22
|
+
class RasaTrainingLogger(tf.keras.callbacks.Callback):
|
|
23
|
+
"""Callback for logging the status of training."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, epochs: int, silent: bool) -> None:
|
|
26
|
+
"""Initializes the callback.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
epochs: Total number of epochs.
|
|
30
|
+
silent: If 'True' the entire progressbar wrapper is disabled.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
disable = silent or rasa.shared.utils.io.is_logging_disabled()
|
|
35
|
+
self.progress_bar = tqdm(range(epochs), desc="Epochs", disable=disable)
|
|
36
|
+
|
|
37
|
+
def on_epoch_end(
|
|
38
|
+
self, epoch: int, logs: Optional[Dict[Text, Any]] = None
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Updates the logging output on every epoch end.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
epoch: The current epoch.
|
|
44
|
+
logs: The training metrics.
|
|
45
|
+
"""
|
|
46
|
+
self.progress_bar.update(1)
|
|
47
|
+
self.progress_bar.set_postfix(logs)
|
|
48
|
+
|
|
49
|
+
def on_train_end(self, logs: Optional[Dict[Text, Any]] = None) -> None:
|
|
50
|
+
"""Closes the progress bar after training.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
logs: The training metrics.
|
|
54
|
+
"""
|
|
55
|
+
self.progress_bar.close()
|
|
56
|
+
|
|
57
|
+
class RasaModelCheckpoint(tf.keras.callbacks.Callback):
|
|
58
|
+
"""Callback for saving intermediate model checkpoints."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, checkpoint_dir: Path) -> None:
|
|
61
|
+
"""Initializes the callback.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
checkpoint_dir: Directory to store checkpoints to.
|
|
65
|
+
"""
|
|
66
|
+
super().__init__()
|
|
67
|
+
|
|
68
|
+
self.checkpoint_file = checkpoint_dir / "checkpoint.weights.h5"
|
|
69
|
+
self.best_metrics_so_far: Dict[Text, Any] = {}
|
|
70
|
+
|
|
71
|
+
def on_epoch_end(
|
|
72
|
+
self, epoch: int, logs: Optional[Dict[Text, Any]] = None
|
|
73
|
+
) -> None:
|
|
74
|
+
"""Save the model on epoch end if the model has improved.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
epoch: The current epoch.
|
|
78
|
+
logs: The training metrics.
|
|
79
|
+
"""
|
|
80
|
+
if self._does_model_improve(logs):
|
|
81
|
+
logger.debug(f"Creating model checkpoint at epoch={epoch + 1} ...")
|
|
82
|
+
# Ensure model is built before saving weights
|
|
83
|
+
if not self.model.built:
|
|
84
|
+
# Build the model with dummy data to ensure it's built
|
|
85
|
+
import tensorflow as tf
|
|
86
|
+
|
|
87
|
+
dummy_input = tf.zeros((1, 1))
|
|
88
|
+
_ = self.model(dummy_input)
|
|
89
|
+
|
|
90
|
+
# Ensure the directory exists before saving
|
|
91
|
+
import os
|
|
92
|
+
|
|
93
|
+
os.makedirs(os.path.dirname(self.checkpoint_file), exist_ok=True)
|
|
94
|
+
self.model.save_weights(self.checkpoint_file, overwrite=True)
|
|
95
|
+
|
|
96
|
+
def _does_model_improve(self, curr_results: Dict[Text, Any]) -> bool:
|
|
97
|
+
"""Checks whether the current results are better than the best so far.
|
|
98
|
+
|
|
99
|
+
Results are considered better if each metric is
|
|
100
|
+
equal or better than the best so far, and at least one is better.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
curr_results: The training metrics for this epoch.
|
|
104
|
+
"""
|
|
105
|
+
curr_metric_names = [
|
|
106
|
+
k
|
|
107
|
+
for k in curr_results.keys()
|
|
108
|
+
if k.startswith("val") and (k.endswith("_acc") or k.endswith("_f1"))
|
|
109
|
+
]
|
|
110
|
+
# the "val" prefix is prepended to metrics in fit
|
|
111
|
+
# if _should_eval returns true
|
|
112
|
+
# for this particular epoch
|
|
113
|
+
if len(curr_metric_names) == 0:
|
|
114
|
+
# the metrics are not validation metrics
|
|
104
115
|
return False
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
116
|
+
# initialize best_metrics_so_far with the first results
|
|
117
|
+
if not self.best_metrics_so_far:
|
|
118
|
+
for metric_name in curr_metric_names:
|
|
119
|
+
self.best_metrics_so_far[metric_name] = float(
|
|
120
|
+
curr_results[metric_name]
|
|
121
|
+
)
|
|
122
|
+
return True
|
|
123
|
+
|
|
124
|
+
at_least_one_improved = False
|
|
125
|
+
improved_metrics = {}
|
|
126
|
+
for metric_name in self.best_metrics_so_far.keys():
|
|
127
|
+
if (
|
|
128
|
+
float(curr_results[metric_name])
|
|
129
|
+
< self.best_metrics_so_far[metric_name]
|
|
130
|
+
):
|
|
131
|
+
# at least one of the values is worse
|
|
132
|
+
return False
|
|
133
|
+
if (
|
|
134
|
+
float(curr_results[metric_name])
|
|
135
|
+
> self.best_metrics_so_far[metric_name]
|
|
136
|
+
):
|
|
137
|
+
at_least_one_improved = True
|
|
138
|
+
improved_metrics[metric_name] = float(curr_results[metric_name])
|
|
139
|
+
|
|
140
|
+
# all current values >= previous best and at least one is better
|
|
141
|
+
if at_least_one_improved:
|
|
142
|
+
self.best_metrics_so_far.update(improved_metrics)
|
|
143
|
+
return at_least_one_improved
|
|
144
|
+
else:
|
|
145
|
+
# Placeholder classes when TensorFlow is not available
|
|
146
|
+
RasaTrainingLogger = None # type: ignore
|
|
147
|
+
RasaModelCheckpoint = None # type: ignore
|
rasa/utils/tensorflow/crf.py
CHANGED
|
@@ -9,7 +9,7 @@ from tensorflow.types.experimental import TensorLike
|
|
|
9
9
|
# (modified to our neeeds)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class CrfDecodeForwardRnnCell(tf.keras.layers.
|
|
12
|
+
class CrfDecodeForwardRnnCell(tf.keras.layers.Layer):
|
|
13
13
|
"""Computes the forward decoding in a linear-chain CRF."""
|
|
14
14
|
|
|
15
15
|
def __init__(self, transition_params: TensorLike, **kwargs: Any) -> None:
|
|
@@ -71,13 +71,22 @@ class RasaDataGenerator(Sequence):
|
|
|
71
71
|
# balancing on the next epoch
|
|
72
72
|
return data
|
|
73
73
|
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _create_default_array() -> np.ndarray:
|
|
76
|
+
"""Create a default empty array for missing features.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
A default empty array with shape (0, 1) and dtype float32.
|
|
80
|
+
"""
|
|
81
|
+
return np.zeros((0, 1), dtype=np.float32)
|
|
82
|
+
|
|
74
83
|
@staticmethod
|
|
75
84
|
def prepare_batch(
|
|
76
85
|
data: Data,
|
|
77
86
|
start: Optional[int] = None,
|
|
78
87
|
end: Optional[int] = None,
|
|
79
88
|
tuple_sizes: Optional[Dict[Text, int]] = None,
|
|
80
|
-
) -> Tuple[
|
|
89
|
+
) -> Tuple[np.ndarray, ...]:
|
|
81
90
|
"""Slices model data into batch using given start and end value.
|
|
82
91
|
|
|
83
92
|
Args:
|
|
@@ -85,8 +94,8 @@ class RasaDataGenerator(Sequence):
|
|
|
85
94
|
start: The start index of the batch
|
|
86
95
|
end: The end index of the batch
|
|
87
96
|
tuple_sizes: In case the feature is not present we propagate the batch with
|
|
88
|
-
|
|
89
|
-
what kind of feature.
|
|
97
|
+
default arrays. Tuple sizes contains the number of how many default values
|
|
98
|
+
to add for what kind of feature.
|
|
90
99
|
|
|
91
100
|
Returns:
|
|
92
101
|
The features of the batch.
|
|
@@ -95,12 +104,14 @@ class RasaDataGenerator(Sequence):
|
|
|
95
104
|
|
|
96
105
|
for key, attribute_data in data.items():
|
|
97
106
|
for sub_key, f_data in attribute_data.items():
|
|
98
|
-
# add
|
|
107
|
+
# add default arrays for not present values during processing
|
|
99
108
|
if not f_data:
|
|
100
109
|
if tuple_sizes:
|
|
101
|
-
batch_data += [
|
|
110
|
+
batch_data += [
|
|
111
|
+
RasaDataGenerator._create_default_array()
|
|
112
|
+
] * tuple_sizes[key]
|
|
102
113
|
else:
|
|
103
|
-
batch_data.append(
|
|
114
|
+
batch_data.append(RasaDataGenerator._create_default_array())
|
|
104
115
|
continue
|
|
105
116
|
|
|
106
117
|
for v in f_data:
|
|
@@ -409,8 +420,10 @@ class RasaBatchDataGenerator(RasaDataGenerator):
|
|
|
409
420
|
end = start + self._current_batch_size
|
|
410
421
|
|
|
411
422
|
# return input and target data, as our target data is inside the input
|
|
412
|
-
# data return
|
|
413
|
-
return self.prepare_batch(
|
|
423
|
+
# data return default array for the target data
|
|
424
|
+
return self.prepare_batch(
|
|
425
|
+
self._data, start, end
|
|
426
|
+
), RasaDataGenerator._create_default_array()
|
|
414
427
|
|
|
415
428
|
def on_epoch_end(self) -> None:
|
|
416
429
|
"""Update the data after every epoch."""
|
rasa/utils/tensorflow/layers.py
CHANGED
|
@@ -3,9 +3,7 @@ from typing import Any, Callable, List, Optional, Text, Tuple, Union
|
|
|
3
3
|
|
|
4
4
|
import tensorflow as tf
|
|
5
5
|
import tensorflow.keras.backend as K
|
|
6
|
-
|
|
7
|
-
# TODO: The following is not (yet) available via tf.keras
|
|
8
|
-
from keras.src.utils.control_flow_util import smart_cond
|
|
6
|
+
from tensorflow.python.keras.utils.control_flow_util import smart_cond
|
|
9
7
|
|
|
10
8
|
import rasa.utils.tensorflow.crf
|
|
11
9
|
import rasa.utils.tensorflow.layers_utils as layers_utils
|
|
@@ -278,6 +276,7 @@ class RandomlyConnectedDense(tf.keras.layers.Dense):
|
|
|
278
276
|
kernel_constraint: Constraint function applied to
|
|
279
277
|
the `kernel` weights matrix.
|
|
280
278
|
bias_constraint: Constraint function applied to the bias vector.
|
|
279
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
281
280
|
"""
|
|
282
281
|
super().__init__(**kwargs)
|
|
283
282
|
|
|
@@ -298,16 +297,19 @@ class RandomlyConnectedDense(tf.keras.layers.Dense):
|
|
|
298
297
|
self.kernel_mask = None
|
|
299
298
|
return
|
|
300
299
|
|
|
301
|
-
#
|
|
302
|
-
|
|
303
|
-
|
|
300
|
+
# Use callable initializer for TensorFlow 2.19.1 compatibility
|
|
301
|
+
def kernel_mask_initializer() -> tf.Tensor:
|
|
302
|
+
# Construct mask with given density and guarantee that every output is
|
|
303
|
+
# connected to at least one input
|
|
304
|
+
kernel_mask = self._minimal_mask() + self._random_mask()
|
|
304
305
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
306
|
+
# We might accidently have added a random connection on top of
|
|
307
|
+
# a fixed connection
|
|
308
|
+
kernel_mask = tf.clip_by_value(kernel_mask, 0, 1)
|
|
309
|
+
return kernel_mask
|
|
308
310
|
|
|
309
311
|
self.kernel_mask = tf.Variable(
|
|
310
|
-
initial_value=
|
|
312
|
+
initial_value=kernel_mask_initializer, trainable=False, name="kernel_mask"
|
|
311
313
|
)
|
|
312
314
|
|
|
313
315
|
def _random_mask(self) -> tf.Tensor:
|
|
@@ -367,7 +369,12 @@ class RandomlyConnectedDense(tf.keras.layers.Dense):
|
|
|
367
369
|
Returns:
|
|
368
370
|
The processed inputs.
|
|
369
371
|
"""
|
|
370
|
-
if
|
|
372
|
+
# Apply kernel masking if needed (Keras 3.x compatibility check)
|
|
373
|
+
if (
|
|
374
|
+
self.density < 1.0
|
|
375
|
+
and hasattr(self, "kernel_mask")
|
|
376
|
+
and self.kernel_mask is not None
|
|
377
|
+
):
|
|
371
378
|
# Set fraction of the `kernel` weights to zero according to precomputed mask
|
|
372
379
|
self.kernel.assign(self.kernel * self.kernel_mask)
|
|
373
380
|
return super().call(inputs)
|
|
@@ -724,6 +731,7 @@ class DotProductLoss(tf.keras.layers.Layer):
|
|
|
724
731
|
Currently, the only possible value is `SOFTMAX`.
|
|
725
732
|
similarity_type: Similarity measure to use, either `cosine` or `inner`.
|
|
726
733
|
name: Optional name of the layer.
|
|
734
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
727
735
|
|
|
728
736
|
Raises:
|
|
729
737
|
TFLayerConfigException: When `similarity_type` is not one of `COSINE` or
|
|
@@ -883,6 +891,7 @@ class SingleLabelDotProductLoss(DotProductLoss):
|
|
|
883
891
|
values are approximately bounded.
|
|
884
892
|
model_confidence: Normalization of confidence values during inference.
|
|
885
893
|
Currently, the only possible value is `SOFTMAX`.
|
|
894
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
886
895
|
"""
|
|
887
896
|
super().__init__(
|
|
888
897
|
num_candidates,
|
|
@@ -1244,6 +1253,7 @@ class MultiLabelDotProductLoss(DotProductLoss):
|
|
|
1244
1253
|
Used inside _loss_cross_entropy() only.
|
|
1245
1254
|
model_confidence: Normalization of confidence values during inference.
|
|
1246
1255
|
Currently, the only possible value is `SOFTMAX`.
|
|
1256
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
1247
1257
|
"""
|
|
1248
1258
|
super().__init__(
|
|
1249
1259
|
num_candidates,
|
rasa/utils/tensorflow/metrics.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Any, Dict, Optional
|
|
2
2
|
|
|
3
3
|
import tensorflow as tf
|
|
4
|
-
from tensorflow.keras import backend as K
|
|
5
4
|
from tensorflow.types.experimental import TensorLike
|
|
6
5
|
|
|
7
6
|
# original code taken from
|
|
@@ -118,7 +117,7 @@ class FBetaScore(tf.keras.metrics.Metric):
|
|
|
118
117
|
|
|
119
118
|
def _zero_wt_init(name: Any) -> Any:
|
|
120
119
|
return self.add_weight(
|
|
121
|
-
name, shape=self.init_shape, initializer="zeros", dtype=self.dtype
|
|
120
|
+
name=name, shape=self.init_shape, initializer="zeros", dtype=self.dtype
|
|
122
121
|
)
|
|
123
122
|
|
|
124
123
|
self.true_positives = _zero_wt_init("true_positives")
|
|
@@ -197,7 +196,12 @@ class FBetaScore(tf.keras.metrics.Metric):
|
|
|
197
196
|
|
|
198
197
|
def reset_state(self) -> None:
|
|
199
198
|
reset_value = tf.zeros(self.init_shape, dtype=self.dtype)
|
|
200
|
-
|
|
199
|
+
# In Keras 3.x, self.variables contains string names, not variable objects,
|
|
200
|
+
# so each metric variable is reset using assign() instead of K.batch_set_value()
|
|
201
|
+
self.true_positives.assign(reset_value)
|
|
202
|
+
self.false_positives.assign(reset_value)
|
|
203
|
+
self.false_negatives.assign(reset_value)
|
|
204
|
+
self.weights_intermediate.assign(reset_value)
|
|
201
205
|
|
|
202
206
|
def reset_states(self) -> None:
|
|
203
207
|
# Backwards compatibility alias of `reset_state`. New classes should
|