rasa-pro 3.8.18__py3-none-any.whl → 3.9.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +6 -42
- rasa/__main__.py +14 -9
- rasa/anonymization/anonymization_pipeline.py +0 -1
- rasa/anonymization/anonymization_rule_executor.py +3 -3
- rasa/anonymization/utils.py +4 -3
- rasa/api.py +2 -2
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/arguments/run.py +2 -2
- rasa/cli/arguments/test.py +1 -1
- rasa/cli/arguments/train.py +10 -10
- rasa/cli/e2e_test.py +27 -7
- rasa/cli/export.py +0 -1
- rasa/cli/license.py +3 -3
- rasa/cli/project_templates/calm/actions/action_template.py +1 -1
- rasa/cli/project_templates/calm/config.yml +1 -1
- rasa/cli/project_templates/calm/credentials.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
- rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
- rasa/cli/project_templates/calm/domain/shared.yml +5 -0
- rasa/cli/project_templates/calm/endpoints.yml +4 -4
- rasa/cli/project_templates/default/actions/actions.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -5
- rasa/cli/project_templates/default/credentials.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +4 -4
- rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/credentials.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +4 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
- rasa/cli/run.py +0 -1
- rasa/cli/scaffold.py +3 -2
- rasa/cli/studio/download.py +11 -0
- rasa/cli/studio/studio.py +180 -24
- rasa/cli/studio/upload.py +0 -8
- rasa/cli/telemetry.py +18 -6
- rasa/cli/utils.py +21 -10
- rasa/cli/x.py +3 -2
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +90 -315
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/constants.py +3 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +11 -7
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +3 -0
- rasa/core/actions/two_stage_fallback.py +1 -1
- rasa/core/agent.py +2 -4
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/audiocodes.py +1 -1
- rasa/core/channels/botframework.py +0 -1
- rasa/core/channels/callback.py +0 -1
- rasa/core/channels/console.py +6 -8
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +0 -3
- rasa/core/channels/hangouts.py +0 -6
- rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
- rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
- rasa/core/channels/rest.py +36 -21
- rasa/core/channels/rocketchat.py +0 -1
- rasa/core/channels/socketio.py +1 -1
- rasa/core/channels/telegram.py +3 -3
- rasa/core/channels/webexteams.py +0 -1
- rasa/core/concurrent_lock_store.py +1 -1
- rasa/core/evaluation/marker_base.py +1 -3
- rasa/core/evaluation/marker_stats.py +1 -2
- rasa/core/featurizers/single_state_featurizer.py +3 -26
- rasa/core/featurizers/tracker_featurizers.py +18 -122
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +9 -4
- rasa/core/information_retrieval/information_retrieval.py +64 -7
- rasa/core/information_retrieval/milvus.py +7 -14
- rasa/core/information_retrieval/qdrant.py +8 -15
- rasa/core/lock_store.py +0 -1
- rasa/core/migrate.py +1 -2
- rasa/core/nlg/callback.py +3 -4
- rasa/core/policies/enterprise_search_policy.py +86 -22
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flows/flow_executor.py +104 -2
- rasa/core/policies/intentless_policy.py +7 -9
- rasa/core/policies/memoization.py +3 -3
- rasa/core/policies/policy.py +18 -9
- rasa/core/policies/rule_policy.py +8 -11
- rasa/core/policies/ted_policy.py +61 -88
- rasa/core/policies/unexpected_intent_policy.py +8 -17
- rasa/core/processor.py +136 -47
- rasa/core/run.py +41 -25
- rasa/core/secrets_manager/endpoints.py +2 -2
- rasa/core/secrets_manager/vault.py +6 -8
- rasa/core/test.py +3 -5
- rasa/core/tracker_store.py +49 -14
- rasa/core/train.py +1 -3
- rasa/core/training/interactive.py +9 -6
- rasa/core/utils.py +5 -10
- rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
- rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/clarify_command.py +9 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
- rasa/dialogue_understanding/commands/error_command.py +12 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/noop_command.py +9 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +38 -3
- rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
- rasa/dialogue_understanding/generator/__init__.py +16 -1
- rasa/dialogue_understanding/generator/command_generator.py +92 -6
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +36 -31
- rasa/dialogue_understanding/processor/command_processor.py +112 -3
- rasa/e2e_test/constants.py +1 -0
- rasa/e2e_test/e2e_test_case.py +44 -0
- rasa/e2e_test/e2e_test_runner.py +114 -11
- rasa/e2e_test/e2e_test_schema.yml +18 -0
- rasa/engine/caching.py +0 -1
- rasa/engine/graph.py +18 -6
- rasa/engine/recipes/config_files/default_config.yml +3 -3
- rasa/engine/recipes/default_components.py +1 -1
- rasa/engine/recipes/default_recipe.py +4 -5
- rasa/engine/recipes/recipe.py +1 -1
- rasa/engine/runner/dask.py +3 -9
- rasa/engine/storage/local_model_storage.py +0 -2
- rasa/engine/validation.py +179 -145
- rasa/exceptions.py +2 -2
- rasa/graph_components/validators/default_recipe_validator.py +3 -5
- rasa/hooks.py +0 -1
- rasa/model.py +1 -1
- rasa/model_training.py +1 -0
- rasa/nlu/classifiers/diet_classifier.py +33 -52
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +54 -97
- rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
- rasa/nlu/featurizers/featurizer.py +1 -1
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +68 -26
- rasa/nlu/selectors/response_selector.py +7 -10
- rasa/nlu/test.py +0 -3
- rasa/nlu/utils/hugging_face/registry.py +1 -1
- rasa/nlu/utils/spacy_utils.py +1 -3
- rasa/server.py +22 -7
- rasa/shared/constants.py +12 -1
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +4 -5
- rasa/shared/core/domain.py +57 -56
- rasa/shared/core/events.py +4 -7
- rasa/shared/core/flows/flow.py +9 -0
- rasa/shared/core/flows/flows_list.py +12 -0
- rasa/shared/core/flows/steps/action.py +7 -2
- rasa/shared/core/generator.py +12 -11
- rasa/shared/core/slot_mappings.py +315 -24
- rasa/shared/core/slots.py +4 -2
- rasa/shared/core/trackers.py +32 -14
- rasa/shared/core/training_data/loading.py +0 -1
- rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
- rasa/shared/core/training_data/structures.py +1 -1
- rasa/shared/core/training_data/visualization.py +1 -1
- rasa/shared/data.py +58 -1
- rasa/shared/exceptions.py +36 -2
- rasa/shared/importers/importer.py +1 -2
- rasa/shared/importers/rasa.py +0 -1
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/entities_parser.py +1 -2
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
- rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
- rasa/shared/nlu/training_data/message.py +13 -0
- rasa/shared/nlu/training_data/training_data.py +0 -2
- rasa/shared/providers/openai/session_handler.py +2 -2
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/io.py +11 -1
- rasa/shared/utils/llm.py +1 -2
- rasa/shared/utils/pykwalify_extensions.py +1 -0
- rasa/shared/utils/schemas/domain.yml +3 -0
- rasa/shared/utils/yaml.py +44 -35
- rasa/studio/auth.py +26 -10
- rasa/studio/constants.py +2 -0
- rasa/studio/data_handler.py +114 -107
- rasa/studio/download.py +160 -27
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +6 -7
- rasa/studio/upload.py +159 -134
- rasa/telemetry.py +188 -34
- rasa/tracing/config.py +18 -3
- rasa/tracing/constants.py +26 -2
- rasa/tracing/instrumentation/attribute_extractors.py +50 -41
- rasa/tracing/instrumentation/instrumentation.py +290 -44
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
- rasa/tracing/instrumentation/metrics.py +109 -21
- rasa/tracing/metric_instrument_provider.py +83 -3
- rasa/utils/cli.py +2 -1
- rasa/utils/common.py +1 -1
- rasa/utils/endpoints.py +1 -2
- rasa/utils/io.py +72 -6
- rasa/utils/licensing.py +246 -31
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/data_generator.py +1 -1
- rasa/utils/tensorflow/environment.py +1 -1
- rasa/utils/tensorflow/model_data.py +201 -12
- rasa/utils/tensorflow/model_data_utils.py +499 -500
- rasa/utils/tensorflow/models.py +5 -6
- rasa/utils/tensorflow/rasa_layers.py +15 -15
- rasa/utils/train_utils.py +1 -1
- rasa/utils/url_tools.py +53 -0
- rasa/validator.py +305 -3
- rasa/version.py +1 -1
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/METADATA +25 -61
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/RECORD +276 -259
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
- rasa/utils/tensorflow/feature_array.py +0 -370
- /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/NOTICE +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/WHEEL +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/entry_points.txt +0 -0
rasa/utils/licensing.py
CHANGED
|
@@ -1,14 +1,27 @@
|
|
|
1
|
+
from asyncio import AbstractEventLoop
|
|
1
2
|
import hashlib
|
|
2
|
-
import logging
|
|
3
3
|
import os
|
|
4
|
+
import random
|
|
4
5
|
import re
|
|
5
6
|
import time
|
|
7
|
+
import typing
|
|
6
8
|
import uuid
|
|
7
9
|
from datetime import datetime, timezone
|
|
8
|
-
from typing import Any, Callable, Dict, Optional, Set, Text
|
|
10
|
+
from typing import Any, Callable, Dict, Optional, Set, Text, TypeVar
|
|
9
11
|
|
|
10
12
|
import jwt
|
|
11
13
|
from dotenv import dotenv_values
|
|
14
|
+
from sanic import Sanic
|
|
15
|
+
import structlog
|
|
16
|
+
from rasa import telemetry
|
|
17
|
+
|
|
18
|
+
from rasa.core import jobs
|
|
19
|
+
from rasa.shared.utils.cli import print_error_and_exit
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
if typing.TYPE_CHECKING:
|
|
23
|
+
from rasa.core.tracker_store import TrackerStore
|
|
24
|
+
|
|
12
25
|
|
|
13
26
|
LICENSE_ENV_VAR = "RASA_PRO_LICENSE"
|
|
14
27
|
ALGORITHM = "RS256"
|
|
@@ -32,8 +45,11 @@ JTI_BLOCKLIST: Set[Text] = set([])
|
|
|
32
45
|
SCOPE_DELIMITER = ":"
|
|
33
46
|
PRODUCT_AREA = "rasa:pro:plus"
|
|
34
47
|
VOICE_SCOPE = "rasa:voice"
|
|
35
|
-
|
|
36
|
-
|
|
48
|
+
CHAMPION_SERVER_LIMITED_SCOPE = "rasa:pro:champion-server-limited"
|
|
49
|
+
CHAMPION_SERVER_INTERNAL_SCOPE = "rasa:pro:champion-server-internal"
|
|
50
|
+
# defines multiple of the soft limit that triggers the hard limit
|
|
51
|
+
HARD_LIMIT_FACTOR = 10
|
|
52
|
+
structlogger = structlog.get_logger()
|
|
37
53
|
|
|
38
54
|
|
|
39
55
|
class LicenseValidationException(Exception):
|
|
@@ -64,6 +80,10 @@ class LicenseNotYetValidException(LicenseValidationException):
|
|
|
64
80
|
"""Exception raised when a license is not valid yet (nbf)."""
|
|
65
81
|
|
|
66
82
|
|
|
83
|
+
class LicenseNotFoundException(LicenseValidationException):
|
|
84
|
+
"""Exception raised when a license is not available."""
|
|
85
|
+
|
|
86
|
+
|
|
67
87
|
class License:
|
|
68
88
|
"""Represents a Rasa Pro license.
|
|
69
89
|
|
|
@@ -80,6 +100,37 @@ class License:
|
|
|
80
100
|
requires access to the private key to then encode) or for testing
|
|
81
101
|
purposes.
|
|
82
102
|
|
|
103
|
+
The scopes of the license define what it unlocks. The scopes are
|
|
104
|
+
hierarchical, meaning that a license with a scope `rasa:pro` also
|
|
105
|
+
unlocks the scopes `rasa:pro:plus`.
|
|
106
|
+
|
|
107
|
+
Scopes can be used to enable or disable features in Rasa, the following
|
|
108
|
+
scopes are used:
|
|
109
|
+
|
|
110
|
+
- `rasa:pro` - Rasa Pro features
|
|
111
|
+
- `rasa:pro:plus` - Rasa Pro Plus features
|
|
112
|
+
- `rasa:voice` - Voice features
|
|
113
|
+
|
|
114
|
+
In addition to scopes focused on feature flagging, there are also scopes
|
|
115
|
+
that are used to limited by the license agreement they are issued under:
|
|
116
|
+
|
|
117
|
+
- `rasa:pro:champion` - Champion license, can only be used for local
|
|
118
|
+
development
|
|
119
|
+
|
|
120
|
+
- `rasa:pro:champion-server-internal` - Champion license, can be used
|
|
121
|
+
to deploy an assistant on a server for an internal use case. The license
|
|
122
|
+
limits the number of conversations per month to 100.
|
|
123
|
+
|
|
124
|
+
- `rasa:pro:champion-server-limited` - Champion license, can be used
|
|
125
|
+
to deploy an assistant on a server for a limited external use case.
|
|
126
|
+
The license allows for a maximum of 1000 conversations per month.
|
|
127
|
+
|
|
128
|
+
The champion scopes on their own do not unlock any features, they are
|
|
129
|
+
used to limit the use of the license to the agreed upon number of
|
|
130
|
+
conversations. The champion scopes can be combined with the feature
|
|
131
|
+
scopes to enable the features for the champion license, e.g.
|
|
132
|
+
`rasa:pro:champion rasa:pro` would enable the champion license for
|
|
133
|
+
local development and the Rasa Pro features.
|
|
83
134
|
"""
|
|
84
135
|
|
|
85
136
|
__slots__ = ["jti", "iat", "nbf", "scope", "exp", "email", "company"]
|
|
@@ -215,39 +266,44 @@ class License:
|
|
|
215
266
|
return jwt.encode(self.as_dict(), key=private_key, algorithm=ALGORITHM)
|
|
216
267
|
|
|
217
268
|
|
|
218
|
-
def date_as_unix_timestamp(utc_date: Text) -> int:
|
|
219
|
-
"""Returns a date represented as a UNIX timestamp.
|
|
220
|
-
|
|
221
|
-
Args:
|
|
222
|
-
utc_date: Date as text (YYYY-MM-DD), UTC timezone.
|
|
223
|
-
|
|
224
|
-
Returns:
|
|
225
|
-
Date as UNIX timestamp.
|
|
226
|
-
"""
|
|
227
|
-
dt = datetime.strptime(utc_date, "%Y-%m-%d")
|
|
228
|
-
return int(dt.replace(tzinfo=timezone.utc).timestamp())
|
|
229
|
-
|
|
230
|
-
|
|
231
269
|
def retrieve_license_from_env() -> Text:
|
|
232
270
|
"""Return the license found in the env var."""
|
|
233
271
|
stored_env_values = dotenv_values(".env")
|
|
234
272
|
license_from_env = os.environ.get(LICENSE_ENV_VAR)
|
|
235
273
|
license = license_from_env or stored_env_values.get(LICENSE_ENV_VAR)
|
|
236
274
|
if not license:
|
|
237
|
-
raise
|
|
238
|
-
f"A Rasa Pro license is required. "
|
|
239
|
-
f"Please set the environmental variable "
|
|
240
|
-
f"`{LICENSE_ENV_VAR}` to a valid license string. "
|
|
241
|
-
)
|
|
275
|
+
raise LicenseNotFoundException()
|
|
242
276
|
return license
|
|
243
277
|
|
|
244
278
|
|
|
279
|
+
def is_license_expiring_soon(license: License) -> bool:
|
|
280
|
+
"""Check if the license is about to expire in less than 30 days."""
|
|
281
|
+
return license.exp - time.time() < 30 * 24 * 60 * 60
|
|
282
|
+
|
|
283
|
+
|
|
245
284
|
def validate_license_from_env(product_area: Text = PRODUCT_AREA) -> None:
|
|
246
|
-
license = retrieve_license_from_env()
|
|
247
285
|
try:
|
|
248
|
-
|
|
286
|
+
license_text = retrieve_license_from_env()
|
|
287
|
+
license = License.decode(license_text, product_area=product_area)
|
|
288
|
+
|
|
289
|
+
if is_license_expiring_soon(license):
|
|
290
|
+
structlogger.warning(
|
|
291
|
+
"license.expiration.warning",
|
|
292
|
+
event_info=(
|
|
293
|
+
"Your license is about to expire. "
|
|
294
|
+
"Please contact Rasa for a renewal."
|
|
295
|
+
),
|
|
296
|
+
expiration_date=datetime.utcfromtimestamp(license.exp).isoformat(),
|
|
297
|
+
)
|
|
298
|
+
except LicenseNotFoundException:
|
|
299
|
+
structlogger.error("license.not_found.error")
|
|
300
|
+
raise SystemExit(
|
|
301
|
+
f"A Rasa Pro license is required. "
|
|
302
|
+
f"Please set the environmental variable "
|
|
303
|
+
f"`{LICENSE_ENV_VAR}` to a valid license string. "
|
|
304
|
+
)
|
|
249
305
|
except LicenseValidationException as e:
|
|
250
|
-
|
|
306
|
+
structlogger.error("license.validation.error", error=e)
|
|
251
307
|
raise SystemExit(
|
|
252
308
|
f"Failed to validate Rasa Pro license "
|
|
253
309
|
f"which was read from environmental variable `{LICENSE_ENV_VAR}`. "
|
|
@@ -295,21 +351,25 @@ def derive_scope_hierarchy(scope: Text) -> Set[Text]:
|
|
|
295
351
|
return set(required_scopes)
|
|
296
352
|
|
|
297
353
|
|
|
298
|
-
|
|
354
|
+
T = TypeVar("T")
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def property_of_active_license(prop: Callable[[License], T]) -> Optional[T]:
|
|
299
358
|
"""Return a property for this installation based on license.
|
|
300
359
|
|
|
301
360
|
Returns:
|
|
302
361
|
The property of the license if it exists, otherwise None.
|
|
303
362
|
"""
|
|
304
|
-
retrieved_license = os.environ.get(LICENSE_ENV_VAR)
|
|
305
|
-
if not retrieved_license:
|
|
306
|
-
return None
|
|
307
|
-
|
|
308
363
|
try:
|
|
364
|
+
retrieved_license = retrieve_license_from_env()
|
|
365
|
+
if not retrieved_license:
|
|
366
|
+
return None
|
|
309
367
|
decoded = License.decode(retrieved_license)
|
|
310
368
|
return prop(decoded)
|
|
369
|
+
except LicenseNotFoundException:
|
|
370
|
+
return None
|
|
311
371
|
except LicenseValidationException as e:
|
|
312
|
-
|
|
372
|
+
structlogger.warn("licensing.active_license.invalid", error=e)
|
|
313
373
|
return None
|
|
314
374
|
|
|
315
375
|
|
|
@@ -317,3 +377,158 @@ def get_license_hash() -> Optional[Text]:
|
|
|
317
377
|
"""Return the hash of the current active license."""
|
|
318
378
|
license_value = retrieve_license_from_env()
|
|
319
379
|
return hashlib.sha256(license_value.encode("utf-8")).hexdigest()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def is_champion_server_license() -> bool:
|
|
383
|
+
"""Return whether the current license is a developer license."""
|
|
384
|
+
|
|
385
|
+
def has_developer_license_scope(license: License) -> bool:
|
|
386
|
+
if CHAMPION_SERVER_LIMITED_SCOPE in license.scope:
|
|
387
|
+
return True
|
|
388
|
+
if CHAMPION_SERVER_INTERNAL_SCOPE in license.scope:
|
|
389
|
+
return True
|
|
390
|
+
return False
|
|
391
|
+
|
|
392
|
+
# or False is needed to handle the case where the license is not set
|
|
393
|
+
# this is to please the type checker
|
|
394
|
+
return property_of_active_license(has_developer_license_scope) or False
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def conversation_limit_for_license() -> Optional[int]:
|
|
398
|
+
def conversations_limit(license: License) -> Optional[int]:
|
|
399
|
+
"""Return maximum number of conversations per month for this license."""
|
|
400
|
+
scope = license.scope
|
|
401
|
+
|
|
402
|
+
if CHAMPION_SERVER_LIMITED_SCOPE in scope:
|
|
403
|
+
return 1000
|
|
404
|
+
if CHAMPION_SERVER_INTERNAL_SCOPE in scope:
|
|
405
|
+
return 100
|
|
406
|
+
return None
|
|
407
|
+
|
|
408
|
+
return property_of_active_license(conversations_limit)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
async def validate_limited_server_license(app: Sanic, loop: AbstractEventLoop) -> None:
|
|
412
|
+
"""Validate a limited server license and schedule conversation counting job."""
|
|
413
|
+
max_number_conversations = conversation_limit_for_license()
|
|
414
|
+
|
|
415
|
+
if app.ctx.agent:
|
|
416
|
+
if max_number_conversations is None:
|
|
417
|
+
structlogger.debug("licensing.server_limit.unlimited")
|
|
418
|
+
else:
|
|
419
|
+
store = app.ctx.agent.tracker_store
|
|
420
|
+
await run_conversation_counting(store, max_number_conversations)
|
|
421
|
+
await _schedule_conversation_counting(app, store, max_number_conversations)
|
|
422
|
+
else:
|
|
423
|
+
structlogger.warn("licensing.validate_limited_server_license.no_agent")
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
async def handle_soft_limit_reached(
|
|
427
|
+
conversation_count: int, max_number_conversations: int, month: datetime
|
|
428
|
+
) -> None:
|
|
429
|
+
"""Log a warning when the number of conversations exceeds the soft limit."""
|
|
430
|
+
structlogger.error(
|
|
431
|
+
"licensing.conversation_count.exceeded",
|
|
432
|
+
event_info=(
|
|
433
|
+
"The number of conversations has exceeded the limit granted "
|
|
434
|
+
"by your license. Please contact us to upgrade your license."
|
|
435
|
+
),
|
|
436
|
+
conversation_count=conversation_count,
|
|
437
|
+
max_number_conversations=max_number_conversations,
|
|
438
|
+
)
|
|
439
|
+
telemetry.track_conversation_count_soft_limit(conversation_count, month)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
async def handle_hard_limit_reached(
|
|
443
|
+
conversation_count: int, max_number_conversations: int, month: datetime
|
|
444
|
+
) -> None:
|
|
445
|
+
"""Log an error when the number of conversations exceeds the hard limit."""
|
|
446
|
+
structlogger.error(
|
|
447
|
+
"licensing.conversation_count.exceeded",
|
|
448
|
+
event_info=(
|
|
449
|
+
"The number of conversations has exceeded the limit granted "
|
|
450
|
+
"by your license. Please contact us to upgrade your license."
|
|
451
|
+
),
|
|
452
|
+
conversation_count=conversation_count,
|
|
453
|
+
max_number_conversations=max_number_conversations,
|
|
454
|
+
)
|
|
455
|
+
telemetry.track_conversation_count_hard_limit(conversation_count, month)
|
|
456
|
+
print_error_and_exit(
|
|
457
|
+
"The number of conversations has exceeded the limit granted by your "
|
|
458
|
+
"license. Please contact us to upgrade your license."
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def current_utc_month() -> datetime:
|
|
463
|
+
"""Return the current month in UTC timezone."""
|
|
464
|
+
return datetime.now(timezone.utc).replace(
|
|
465
|
+
day=1, hour=0, minute=0, second=0, microsecond=0
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
async def run_conversation_counting(
|
|
470
|
+
tracker_store: Optional["TrackerStore"], max_number_conversations: int
|
|
471
|
+
) -> None:
|
|
472
|
+
"""Count the number of conversations started in the current month and log it."""
|
|
473
|
+
start_of_month_utc = current_utc_month()
|
|
474
|
+
start_of_month_timestamp = start_of_month_utc.timestamp()
|
|
475
|
+
|
|
476
|
+
conversation_count = await _count_conversations_after(
|
|
477
|
+
tracker_store, start_of_month_timestamp
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
if conversation_count >= max_number_conversations * HARD_LIMIT_FACTOR:
|
|
481
|
+
# user is above the hard limit
|
|
482
|
+
await handle_hard_limit_reached(
|
|
483
|
+
conversation_count, max_number_conversations, start_of_month_utc
|
|
484
|
+
)
|
|
485
|
+
elif conversation_count > max_number_conversations:
|
|
486
|
+
await handle_soft_limit_reached(
|
|
487
|
+
conversation_count, max_number_conversations, start_of_month_utc
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
structlogger.debug(
|
|
491
|
+
"licensing.conversation_count",
|
|
492
|
+
conversation_count=conversation_count,
|
|
493
|
+
max_number_conversations=max_number_conversations,
|
|
494
|
+
)
|
|
495
|
+
telemetry.track_conversation_count(conversation_count, start_of_month_utc)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
async def _schedule_conversation_counting(
|
|
499
|
+
app: Sanic, tracker_store: Optional["TrackerStore"], max_number_conversations: int
|
|
500
|
+
) -> None:
|
|
501
|
+
"""Schedule a job counting the number of conversations in the current month."""
|
|
502
|
+
|
|
503
|
+
async def conversation_counting_job(
|
|
504
|
+
app: Sanic,
|
|
505
|
+
tracker_store: Optional["TrackerStore"],
|
|
506
|
+
max_number_conversations: int,
|
|
507
|
+
) -> None:
|
|
508
|
+
try:
|
|
509
|
+
await run_conversation_counting(tracker_store, max_number_conversations)
|
|
510
|
+
except SystemExit:
|
|
511
|
+
# we've reached the conversation limit
|
|
512
|
+
app.stop()
|
|
513
|
+
|
|
514
|
+
(await jobs.scheduler()).add_job(
|
|
515
|
+
conversation_counting_job,
|
|
516
|
+
"interval",
|
|
517
|
+
# every 24 hours with a random offset of max 1 hour
|
|
518
|
+
# clusters tend to get started at the same time, this avoids them
|
|
519
|
+
# doing the "conversation counting work" at exactly the same time
|
|
520
|
+
seconds=24 * 60 * 60 + random.randint(0, 60 * 60),
|
|
521
|
+
args=[app, tracker_store, max_number_conversations],
|
|
522
|
+
id="count-conversations",
|
|
523
|
+
replace_existing=True,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
async def _count_conversations_after(
|
|
528
|
+
tracker_store: Optional["TrackerStore"], after_timestamp: float
|
|
529
|
+
) -> int:
|
|
530
|
+
"""Count the number of conversations started in the current month."""
|
|
531
|
+
if tracker_store is None:
|
|
532
|
+
return 0
|
|
533
|
+
|
|
534
|
+
return await tracker_store.count_conversations(after_timestamp=after_timestamp)
|
rasa/utils/ml_utils.py
CHANGED
|
@@ -217,7 +217,7 @@ class RasaDataGenerator(Sequence):
|
|
|
217
217
|
# we need to make sure that the matrices are coo_matrices otherwise the
|
|
218
218
|
# transformation does not work (e.g. you cannot access x.row, x.col)
|
|
219
219
|
if not isinstance(array_of_sparse[0], scipy.sparse.coo_matrix):
|
|
220
|
-
array_of_sparse = [x.tocoo() for x in array_of_sparse] # type: ignore[assignment]
|
|
220
|
+
array_of_sparse = [x.tocoo() for x in array_of_sparse] # type: ignore[assignment]
|
|
221
221
|
|
|
222
222
|
max_seq_len = max([x.shape[0] for x in array_of_sparse])
|
|
223
223
|
|
|
@@ -98,7 +98,7 @@ def _parse_gpu_config(gpu_memory_config: Text) -> Dict[int, int]:
|
|
|
98
98
|
# Helper explanation of where the error comes from
|
|
99
99
|
raise ValueError(
|
|
100
100
|
f"Error parsing GPU configuration. Please cross-check the format of "
|
|
101
|
-
f"'{ENV_GPU_CONFIG}' at https://rasa.com/docs/rasa/tuning-your-model"
|
|
101
|
+
f"'{ENV_GPU_CONFIG}' at https://rasa.com/docs/rasa-pro/nlu-based-assistants/tuning-your-model"
|
|
102
102
|
f"#restricting-absolute-gpu-memory-available ."
|
|
103
103
|
)
|
|
104
104
|
|
|
@@ -20,8 +20,6 @@ import numpy as np
|
|
|
20
20
|
import scipy.sparse
|
|
21
21
|
from sklearn.model_selection import train_test_split
|
|
22
22
|
|
|
23
|
-
from rasa.utils.tensorflow.feature_array import FeatureArray
|
|
24
|
-
|
|
25
23
|
logger = logging.getLogger(__name__)
|
|
26
24
|
|
|
27
25
|
|
|
@@ -39,6 +37,199 @@ def ragged_array_to_ndarray(ragged_array: Iterable[np.ndarray]) -> np.ndarray:
|
|
|
39
37
|
return np.array(ragged_array, dtype=object)
|
|
40
38
|
|
|
41
39
|
|
|
40
|
+
class FeatureArray(np.ndarray):
|
|
41
|
+
"""Stores any kind of features ready to be used by a RasaModel.
|
|
42
|
+
|
|
43
|
+
Next to the input numpy array of features, it also received the number of
|
|
44
|
+
dimensions of the features.
|
|
45
|
+
As our features can have 1 to 4 dimensions we might have different number of numpy
|
|
46
|
+
arrays stacked. The number of dimensions helps us to figure out how to handle this
|
|
47
|
+
particular feature array. Also, it is automatically determined whether the feature
|
|
48
|
+
array is sparse or not and the number of units is determined as well.
|
|
49
|
+
|
|
50
|
+
Subclassing np.array: https://numpy.org/doc/stable/user/basics.subclassing.html
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __new__(
|
|
54
|
+
cls, input_array: np.ndarray, number_of_dimensions: int
|
|
55
|
+
) -> "FeatureArray":
|
|
56
|
+
"""Create and return a new object. See help(type) for accurate signature."""
|
|
57
|
+
FeatureArray._validate_number_of_dimensions(number_of_dimensions, input_array)
|
|
58
|
+
|
|
59
|
+
feature_array = np.asarray(input_array).view(cls)
|
|
60
|
+
|
|
61
|
+
if number_of_dimensions <= 2:
|
|
62
|
+
feature_array.units = input_array.shape[-1]
|
|
63
|
+
feature_array.is_sparse = isinstance(input_array[0], scipy.sparse.spmatrix)
|
|
64
|
+
elif number_of_dimensions == 3:
|
|
65
|
+
feature_array.units = input_array[0].shape[-1]
|
|
66
|
+
feature_array.is_sparse = isinstance(input_array[0], scipy.sparse.spmatrix)
|
|
67
|
+
elif number_of_dimensions == 4:
|
|
68
|
+
feature_array.units = input_array[0][0].shape[-1]
|
|
69
|
+
feature_array.is_sparse = isinstance(
|
|
70
|
+
input_array[0][0], scipy.sparse.spmatrix
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Number of dimensions '{number_of_dimensions}' currently not "
|
|
75
|
+
f"supported."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
feature_array.number_of_dimensions = number_of_dimensions
|
|
79
|
+
|
|
80
|
+
return feature_array
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self, input_array: Any, number_of_dimensions: int, **kwargs: Any
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Initialize. FeatureArray.
|
|
86
|
+
|
|
87
|
+
Needed in order to avoid 'Invalid keyword argument number_of_dimensions
|
|
88
|
+
to function FeatureArray.__init__ '
|
|
89
|
+
Args:
|
|
90
|
+
input_array: the array that contains features
|
|
91
|
+
number_of_dimensions: number of dimensions in input_array
|
|
92
|
+
"""
|
|
93
|
+
super().__init__(**kwargs)
|
|
94
|
+
self.number_of_dimensions = number_of_dimensions
|
|
95
|
+
|
|
96
|
+
def __array_finalize__(self, obj: Optional[np.ndarray]) -> None:
|
|
97
|
+
"""This method is called when the system allocates a new array from obj.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
obj: A subclass (subtype) of ndarray.
|
|
101
|
+
"""
|
|
102
|
+
if obj is None:
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
self.units = getattr(obj, "units", None)
|
|
106
|
+
self.number_of_dimensions = getattr(obj, "number_of_dimensions", None) # type: ignore[assignment]
|
|
107
|
+
self.is_sparse = getattr(obj, "is_sparse", None)
|
|
108
|
+
|
|
109
|
+
default_attributes = {
|
|
110
|
+
"units": self.units,
|
|
111
|
+
"number_of_dimensions": self.number_of_dimensions,
|
|
112
|
+
"is_spare": self.is_sparse,
|
|
113
|
+
}
|
|
114
|
+
self.__dict__.update(default_attributes)
|
|
115
|
+
|
|
116
|
+
# pytype: disable=attribute-error
|
|
117
|
+
def __array_ufunc__(
|
|
118
|
+
self, ufunc: Any, method: Text, *inputs: Any, **kwargs: Any
|
|
119
|
+
) -> Any:
|
|
120
|
+
"""Overwrite this method as we are subclassing numpy array.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
ufunc: The ufunc object that was called.
|
|
124
|
+
method: A string indicating which Ufunc method was called
|
|
125
|
+
(one of "__call__", "reduce", "reduceat", "accumulate", "outer",
|
|
126
|
+
"inner").
|
|
127
|
+
*inputs: A tuple of the input arguments to the ufunc.
|
|
128
|
+
**kwargs: Any additional arguments
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
The result of the operation.
|
|
132
|
+
"""
|
|
133
|
+
f = {
|
|
134
|
+
"reduce": ufunc.reduce,
|
|
135
|
+
"accumulate": ufunc.accumulate,
|
|
136
|
+
"reduceat": ufunc.reduceat,
|
|
137
|
+
"outer": ufunc.outer,
|
|
138
|
+
"at": ufunc.at,
|
|
139
|
+
"__call__": ufunc,
|
|
140
|
+
}
|
|
141
|
+
# convert the inputs to np.ndarray to prevent recursion, call the function,
|
|
142
|
+
# then cast it back as FeatureArray
|
|
143
|
+
output = FeatureArray(
|
|
144
|
+
f[method](*(i.view(np.ndarray) for i in inputs), **kwargs),
|
|
145
|
+
number_of_dimensions=kwargs["number_of_dimensions"],
|
|
146
|
+
)
|
|
147
|
+
output.__dict__ = self.__dict__ # carry forward attributes
|
|
148
|
+
return output
|
|
149
|
+
|
|
150
|
+
def __reduce__(self) -> Tuple[Any, Any, Any]:
|
|
151
|
+
"""Needed in order to pickle this object.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A tuple.
|
|
155
|
+
"""
|
|
156
|
+
pickled_state = super(FeatureArray, self).__reduce__()
|
|
157
|
+
if isinstance(pickled_state, str):
|
|
158
|
+
raise TypeError("np array __reduce__ returned string instead of tuple.")
|
|
159
|
+
new_state = pickled_state[2] + (
|
|
160
|
+
self.number_of_dimensions,
|
|
161
|
+
self.is_sparse,
|
|
162
|
+
self.units,
|
|
163
|
+
)
|
|
164
|
+
return pickled_state[0], pickled_state[1], new_state
|
|
165
|
+
|
|
166
|
+
def __setstate__(self, state: Any, **kwargs: Any) -> None:
|
|
167
|
+
"""Sets the state.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
state: The state argument must be a sequence that contains the following
|
|
171
|
+
elements version, shape, dtype, isFortan, rawdata.
|
|
172
|
+
**kwargs: Any additional parameter
|
|
173
|
+
"""
|
|
174
|
+
# Needed in order to load the object
|
|
175
|
+
self.number_of_dimensions = state[-3]
|
|
176
|
+
self.is_sparse = state[-2]
|
|
177
|
+
self.units = state[-1]
|
|
178
|
+
super(FeatureArray, self).__setstate__(state[0:-3], **kwargs)
|
|
179
|
+
|
|
180
|
+
# pytype: enable=attribute-error
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def _validate_number_of_dimensions(
|
|
184
|
+
number_of_dimensions: int, input_array: np.ndarray
|
|
185
|
+
) -> None:
|
|
186
|
+
"""Validates if the the input array has given number of dimensions.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
number_of_dimensions: number of dimensions
|
|
190
|
+
input_array: input array
|
|
191
|
+
|
|
192
|
+
Raises: ValueError in case the dimensions do not match
|
|
193
|
+
"""
|
|
194
|
+
_sub_array = input_array
|
|
195
|
+
dim = 0
|
|
196
|
+
# Go number_of_dimensions into the given input_array
|
|
197
|
+
for i in range(1, number_of_dimensions + 1):
|
|
198
|
+
_sub_array = _sub_array[0]
|
|
199
|
+
if isinstance(_sub_array, scipy.sparse.spmatrix):
|
|
200
|
+
dim = i
|
|
201
|
+
break
|
|
202
|
+
if isinstance(_sub_array, np.ndarray) and _sub_array.shape[0] == 0:
|
|
203
|
+
# sequence dimension is 0, we are dealing with "fake" features
|
|
204
|
+
dim = i
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
# If the resulting sub_array is sparse, the remaining number of dimensions
|
|
208
|
+
# should be at least 2
|
|
209
|
+
if isinstance(_sub_array, scipy.sparse.spmatrix):
|
|
210
|
+
if dim > 2:
|
|
211
|
+
raise ValueError(
|
|
212
|
+
f"Given number of dimensions '{number_of_dimensions}' does not "
|
|
213
|
+
f"match dimensions of given input array: {input_array}."
|
|
214
|
+
)
|
|
215
|
+
elif isinstance(_sub_array, np.ndarray) and _sub_array.shape[0] == 0:
|
|
216
|
+
# sequence dimension is 0, we are dealing with "fake" features,
|
|
217
|
+
# but they should be of dim 2
|
|
218
|
+
if dim > 2:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
f"Given number of dimensions '{number_of_dimensions}' does not "
|
|
221
|
+
f"match dimensions of given input array: {input_array}."
|
|
222
|
+
)
|
|
223
|
+
# If the resulting sub_array is dense, the sub_array should be a single number
|
|
224
|
+
elif not np.issubdtype(type(_sub_array), np.integer) and not isinstance(
|
|
225
|
+
_sub_array, (np.float32, np.float64)
|
|
226
|
+
):
|
|
227
|
+
raise ValueError(
|
|
228
|
+
f"Given number of dimensions '{number_of_dimensions}' does not match "
|
|
229
|
+
f"dimensions of given input array: {input_array}."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
42
233
|
class FeatureSignature(NamedTuple):
|
|
43
234
|
"""Signature of feature arrays.
|
|
44
235
|
|
|
@@ -94,12 +285,10 @@ class RasaModelData:
|
|
|
94
285
|
self.sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]] = {}
|
|
95
286
|
|
|
96
287
|
@overload
|
|
97
|
-
def get(self, key: Text, sub_key: Text) -> List[FeatureArray]:
|
|
98
|
-
...
|
|
288
|
+
def get(self, key: Text, sub_key: Text) -> List[FeatureArray]: ...
|
|
99
289
|
|
|
100
290
|
@overload
|
|
101
|
-
def get(self, key: Text, sub_key: None = ...) -> Dict[Text, List[FeatureArray]]:
|
|
102
|
-
...
|
|
291
|
+
def get(self, key: Text, sub_key: None = ...) -> Dict[Text, List[FeatureArray]]: ...
|
|
103
292
|
|
|
104
293
|
def get(
|
|
105
294
|
self, key: Text, sub_key: Optional[Text] = None
|
|
@@ -548,9 +737,9 @@ class RasaModelData:
|
|
|
548
737
|
# if a label was skipped in current batch
|
|
549
738
|
skipped = [False] * num_label_ids
|
|
550
739
|
|
|
551
|
-
new_data: DefaultDict[
|
|
552
|
-
|
|
553
|
-
|
|
740
|
+
new_data: DefaultDict[Text, DefaultDict[Text, List[List[FeatureArray]]]] = (
|
|
741
|
+
defaultdict(lambda: defaultdict(list))
|
|
742
|
+
)
|
|
554
743
|
|
|
555
744
|
while min(num_data_cycles) == 0:
|
|
556
745
|
if shuffle:
|
|
@@ -701,9 +890,9 @@ class RasaModelData:
|
|
|
701
890
|
Returns:
|
|
702
891
|
The test and train RasaModelData
|
|
703
892
|
"""
|
|
704
|
-
data_train: DefaultDict[
|
|
705
|
-
|
|
706
|
-
|
|
893
|
+
data_train: DefaultDict[Text, DefaultDict[Text, List[FeatureArray]]] = (
|
|
894
|
+
defaultdict(lambda: defaultdict(list))
|
|
895
|
+
)
|
|
707
896
|
data_val: DefaultDict[Text, DefaultDict[Text, List[Any]]] = defaultdict(
|
|
708
897
|
lambda: defaultdict(list)
|
|
709
898
|
)
|