rasa-pro 3.8.18__py3-none-any.whl → 3.9.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +6 -42
- rasa/__main__.py +14 -9
- rasa/anonymization/anonymization_pipeline.py +0 -1
- rasa/anonymization/anonymization_rule_executor.py +3 -3
- rasa/anonymization/utils.py +4 -3
- rasa/api.py +2 -2
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/arguments/run.py +2 -2
- rasa/cli/arguments/test.py +1 -1
- rasa/cli/arguments/train.py +10 -10
- rasa/cli/e2e_test.py +27 -7
- rasa/cli/export.py +0 -1
- rasa/cli/license.py +3 -3
- rasa/cli/project_templates/calm/actions/action_template.py +1 -1
- rasa/cli/project_templates/calm/config.yml +1 -1
- rasa/cli/project_templates/calm/credentials.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
- rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
- rasa/cli/project_templates/calm/domain/shared.yml +5 -0
- rasa/cli/project_templates/calm/endpoints.yml +4 -4
- rasa/cli/project_templates/default/actions/actions.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -5
- rasa/cli/project_templates/default/credentials.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +4 -4
- rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/credentials.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +4 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
- rasa/cli/run.py +0 -1
- rasa/cli/scaffold.py +3 -2
- rasa/cli/studio/download.py +11 -0
- rasa/cli/studio/studio.py +180 -24
- rasa/cli/studio/upload.py +0 -8
- rasa/cli/telemetry.py +18 -6
- rasa/cli/utils.py +21 -10
- rasa/cli/x.py +3 -2
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +90 -315
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/constants.py +3 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +11 -7
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +3 -0
- rasa/core/actions/two_stage_fallback.py +1 -1
- rasa/core/agent.py +2 -4
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/audiocodes.py +1 -1
- rasa/core/channels/botframework.py +0 -1
- rasa/core/channels/callback.py +0 -1
- rasa/core/channels/console.py +6 -8
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +0 -3
- rasa/core/channels/hangouts.py +0 -6
- rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
- rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
- rasa/core/channels/rest.py +36 -21
- rasa/core/channels/rocketchat.py +0 -1
- rasa/core/channels/socketio.py +1 -1
- rasa/core/channels/telegram.py +3 -3
- rasa/core/channels/webexteams.py +0 -1
- rasa/core/concurrent_lock_store.py +1 -1
- rasa/core/evaluation/marker_base.py +1 -3
- rasa/core/evaluation/marker_stats.py +1 -2
- rasa/core/featurizers/single_state_featurizer.py +3 -26
- rasa/core/featurizers/tracker_featurizers.py +18 -122
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +9 -4
- rasa/core/information_retrieval/information_retrieval.py +64 -7
- rasa/core/information_retrieval/milvus.py +7 -14
- rasa/core/information_retrieval/qdrant.py +8 -15
- rasa/core/lock_store.py +0 -1
- rasa/core/migrate.py +1 -2
- rasa/core/nlg/callback.py +3 -4
- rasa/core/policies/enterprise_search_policy.py +86 -22
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flows/flow_executor.py +104 -2
- rasa/core/policies/intentless_policy.py +7 -9
- rasa/core/policies/memoization.py +3 -3
- rasa/core/policies/policy.py +18 -9
- rasa/core/policies/rule_policy.py +8 -11
- rasa/core/policies/ted_policy.py +61 -88
- rasa/core/policies/unexpected_intent_policy.py +8 -17
- rasa/core/processor.py +136 -47
- rasa/core/run.py +41 -25
- rasa/core/secrets_manager/endpoints.py +2 -2
- rasa/core/secrets_manager/vault.py +6 -8
- rasa/core/test.py +3 -5
- rasa/core/tracker_store.py +49 -14
- rasa/core/train.py +1 -3
- rasa/core/training/interactive.py +9 -6
- rasa/core/utils.py +5 -10
- rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
- rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/clarify_command.py +9 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
- rasa/dialogue_understanding/commands/error_command.py +12 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/noop_command.py +9 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +38 -3
- rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
- rasa/dialogue_understanding/generator/__init__.py +16 -1
- rasa/dialogue_understanding/generator/command_generator.py +92 -6
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +36 -31
- rasa/dialogue_understanding/processor/command_processor.py +112 -3
- rasa/e2e_test/constants.py +1 -0
- rasa/e2e_test/e2e_test_case.py +44 -0
- rasa/e2e_test/e2e_test_runner.py +114 -11
- rasa/e2e_test/e2e_test_schema.yml +18 -0
- rasa/engine/caching.py +0 -1
- rasa/engine/graph.py +18 -6
- rasa/engine/recipes/config_files/default_config.yml +3 -3
- rasa/engine/recipes/default_components.py +1 -1
- rasa/engine/recipes/default_recipe.py +4 -5
- rasa/engine/recipes/recipe.py +1 -1
- rasa/engine/runner/dask.py +3 -9
- rasa/engine/storage/local_model_storage.py +0 -2
- rasa/engine/validation.py +179 -145
- rasa/exceptions.py +2 -2
- rasa/graph_components/validators/default_recipe_validator.py +3 -5
- rasa/hooks.py +0 -1
- rasa/model.py +1 -1
- rasa/model_training.py +1 -0
- rasa/nlu/classifiers/diet_classifier.py +33 -52
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +54 -97
- rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
- rasa/nlu/featurizers/featurizer.py +1 -1
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +68 -26
- rasa/nlu/selectors/response_selector.py +7 -10
- rasa/nlu/test.py +0 -3
- rasa/nlu/utils/hugging_face/registry.py +1 -1
- rasa/nlu/utils/spacy_utils.py +1 -3
- rasa/server.py +22 -7
- rasa/shared/constants.py +12 -1
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +4 -5
- rasa/shared/core/domain.py +57 -56
- rasa/shared/core/events.py +4 -7
- rasa/shared/core/flows/flow.py +9 -0
- rasa/shared/core/flows/flows_list.py +12 -0
- rasa/shared/core/flows/steps/action.py +7 -2
- rasa/shared/core/generator.py +12 -11
- rasa/shared/core/slot_mappings.py +315 -24
- rasa/shared/core/slots.py +4 -2
- rasa/shared/core/trackers.py +32 -14
- rasa/shared/core/training_data/loading.py +0 -1
- rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
- rasa/shared/core/training_data/structures.py +1 -1
- rasa/shared/core/training_data/visualization.py +1 -1
- rasa/shared/data.py +58 -1
- rasa/shared/exceptions.py +36 -2
- rasa/shared/importers/importer.py +1 -2
- rasa/shared/importers/rasa.py +0 -1
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/entities_parser.py +1 -2
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
- rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
- rasa/shared/nlu/training_data/message.py +13 -0
- rasa/shared/nlu/training_data/training_data.py +0 -2
- rasa/shared/providers/openai/session_handler.py +2 -2
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/io.py +11 -1
- rasa/shared/utils/llm.py +1 -2
- rasa/shared/utils/pykwalify_extensions.py +1 -0
- rasa/shared/utils/schemas/domain.yml +3 -0
- rasa/shared/utils/yaml.py +44 -35
- rasa/studio/auth.py +26 -10
- rasa/studio/constants.py +2 -0
- rasa/studio/data_handler.py +114 -107
- rasa/studio/download.py +160 -27
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +6 -7
- rasa/studio/upload.py +159 -134
- rasa/telemetry.py +188 -34
- rasa/tracing/config.py +18 -3
- rasa/tracing/constants.py +26 -2
- rasa/tracing/instrumentation/attribute_extractors.py +50 -41
- rasa/tracing/instrumentation/instrumentation.py +290 -44
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
- rasa/tracing/instrumentation/metrics.py +109 -21
- rasa/tracing/metric_instrument_provider.py +83 -3
- rasa/utils/cli.py +2 -1
- rasa/utils/common.py +1 -1
- rasa/utils/endpoints.py +1 -2
- rasa/utils/io.py +72 -6
- rasa/utils/licensing.py +246 -31
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/data_generator.py +1 -1
- rasa/utils/tensorflow/environment.py +1 -1
- rasa/utils/tensorflow/model_data.py +201 -12
- rasa/utils/tensorflow/model_data_utils.py +499 -500
- rasa/utils/tensorflow/models.py +5 -6
- rasa/utils/tensorflow/rasa_layers.py +15 -15
- rasa/utils/train_utils.py +1 -1
- rasa/utils/url_tools.py +53 -0
- rasa/validator.py +305 -3
- rasa/version.py +1 -1
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/METADATA +25 -61
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/RECORD +276 -259
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
- rasa/utils/tensorflow/feature_array.py +0 -370
- /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/NOTICE +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/WHEEL +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/entry_points.txt +0 -0
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import copy
|
|
4
3
|
import logging
|
|
5
4
|
from collections import defaultdict
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
|
|
6
|
+
|
|
7
|
+
from rasa.exceptions import ModelNotFound
|
|
8
|
+
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
import scipy.sparse
|
|
11
12
|
import tensorflow as tf
|
|
12
13
|
|
|
13
|
-
from
|
|
14
|
-
|
|
14
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
|
|
15
|
+
|
|
15
16
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
16
17
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
17
18
|
from rasa.engine.storage.resource import Resource
|
|
@@ -19,21 +20,18 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
19
20
|
from rasa.nlu.extractors.extractor import EntityExtractorMixin
|
|
20
21
|
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
21
22
|
import rasa.shared.utils.io
|
|
23
|
+
import rasa.utils.io as io_utils
|
|
22
24
|
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
23
25
|
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
24
26
|
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
25
27
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
26
28
|
from rasa.utils import train_utils
|
|
27
29
|
from rasa.utils.tensorflow import rasa_layers
|
|
28
|
-
from rasa.utils.tensorflow.feature_array import (
|
|
29
|
-
FeatureArray,
|
|
30
|
-
serialize_nested_feature_arrays,
|
|
31
|
-
deserialize_nested_feature_arrays,
|
|
32
|
-
)
|
|
33
30
|
from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
|
|
34
31
|
from rasa.utils.tensorflow.model_data import (
|
|
35
32
|
RasaModelData,
|
|
36
33
|
FeatureSignature,
|
|
34
|
+
FeatureArray,
|
|
37
35
|
)
|
|
38
36
|
from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
|
|
39
37
|
from rasa.shared.nlu.constants import (
|
|
@@ -120,6 +118,7 @@ LABEL_SUB_KEY = IDS
|
|
|
120
118
|
|
|
121
119
|
POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
|
|
122
120
|
|
|
121
|
+
|
|
123
122
|
DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
|
|
124
123
|
|
|
125
124
|
|
|
@@ -511,7 +510,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
511
510
|
def _extract_features(
|
|
512
511
|
self, message: Message, attribute: Text
|
|
513
512
|
) -> Dict[Text, Union[scipy.sparse.spmatrix, np.ndarray]]:
|
|
514
|
-
|
|
515
513
|
(
|
|
516
514
|
sparse_sequence_features,
|
|
517
515
|
sparse_sentence_features,
|
|
@@ -781,7 +779,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
781
779
|
sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
782
780
|
label_attribute: Optional[Text] = None,
|
|
783
781
|
) -> Dict[Text, Dict[Text, List[int]]]:
|
|
784
|
-
|
|
785
782
|
if label_attribute in sparse_feature_sizes:
|
|
786
783
|
del sparse_feature_sizes[label_attribute]
|
|
787
784
|
return sparse_feature_sizes
|
|
@@ -1086,24 +1083,18 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1086
1083
|
|
|
1087
1084
|
self.model.save(str(tf_model_file))
|
|
1088
1085
|
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
self._data_example,
|
|
1092
|
-
model_path / f"{file_name}.data_example.st",
|
|
1093
|
-
model_path / f"{file_name}.data_example_metadata.json",
|
|
1086
|
+
io_utils.pickle_dump(
|
|
1087
|
+
model_path / f"{file_name}.data_example.pkl", self._data_example
|
|
1094
1088
|
)
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1098
|
-
model_path / f"{file_name}.label_data.st",
|
|
1099
|
-
model_path / f"{file_name}.label_data_metadata.json",
|
|
1100
|
-
)
|
|
1101
|
-
|
|
1102
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1103
|
-
model_path / f"{file_name}.sparse_feature_sizes.json",
|
|
1089
|
+
io_utils.pickle_dump(
|
|
1090
|
+
model_path / f"{file_name}.sparse_feature_sizes.pkl",
|
|
1104
1091
|
self._sparse_feature_sizes,
|
|
1105
1092
|
)
|
|
1106
|
-
|
|
1093
|
+
io_utils.pickle_dump(
|
|
1094
|
+
model_path / f"{file_name}.label_data.pkl",
|
|
1095
|
+
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1096
|
+
)
|
|
1097
|
+
io_utils.json_pickle(
|
|
1107
1098
|
model_path / f"{file_name}.index_label_id_mapping.json",
|
|
1108
1099
|
self.index_label_id_mapping,
|
|
1109
1100
|
)
|
|
@@ -1192,22 +1183,15 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1192
1183
|
]:
|
|
1193
1184
|
file_name = cls.__name__
|
|
1194
1185
|
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
str(model_path / f"{file_name}.data_example.st"),
|
|
1198
|
-
str(model_path / f"{file_name}.data_example_metadata.json"),
|
|
1186
|
+
data_example = io_utils.pickle_load(
|
|
1187
|
+
model_path / f"{file_name}.data_example.pkl"
|
|
1199
1188
|
)
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
)
|
|
1205
|
-
label_data = RasaModelData(data=loaded_label_data)
|
|
1206
|
-
|
|
1207
|
-
sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
|
|
1208
|
-
model_path / f"{file_name}.sparse_feature_sizes.json"
|
|
1189
|
+
label_data = io_utils.pickle_load(model_path / f"{file_name}.label_data.pkl")
|
|
1190
|
+
label_data = RasaModelData(data=label_data)
|
|
1191
|
+
sparse_feature_sizes = io_utils.pickle_load(
|
|
1192
|
+
model_path / f"{file_name}.sparse_feature_sizes.pkl"
|
|
1209
1193
|
)
|
|
1210
|
-
index_label_id_mapping =
|
|
1194
|
+
index_label_id_mapping = io_utils.json_unpickle(
|
|
1211
1195
|
model_path / f"{file_name}.index_label_id_mapping.json"
|
|
1212
1196
|
)
|
|
1213
1197
|
entity_tag_specs = rasa.shared.utils.io.read_json_file(
|
|
@@ -1227,6 +1211,7 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1227
1211
|
for tag_spec in entity_tag_specs
|
|
1228
1212
|
]
|
|
1229
1213
|
|
|
1214
|
+
# jsonpickle converts dictionary keys to strings
|
|
1230
1215
|
index_label_id_mapping = {
|
|
1231
1216
|
int(key): value for key, value in index_label_id_mapping.items()
|
|
1232
1217
|
}
|
|
@@ -1280,7 +1265,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1280
1265
|
config: Dict[Text, Any],
|
|
1281
1266
|
finetune_mode: bool,
|
|
1282
1267
|
) -> "RasaModel":
|
|
1283
|
-
|
|
1284
1268
|
predict_data_example = RasaModelData(
|
|
1285
1269
|
label_key=model_data_example.label_key,
|
|
1286
1270
|
data={
|
|
@@ -1467,10 +1451,10 @@ class DIET(TransformerRasaModel):
|
|
|
1467
1451
|
# everything using a transformer and optionally also do masked language
|
|
1468
1452
|
# modeling.
|
|
1469
1453
|
self.text_name = TEXT
|
|
1470
|
-
self._tf_layers[
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1454
|
+
self._tf_layers[f"sequence_layer.{self.text_name}"] = (
|
|
1455
|
+
rasa_layers.RasaSequenceLayer(
|
|
1456
|
+
self.text_name, self.data_signature[self.text_name], self.config
|
|
1457
|
+
)
|
|
1474
1458
|
)
|
|
1475
1459
|
if self.config[MASKED_LM]:
|
|
1476
1460
|
self._prepare_mask_lm_loss(self.text_name)
|
|
@@ -1488,10 +1472,10 @@ class DIET(TransformerRasaModel):
|
|
|
1488
1472
|
{SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
|
|
1489
1473
|
)
|
|
1490
1474
|
|
|
1491
|
-
self._tf_layers[
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1475
|
+
self._tf_layers[f"feature_combining_layer.{self.label_name}"] = (
|
|
1476
|
+
rasa_layers.RasaFeatureCombiningLayer(
|
|
1477
|
+
self.label_name, self.label_signature[self.label_name], label_config
|
|
1478
|
+
)
|
|
1495
1479
|
)
|
|
1496
1480
|
|
|
1497
1481
|
self._prepare_ffnn_layer(
|
|
@@ -1523,7 +1507,6 @@ class DIET(TransformerRasaModel):
|
|
|
1523
1507
|
sequence_feature_lengths: tf.Tensor,
|
|
1524
1508
|
name: Text,
|
|
1525
1509
|
) -> tf.Tensor:
|
|
1526
|
-
|
|
1527
1510
|
x, _ = self._tf_layers[f"feature_combining_layer.{name}"](
|
|
1528
1511
|
(sequence_features, sentence_features, sequence_feature_lengths),
|
|
1529
1512
|
training=self._training,
|
|
@@ -1705,7 +1688,6 @@ class DIET(TransformerRasaModel):
|
|
|
1705
1688
|
return loss
|
|
1706
1689
|
|
|
1707
1690
|
def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
|
|
1708
|
-
|
|
1709
1691
|
self.intent_loss.update_state(loss)
|
|
1710
1692
|
self.intent_acc.update_state(acc)
|
|
1711
1693
|
|
|
@@ -1864,7 +1846,6 @@ class DIET(TransformerRasaModel):
|
|
|
1864
1846
|
combined_sequence_sentence_feature_lengths: tf.Tensor,
|
|
1865
1847
|
text_transformed: tf.Tensor,
|
|
1866
1848
|
) -> Dict[Text, tf.Tensor]:
|
|
1867
|
-
|
|
1868
1849
|
if self.all_labels_embed is None:
|
|
1869
1850
|
raise ValueError(
|
|
1870
1851
|
"The model was not prepared for prediction. "
|
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
from typing import Any, Text, Dict, List, Type, Tuple
|
|
2
2
|
|
|
3
|
+
import joblib
|
|
3
4
|
import structlog
|
|
4
5
|
from scipy.sparse import hstack, vstack, csr_matrix
|
|
5
6
|
from sklearn.exceptions import NotFittedError
|
|
6
7
|
from sklearn.linear_model import LogisticRegression
|
|
7
8
|
from sklearn.utils.validation import check_is_fitted
|
|
8
9
|
|
|
9
|
-
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
10
|
-
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
11
10
|
from rasa.engine.storage.resource import Resource
|
|
12
11
|
from rasa.engine.storage.storage import ModelStorage
|
|
12
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
13
14
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
14
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
15
15
|
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
16
|
-
from rasa.
|
|
17
|
-
from rasa.shared.nlu.training_data.message import Message
|
|
16
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
18
17
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
18
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
19
|
+
from rasa.shared.nlu.constants import TEXT, INTENT
|
|
19
20
|
from rasa.utils.tensorflow.constants import RANKING_LENGTH
|
|
20
21
|
|
|
21
22
|
structlogger = structlog.get_logger()
|
|
@@ -183,11 +184,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
183
184
|
|
|
184
185
|
def persist(self) -> None:
|
|
185
186
|
"""Persist this model into the passed directory."""
|
|
186
|
-
import skops.io as sio
|
|
187
|
-
|
|
188
187
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
189
|
-
path = model_dir / f"{self._resource.name}.
|
|
190
|
-
|
|
188
|
+
path = model_dir / f"{self._resource.name}.joblib"
|
|
189
|
+
joblib.dump(self.clf, path)
|
|
191
190
|
structlogger.debug(
|
|
192
191
|
"logistic_regression_classifier.persist",
|
|
193
192
|
event_info=f"Saved intent classifier to '{path}'.",
|
|
@@ -203,21 +202,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
203
202
|
**kwargs: Any,
|
|
204
203
|
) -> "LogisticRegressionClassifier":
|
|
205
204
|
"""Loads trained component (see parent class for full docstring)."""
|
|
206
|
-
import skops.io as sio
|
|
207
|
-
|
|
208
205
|
try:
|
|
209
206
|
with model_storage.read_from(resource) as model_dir:
|
|
210
|
-
|
|
211
|
-
unknown_types = sio.get_untrusted_types(file=classifier_file)
|
|
212
|
-
|
|
213
|
-
if unknown_types:
|
|
214
|
-
structlogger.error(
|
|
215
|
-
f"Untrusted types found when loading {classifier_file}!",
|
|
216
|
-
unknown_types=unknown_types,
|
|
217
|
-
)
|
|
218
|
-
raise ValueError()
|
|
219
|
-
|
|
220
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
207
|
+
classifier = joblib.load(model_dir / f"{resource.name}.joblib")
|
|
221
208
|
component = cls(
|
|
222
209
|
config, execution_context.node_name, model_storage, resource
|
|
223
210
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import logging
|
|
3
|
+
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
4
4
|
import typing
|
|
5
5
|
import warnings
|
|
6
6
|
from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
@@ -8,18 +8,18 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
import rasa.shared.utils.io
|
|
11
|
+
import rasa.utils.io as io_utils
|
|
11
12
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
12
13
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
14
|
from rasa.engine.storage.resource import Resource
|
|
14
15
|
from rasa.engine.storage.storage import ModelStorage
|
|
15
|
-
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
16
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
17
|
-
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
18
16
|
from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
|
|
17
|
+
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
19
18
|
from rasa.shared.exceptions import RasaException
|
|
20
19
|
from rasa.shared.nlu.constants import TEXT
|
|
21
|
-
from rasa.
|
|
20
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
22
21
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
22
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
23
23
|
from rasa.utils.tensorflow.constants import FEATURIZERS
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
@@ -266,20 +266,14 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
266
266
|
|
|
267
267
|
def persist(self) -> None:
|
|
268
268
|
"""Persist this model into the passed directory."""
|
|
269
|
-
import skops.io as sio
|
|
270
|
-
|
|
271
269
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
272
270
|
file_name = self.__class__.__name__
|
|
273
|
-
classifier_file_name = model_dir / f"{file_name}_classifier.
|
|
274
|
-
encoder_file_name = model_dir / f"{file_name}_encoder.
|
|
271
|
+
classifier_file_name = model_dir / f"{file_name}_classifier.pkl"
|
|
272
|
+
encoder_file_name = model_dir / f"{file_name}_encoder.pkl"
|
|
275
273
|
|
|
276
274
|
if self.clf and self.le:
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
280
|
-
encoder_file_name, list(self.le.classes_)
|
|
281
|
-
)
|
|
282
|
-
sio.dump(self.clf.best_estimator_, classifier_file_name)
|
|
275
|
+
io_utils.json_pickle(encoder_file_name, self.le.classes_)
|
|
276
|
+
io_utils.json_pickle(classifier_file_name, self.clf.best_estimator_)
|
|
283
277
|
|
|
284
278
|
@classmethod
|
|
285
279
|
def load(
|
|
@@ -292,36 +286,21 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
292
286
|
) -> SklearnIntentClassifier:
|
|
293
287
|
"""Loads trained component (see parent class for full docstring)."""
|
|
294
288
|
from sklearn.preprocessing import LabelEncoder
|
|
295
|
-
import skops.io as sio
|
|
296
289
|
|
|
297
290
|
try:
|
|
298
291
|
with model_storage.read_from(resource) as model_dir:
|
|
299
292
|
file_name = cls.__name__
|
|
300
|
-
classifier_file = model_dir / f"{file_name}_classifier.
|
|
293
|
+
classifier_file = model_dir / f"{file_name}_classifier.pkl"
|
|
301
294
|
|
|
302
295
|
if classifier_file.exists():
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
if unknown_types:
|
|
306
|
-
logger.error(
|
|
307
|
-
f"Untrusted types ({unknown_types}) found when "
|
|
308
|
-
f"loading {classifier_file}!"
|
|
309
|
-
)
|
|
310
|
-
raise ValueError()
|
|
311
|
-
else:
|
|
312
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
313
|
-
|
|
314
|
-
encoder_file = model_dir / f"{file_name}_encoder.json"
|
|
315
|
-
classes = rasa.shared.utils.io.read_json_file(encoder_file)
|
|
296
|
+
classifier = io_utils.json_unpickle(classifier_file)
|
|
316
297
|
|
|
298
|
+
encoder_file = model_dir / f"{file_name}_encoder.pkl"
|
|
299
|
+
classes = io_utils.json_unpickle(encoder_file)
|
|
317
300
|
encoder = LabelEncoder()
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
)
|
|
321
|
-
# convert list of strings (class labels) back to numpy array of
|
|
322
|
-
# strings
|
|
323
|
-
intent_classifier.transform_labels_str2num(classes)
|
|
324
|
-
return intent_classifier
|
|
301
|
+
encoder.classes_ = classes
|
|
302
|
+
|
|
303
|
+
return cls(config, model_storage, resource, classifier, encoder)
|
|
325
304
|
except ValueError:
|
|
326
305
|
logger.debug(
|
|
327
306
|
f"Failed to load '{cls.__name__}' from model storage. Resource "
|
|
@@ -4,9 +4,9 @@ from collections import OrderedDict
|
|
|
4
4
|
from enum import Enum
|
|
5
5
|
import logging
|
|
6
6
|
import typing
|
|
7
|
-
from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
9
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
|
|
10
10
|
|
|
11
11
|
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
12
12
|
import rasa.shared.utils.io
|
|
@@ -41,9 +41,6 @@ if typing.TYPE_CHECKING:
|
|
|
41
41
|
from sklearn_crfsuite import CRF
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
CONFIG_FEATURES = "features"
|
|
45
|
-
|
|
46
|
-
|
|
47
44
|
class CRFToken:
|
|
48
45
|
def __init__(
|
|
49
46
|
self,
|
|
@@ -63,29 +60,6 @@ class CRFToken:
|
|
|
63
60
|
self.entity_role_tag = entity_role_tag
|
|
64
61
|
self.entity_group_tag = entity_group_tag
|
|
65
62
|
|
|
66
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
67
|
-
return {
|
|
68
|
-
"text": self.text,
|
|
69
|
-
"pos_tag": self.pos_tag,
|
|
70
|
-
"pattern": self.pattern,
|
|
71
|
-
"dense_features": [str(x) for x in list(self.dense_features)],
|
|
72
|
-
"entity_tag": self.entity_tag,
|
|
73
|
-
"entity_role_tag": self.entity_role_tag,
|
|
74
|
-
"entity_group_tag": self.entity_group_tag,
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
@classmethod
|
|
78
|
-
def create_from_dict(cls, data: Dict[str, Any]) -> "CRFToken":
|
|
79
|
-
return cls(
|
|
80
|
-
data["text"],
|
|
81
|
-
data["pos_tag"],
|
|
82
|
-
data["pattern"],
|
|
83
|
-
np.array([float(x) for x in data["dense_features"]]),
|
|
84
|
-
data["entity_tag"],
|
|
85
|
-
data["entity_role_tag"],
|
|
86
|
-
data["entity_group_tag"],
|
|
87
|
-
)
|
|
88
|
-
|
|
89
63
|
|
|
90
64
|
class CRFEntityExtractorOptions(str, Enum):
|
|
91
65
|
"""Features that can be used for the 'CRFEntityExtractor'."""
|
|
@@ -114,6 +88,8 @@ class CRFEntityExtractorOptions(str, Enum):
|
|
|
114
88
|
class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
115
89
|
"""Implements conditional random fields (CRF) to do named entity recognition."""
|
|
116
90
|
|
|
91
|
+
CONFIG_FEATURES = "features"
|
|
92
|
+
|
|
117
93
|
function_dict: Dict[Text, Callable[[CRFToken], Any]] = { # noqa: RUF012
|
|
118
94
|
CRFEntityExtractorOptions.LOW: lambda crf_token: crf_token.text.lower(),
|
|
119
95
|
CRFEntityExtractorOptions.TITLE: lambda crf_token: crf_token.text.istitle(),
|
|
@@ -132,7 +108,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
132
108
|
CRFEntityExtractorOptions.DIGIT: lambda crf_token: crf_token.text.isdigit(),
|
|
133
109
|
CRFEntityExtractorOptions.PATTERN: lambda crf_token: crf_token.pattern,
|
|
134
110
|
CRFEntityExtractorOptions.TEXT_DENSE_FEATURES: (
|
|
135
|
-
lambda crf_token: CRFEntityExtractor._convert_dense_features_for_crfsuite(
|
|
111
|
+
lambda crf_token: CRFEntityExtractor._convert_dense_features_for_crfsuite(
|
|
136
112
|
crf_token
|
|
137
113
|
)
|
|
138
114
|
),
|
|
@@ -161,7 +137,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
161
137
|
# "is the preceding token in title case?"
|
|
162
138
|
# POS features require SpacyTokenizer
|
|
163
139
|
# pattern feature require RegexFeaturizer
|
|
164
|
-
CONFIG_FEATURES: [
|
|
140
|
+
CRFEntityExtractor.CONFIG_FEATURES: [
|
|
165
141
|
[
|
|
166
142
|
CRFEntityExtractorOptions.LOW,
|
|
167
143
|
CRFEntityExtractorOptions.TITLE,
|
|
@@ -224,7 +200,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
224
200
|
)
|
|
225
201
|
|
|
226
202
|
def _validate_configuration(self) -> None:
|
|
227
|
-
if len(self.component_config.get(CONFIG_FEATURES, [])) % 2 != 1:
|
|
203
|
+
if len(self.component_config.get(self.CONFIG_FEATURES, [])) % 2 != 1:
|
|
228
204
|
raise ValueError(
|
|
229
205
|
"Need an odd number of crf feature lists to have a center word."
|
|
230
206
|
)
|
|
@@ -275,11 +251,9 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
275
251
|
]
|
|
276
252
|
dataset = [self._convert_to_crf_tokens(example) for example in entity_examples]
|
|
277
253
|
|
|
278
|
-
self.
|
|
279
|
-
dataset, self.component_config, self.crf_order
|
|
280
|
-
)
|
|
254
|
+
self._train_model(dataset)
|
|
281
255
|
|
|
282
|
-
self.persist(
|
|
256
|
+
self.persist()
|
|
283
257
|
|
|
284
258
|
return self._resource
|
|
285
259
|
|
|
@@ -325,9 +299,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
325
299
|
if include_tag_features:
|
|
326
300
|
self._add_tag_to_crf_token(crf_tokens, predictions)
|
|
327
301
|
|
|
328
|
-
features = self._crf_tokens_to_features(
|
|
329
|
-
crf_tokens, self.component_config, include_tag_features
|
|
330
|
-
)
|
|
302
|
+
features = self._crf_tokens_to_features(crf_tokens, include_tag_features)
|
|
331
303
|
predictions[tag_name] = entity_tagger.predict_marginals_single(features)
|
|
332
304
|
|
|
333
305
|
# convert predictions into a list of tags and a list of confidences
|
|
@@ -417,25 +389,27 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
417
389
|
**kwargs: Any,
|
|
418
390
|
) -> CRFEntityExtractor:
|
|
419
391
|
"""Loads trained component (see parent class for full docstring)."""
|
|
392
|
+
import joblib
|
|
393
|
+
|
|
420
394
|
try:
|
|
395
|
+
entity_taggers = OrderedDict()
|
|
421
396
|
with model_storage.read_from(resource) as model_dir:
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
]
|
|
397
|
+
# We have to load in the same order as we persisted things as otherwise
|
|
398
|
+
# the predictions might be off
|
|
399
|
+
file_names = sorted(model_dir.glob("**/*.pkl"))
|
|
400
|
+
if not file_names:
|
|
401
|
+
logger.debug(
|
|
402
|
+
"Failed to load model for 'CRFEntityExtractor'. "
|
|
403
|
+
"Maybe you did not provide enough training data and "
|
|
404
|
+
"no model was trained."
|
|
405
|
+
)
|
|
406
|
+
return cls(config, model_storage, resource)
|
|
433
407
|
|
|
434
|
-
|
|
408
|
+
for file_name in file_names:
|
|
409
|
+
name = file_name.stem[1:]
|
|
410
|
+
entity_taggers[name] = joblib.load(file_name)
|
|
435
411
|
|
|
436
|
-
|
|
437
|
-
entity_extractor.crf_order = crf_order
|
|
438
|
-
return entity_extractor
|
|
412
|
+
return cls(config, model_storage, resource, entity_taggers)
|
|
439
413
|
except ValueError:
|
|
440
414
|
logger.warning(
|
|
441
415
|
f"Failed to load {cls.__name__} from model storage. Resource "
|
|
@@ -443,29 +417,23 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
443
417
|
)
|
|
444
418
|
return cls(config, model_storage, resource)
|
|
445
419
|
|
|
446
|
-
def persist(self
|
|
420
|
+
def persist(self) -> None:
|
|
447
421
|
"""Persist this model into the passed directory."""
|
|
448
|
-
|
|
449
|
-
data_to_store = [
|
|
450
|
-
[token.to_dict() for token in sub_list] for sub_list in dataset
|
|
451
|
-
]
|
|
422
|
+
import joblib
|
|
452
423
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
424
|
+
with self._model_storage.write_to(self._resource) as model_dir:
|
|
425
|
+
if self.entity_taggers:
|
|
426
|
+
for idx, (name, entity_tagger) in enumerate(
|
|
427
|
+
self.entity_taggers.items()
|
|
428
|
+
):
|
|
429
|
+
model_file_name = model_dir / f"{idx}{name}.pkl"
|
|
430
|
+
joblib.dump(entity_tagger, model_file_name)
|
|
459
431
|
|
|
460
|
-
@classmethod
|
|
461
432
|
def _crf_tokens_to_features(
|
|
462
|
-
|
|
463
|
-
crf_tokens: List[CRFToken],
|
|
464
|
-
config: Dict[str, Any],
|
|
465
|
-
include_tag_features: bool = False,
|
|
433
|
+
self, crf_tokens: List[CRFToken], include_tag_features: bool = False
|
|
466
434
|
) -> List[Dict[Text, Any]]:
|
|
467
435
|
"""Convert the list of tokens into discrete features."""
|
|
468
|
-
configured_features =
|
|
436
|
+
configured_features = self.component_config[self.CONFIG_FEATURES]
|
|
469
437
|
sentence_features = []
|
|
470
438
|
|
|
471
439
|
for token_idx in range(len(crf_tokens)):
|
|
@@ -476,31 +444,28 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
476
444
|
half_window_size = window_size // 2
|
|
477
445
|
window_range = range(-half_window_size, half_window_size + 1)
|
|
478
446
|
|
|
479
|
-
token_features =
|
|
447
|
+
token_features = self._create_features_for_token(
|
|
480
448
|
crf_tokens,
|
|
481
449
|
token_idx,
|
|
482
450
|
half_window_size,
|
|
483
451
|
window_range,
|
|
484
452
|
include_tag_features,
|
|
485
|
-
config,
|
|
486
453
|
)
|
|
487
454
|
|
|
488
455
|
sentence_features.append(token_features)
|
|
489
456
|
|
|
490
457
|
return sentence_features
|
|
491
458
|
|
|
492
|
-
@classmethod
|
|
493
459
|
def _create_features_for_token(
|
|
494
|
-
|
|
460
|
+
self,
|
|
495
461
|
crf_tokens: List[CRFToken],
|
|
496
462
|
token_idx: int,
|
|
497
463
|
half_window_size: int,
|
|
498
464
|
window_range: range,
|
|
499
465
|
include_tag_features: bool,
|
|
500
|
-
config: Dict[str, Any],
|
|
501
466
|
) -> Dict[Text, Any]:
|
|
502
467
|
"""Convert a token into discrete features including words before and after."""
|
|
503
|
-
configured_features =
|
|
468
|
+
configured_features = self.component_config[self.CONFIG_FEATURES]
|
|
504
469
|
prefixes = [str(i) for i in window_range]
|
|
505
470
|
|
|
506
471
|
token_features = {}
|
|
@@ -540,13 +505,13 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
540
505
|
# set in the training data, 'matched' is either 'True' or
|
|
541
506
|
# 'False' depending on whether the token actually matches the
|
|
542
507
|
# pattern or not
|
|
543
|
-
regex_patterns =
|
|
508
|
+
regex_patterns = self.function_dict[feature](token)
|
|
544
509
|
for pattern_name, matched in regex_patterns.items():
|
|
545
|
-
token_features[
|
|
546
|
-
|
|
547
|
-
|
|
510
|
+
token_features[f"{prefix}:{feature}:{pattern_name}"] = (
|
|
511
|
+
matched
|
|
512
|
+
)
|
|
548
513
|
else:
|
|
549
|
-
value =
|
|
514
|
+
value = self.function_dict[feature](token)
|
|
550
515
|
token_features[f"{prefix}:{feature}"] = value
|
|
551
516
|
|
|
552
517
|
return token_features
|
|
@@ -670,46 +635,38 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
670
635
|
|
|
671
636
|
return tags
|
|
672
637
|
|
|
673
|
-
|
|
674
|
-
def train_model(
|
|
675
|
-
cls,
|
|
676
|
-
df_train: List[List[CRFToken]],
|
|
677
|
-
config: Dict[str, Any],
|
|
678
|
-
crf_order: List[str],
|
|
679
|
-
) -> OrderedDict[str, CRF]:
|
|
638
|
+
def _train_model(self, df_train: List[List[CRFToken]]) -> None:
|
|
680
639
|
"""Train the crf tagger based on the training data."""
|
|
681
640
|
import sklearn_crfsuite
|
|
682
641
|
|
|
683
|
-
entity_taggers = OrderedDict()
|
|
642
|
+
self.entity_taggers = OrderedDict()
|
|
684
643
|
|
|
685
|
-
for tag_name in crf_order:
|
|
644
|
+
for tag_name in self.crf_order:
|
|
686
645
|
logger.debug(f"Training CRF for '{tag_name}'.")
|
|
687
646
|
|
|
688
647
|
# add entity tag features for second level CRFs
|
|
689
648
|
include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
|
|
690
649
|
X_train = (
|
|
691
|
-
|
|
650
|
+
self._crf_tokens_to_features(sentence, include_tag_features)
|
|
692
651
|
for sentence in df_train
|
|
693
652
|
)
|
|
694
653
|
y_train = (
|
|
695
|
-
|
|
654
|
+
self._crf_tokens_to_tags(sentence, tag_name) for sentence in df_train
|
|
696
655
|
)
|
|
697
656
|
|
|
698
657
|
entity_tagger = sklearn_crfsuite.CRF(
|
|
699
658
|
algorithm="lbfgs",
|
|
700
659
|
# coefficient for L1 penalty
|
|
701
|
-
c1=
|
|
660
|
+
c1=self.component_config["L1_c"],
|
|
702
661
|
# coefficient for L2 penalty
|
|
703
|
-
c2=
|
|
662
|
+
c2=self.component_config["L2_c"],
|
|
704
663
|
# stop earlier
|
|
705
|
-
max_iterations=
|
|
664
|
+
max_iterations=self.component_config["max_iterations"],
|
|
706
665
|
# include transitions that are possible, but not observed
|
|
707
666
|
all_possible_transitions=True,
|
|
708
667
|
)
|
|
709
668
|
entity_tagger.fit(X_train, y_train)
|
|
710
669
|
|
|
711
|
-
entity_taggers[tag_name] = entity_tagger
|
|
670
|
+
self.entity_taggers[tag_name] = entity_tagger
|
|
712
671
|
|
|
713
672
|
logger.debug("Training finished.")
|
|
714
|
-
|
|
715
|
-
return entity_taggers
|