rasa-pro 3.11.0a4.dev2__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +4 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/channels/channel.py +3 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/jambonz.py +2 -2
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +177 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +16 -22
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +159 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +120 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +297 -7
- rasa/model_manager/config.py +17 -3
- rasa/model_manager/model_api.py +16 -8
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +7 -5
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +31 -30
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check.py +533 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +350 -46
- rasa/shared/utils/yaml.py +11 -13
- rasa/studio/upload.py +64 -20
- rasa/telemetry.py +80 -17
- rasa/tracing/instrumentation/attribute_extractors.py +74 -17
- rasa/utils/io.py +0 -66
- rasa/utils/log_utils.py +9 -2
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
- {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev2.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
rasa/engine/validation.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import inspect
|
|
3
|
-
import re
|
|
4
3
|
import logging
|
|
4
|
+
import re
|
|
5
5
|
import sys
|
|
6
6
|
import typing
|
|
7
7
|
from typing import (
|
|
@@ -31,9 +31,17 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
31
31
|
STICKY,
|
|
32
32
|
NON_STICKY,
|
|
33
33
|
)
|
|
34
|
+
from rasa.dialogue_understanding.coexistence.intent_based_router import (
|
|
35
|
+
IntentBasedRouter,
|
|
36
|
+
)
|
|
37
|
+
from rasa.dialogue_understanding.coexistence.llm_based_router import LLMBasedRouter
|
|
34
38
|
from rasa.dialogue_understanding.generator import (
|
|
35
39
|
LLMBasedCommandGenerator,
|
|
36
40
|
)
|
|
41
|
+
from rasa.dialogue_understanding.generator.constants import (
|
|
42
|
+
LLM_CONFIG_KEY,
|
|
43
|
+
FLOW_RETRIEVAL_KEY,
|
|
44
|
+
)
|
|
37
45
|
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
38
46
|
from rasa.engine.constants import RESERVED_PLACEHOLDERS
|
|
39
47
|
from rasa.engine.exceptions import GraphSchemaValidationException
|
|
@@ -47,18 +55,31 @@ from rasa.engine.graph import (
|
|
|
47
55
|
from rasa.engine.storage.resource import Resource
|
|
48
56
|
from rasa.engine.storage.storage import ModelStorage
|
|
49
57
|
from rasa.engine.training.fingerprinting import Fingerprintable
|
|
50
|
-
from rasa.shared.constants import
|
|
58
|
+
from rasa.shared.constants import (
|
|
59
|
+
DOCS_URL_GRAPH_COMPONENTS,
|
|
60
|
+
ROUTE_TO_CALM_SLOT,
|
|
61
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
62
|
+
API_BASE_CONFIG_KEY,
|
|
63
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
64
|
+
API_VERSION_CONFIG_KEY,
|
|
65
|
+
API_KEY,
|
|
66
|
+
AWS_REGION_NAME_CONFIG_KEY,
|
|
67
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
68
|
+
ROUTER_CONFIG_KEY,
|
|
69
|
+
MODELS_CONFIG_KEY,
|
|
70
|
+
ROUTER_STRATEGY_CONFIG_KEY,
|
|
71
|
+
VALID_ROUTER_STRATEGIES,
|
|
72
|
+
ROUTER_STRATEGIES_REQUIRING_REDIS_CACHE,
|
|
73
|
+
ROUTER_STRATEGIES_NOT_REQUIRING_CACHE,
|
|
74
|
+
REDIS_HOST_CONFIG_KEY,
|
|
75
|
+
)
|
|
51
76
|
from rasa.shared.core.constants import ACTION_RESET_ROUTING, ACTION_TRIGGER_CHITCHAT
|
|
52
77
|
from rasa.shared.core.domain import Domain
|
|
53
78
|
from rasa.shared.core.flows import FlowsList, Flow
|
|
54
79
|
from rasa.shared.core.slots import Slot
|
|
55
80
|
from rasa.shared.exceptions import RasaException
|
|
56
81
|
from rasa.shared.nlu.training_data.message import Message
|
|
57
|
-
|
|
58
|
-
from rasa.dialogue_understanding.coexistence.intent_based_router import (
|
|
59
|
-
IntentBasedRouter,
|
|
60
|
-
)
|
|
61
|
-
from rasa.dialogue_understanding.coexistence.llm_based_router import LLMBasedRouter
|
|
82
|
+
from rasa.shared.utils.cli import print_error_and_exit
|
|
62
83
|
|
|
63
84
|
TypeAnnotation = Union[TypeVar, Text, Type, Optional[AvailableEndpoints]]
|
|
64
85
|
|
|
@@ -845,6 +866,275 @@ def validate_coexistance_routing_setup(
|
|
|
845
866
|
)
|
|
846
867
|
|
|
847
868
|
|
|
869
|
+
def _validate_component_model_client_config(
|
|
870
|
+
component_config: Dict[str, Any],
|
|
871
|
+
key: str,
|
|
872
|
+
model_group_syntax_used: List[bool],
|
|
873
|
+
model_group_ids: List[str],
|
|
874
|
+
component_name: Optional[str] = None,
|
|
875
|
+
) -> None:
|
|
876
|
+
"""Validate the LLM configuration of a component.
|
|
877
|
+
|
|
878
|
+
Checks if the llm is defined using the new syntax or the old syntax.
|
|
879
|
+
If the new syntax is used, it checks that no other parameters are present.
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
component_config: The config of the component
|
|
883
|
+
key: either 'llm' or 'embeddings'
|
|
884
|
+
model_group_syntax_used:
|
|
885
|
+
list of booleans indicating whether the new syntax is used
|
|
886
|
+
model_group_ids: list of model group ids
|
|
887
|
+
component_name: the name of the component
|
|
888
|
+
"""
|
|
889
|
+
if key not in component_config:
|
|
890
|
+
# no llm configuration present
|
|
891
|
+
return
|
|
892
|
+
|
|
893
|
+
if MODELS_CONFIG_KEY in component_config[key]:
|
|
894
|
+
model_group_syntax_used.append(True)
|
|
895
|
+
model_group_ids.append(component_config[key][MODELS_CONFIG_KEY])
|
|
896
|
+
|
|
897
|
+
if len(component_config[key]) > 1:
|
|
898
|
+
print_error_and_exit(
|
|
899
|
+
f"You specified a '{MODELS_CONFIG_KEY}' for the '{key}' "
|
|
900
|
+
f"config key for the component "
|
|
901
|
+
f"'{component_name or component_config['name']}'. "
|
|
902
|
+
"No other parameters are allowed under the "
|
|
903
|
+
f"'{key}' key in that case. Please update your config."
|
|
904
|
+
)
|
|
905
|
+
else:
|
|
906
|
+
model_group_syntax_used.append(False)
|
|
907
|
+
|
|
908
|
+
# check that api_key is not set in config
|
|
909
|
+
if API_KEY in component_config[key]:
|
|
910
|
+
print_error_and_exit(
|
|
911
|
+
f"You specified '{API_KEY}' in the config for"
|
|
912
|
+
f"{component_name or component_config['name']}, which "
|
|
913
|
+
"is not allowed. Set API keys through "
|
|
914
|
+
"environment variables."
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
def validate_model_client_configuration_setup(config: Dict[str, Any]) -> None:
|
|
919
|
+
"""Validates the model client configuration setup.
|
|
920
|
+
|
|
921
|
+
Checks the model configuration of the components in the pipeline.
|
|
922
|
+
Validation fails, if
|
|
923
|
+
- the LLM/embeddings is/are defined using the old and the new syntax at
|
|
924
|
+
the same time (either at component level itself or across different components)
|
|
925
|
+
- the LLM/embeddings is/are defined using the new syntax, but no model
|
|
926
|
+
group is defined or the referenced model group does not exist
|
|
927
|
+
|
|
928
|
+
Args:
|
|
929
|
+
config: The config dictionary
|
|
930
|
+
"""
|
|
931
|
+
|
|
932
|
+
def is_uniform_bool_list(bool_list: List[bool]) -> bool:
|
|
933
|
+
# check if list contains only True or False
|
|
934
|
+
return all(bool_list) or not any(bool_list)
|
|
935
|
+
|
|
936
|
+
model_group_syntax_used: List[bool] = []
|
|
937
|
+
model_group_ids: List[str] = []
|
|
938
|
+
|
|
939
|
+
if "pipeline" not in config:
|
|
940
|
+
return
|
|
941
|
+
|
|
942
|
+
for component in config["pipeline"]:
|
|
943
|
+
for key in [LLM_CONFIG_KEY, EMBEDDINGS_CONFIG_KEY]:
|
|
944
|
+
_validate_component_model_client_config(
|
|
945
|
+
component, key, model_group_syntax_used, model_group_ids
|
|
946
|
+
)
|
|
947
|
+
|
|
948
|
+
# as flow retrieval is not a component itself, we need to
|
|
949
|
+
# check it separately
|
|
950
|
+
if FLOW_RETRIEVAL_KEY in component:
|
|
951
|
+
if EMBEDDINGS_CONFIG_KEY in component[FLOW_RETRIEVAL_KEY]:
|
|
952
|
+
_validate_component_model_client_config(
|
|
953
|
+
component[FLOW_RETRIEVAL_KEY],
|
|
954
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
955
|
+
model_group_syntax_used,
|
|
956
|
+
model_group_ids,
|
|
957
|
+
component["name"] + "." + FLOW_RETRIEVAL_KEY,
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
if not is_uniform_bool_list(model_group_syntax_used):
|
|
961
|
+
print_error_and_exit(
|
|
962
|
+
"Some of your components refer to an LLM using the "
|
|
963
|
+
f"'{MODELS_CONFIG_KEY}' parameter, other components directly"
|
|
964
|
+
f"define the LLM under the '{LLM_CONFIG_KEY}' or the "
|
|
965
|
+
f"'{EMBEDDINGS_CONFIG_KEY}' key. You cannot use"
|
|
966
|
+
"a both types of definition. Please chose one syntax "
|
|
967
|
+
"and update your config."
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
# Print a deprecation warning in case the old syntax is used.
|
|
971
|
+
if len(model_group_syntax_used) > 0 and model_group_syntax_used[0] is False:
|
|
972
|
+
structlogger.warning(
|
|
973
|
+
"validate_llm_configuration_setup",
|
|
974
|
+
event_info=(
|
|
975
|
+
"Defining the LLM configuration in the config.yml file itself is"
|
|
976
|
+
" deprecated and will be removed in Rasa 4.0.0. "
|
|
977
|
+
"Please use the new syntax and define your LLM configuration"
|
|
978
|
+
"in the endpoints.yml file."
|
|
979
|
+
),
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
endpoints = AvailableEndpoints.get_instance()
|
|
983
|
+
if len(model_group_ids) > 0 and endpoints.model_groups is None:
|
|
984
|
+
print_error_and_exit(
|
|
985
|
+
"You are referring to (a) model group(s) in your "
|
|
986
|
+
"config.yml file, but no model group was defined in "
|
|
987
|
+
"the endpoints.yml file. Please define the model "
|
|
988
|
+
"group(s)."
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
if endpoints.model_groups is None:
|
|
992
|
+
return
|
|
993
|
+
|
|
994
|
+
existing_model_group_ids = [
|
|
995
|
+
model_group[MODEL_GROUP_ID_CONFIG_KEY] for model_group in endpoints.model_groups
|
|
996
|
+
]
|
|
997
|
+
|
|
998
|
+
for model_group_id in model_group_ids:
|
|
999
|
+
if model_group_id not in existing_model_group_ids:
|
|
1000
|
+
print_error_and_exit(
|
|
1001
|
+
"One of your components is referring to the model group "
|
|
1002
|
+
f"'{model_group_id}', but this model group does not exist in the "
|
|
1003
|
+
f"endpoints.yml file. Please chose one of the existing "
|
|
1004
|
+
f"model groups ({existing_model_group_ids}) or define "
|
|
1005
|
+
f"the a model group for '{model_group_id}'."
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
def _validate_unique_model_group_ids(model_groups: List[Dict[str, Any]]) -> None:
|
|
1010
|
+
# Each model id must be unique within the model_groups
|
|
1011
|
+
model_ids = [model_group[MODEL_GROUP_ID_CONFIG_KEY] for model_group in model_groups]
|
|
1012
|
+
if len(model_ids) != len(set(model_ids)):
|
|
1013
|
+
print_error_and_exit(
|
|
1014
|
+
"Each model group id must be unique. Please make sure that "
|
|
1015
|
+
"the model group ids are unique in your endpoints.yml file."
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def _validate_model_group_with_multiple_models(
|
|
1020
|
+
model_groups: List[Dict[str, Any]],
|
|
1021
|
+
) -> None:
|
|
1022
|
+
# You cannot define multiple models within a model group, when no router is defined.
|
|
1023
|
+
for model_group in model_groups:
|
|
1024
|
+
if (
|
|
1025
|
+
len(model_group[MODELS_CONFIG_KEY]) > 1
|
|
1026
|
+
and ROUTER_CONFIG_KEY not in model_group
|
|
1027
|
+
):
|
|
1028
|
+
print_error_and_exit(
|
|
1029
|
+
f"You defined multiple models for the model group "
|
|
1030
|
+
f"'{model_group[MODEL_GROUP_ID_CONFIG_KEY]}', but no router. "
|
|
1031
|
+
f"If a model group contains "
|
|
1032
|
+
f"multiple models, a router must be defined. Please define a router "
|
|
1033
|
+
f"for the model group '{model_group[MODEL_GROUP_ID_CONFIG_KEY]}'."
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def _validate_model_group_router_setting(
|
|
1038
|
+
model_groups: List[Dict[str, Any]],
|
|
1039
|
+
) -> None:
|
|
1040
|
+
# You cannot define multiple models within a model group, when no router is defined.
|
|
1041
|
+
for model_group in model_groups:
|
|
1042
|
+
if ROUTER_CONFIG_KEY not in model_group:
|
|
1043
|
+
continue
|
|
1044
|
+
|
|
1045
|
+
router_config = model_group[ROUTER_CONFIG_KEY]
|
|
1046
|
+
if ROUTER_STRATEGY_CONFIG_KEY in router_config:
|
|
1047
|
+
router_strategy = router_config.get(ROUTER_STRATEGY_CONFIG_KEY)
|
|
1048
|
+
if router_strategy and router_strategy not in VALID_ROUTER_STRATEGIES:
|
|
1049
|
+
print_error_and_exit(
|
|
1050
|
+
f"The router strategy you defined for the model group "
|
|
1051
|
+
f"'{model_group[MODEL_GROUP_ID_CONFIG_KEY]}' is not valid. "
|
|
1052
|
+
f"Valid router strategies are categorized as follows:\n"
|
|
1053
|
+
f"- Strategies requiring Redis caching: "
|
|
1054
|
+
f"{', '.join(ROUTER_STRATEGIES_REQUIRING_REDIS_CACHE)}\n"
|
|
1055
|
+
f"- Strategies not requiring caching: "
|
|
1056
|
+
f"{', '.join(ROUTER_STRATEGIES_NOT_REQUIRING_CACHE)}"
|
|
1057
|
+
)
|
|
1058
|
+
if (
|
|
1059
|
+
router_strategy in ROUTER_STRATEGIES_REQUIRING_REDIS_CACHE
|
|
1060
|
+
and REDIS_HOST_CONFIG_KEY not in router_config
|
|
1061
|
+
):
|
|
1062
|
+
structlogger.warning(
|
|
1063
|
+
"validation.router_strategy.redis_host_not_defined",
|
|
1064
|
+
event_info=(
|
|
1065
|
+
f"The router strategy '{router_strategy}' requires a Redis host"
|
|
1066
|
+
f" to be defined. Without a Redis host, the system defaults to "
|
|
1067
|
+
f"'in-memory' caching. Please add the '{REDIS_HOST_CONFIG_KEY}'"
|
|
1068
|
+
f" to the router configuration for the model group "
|
|
1069
|
+
f"'{model_group[MODEL_GROUP_ID_CONFIG_KEY]}'."
|
|
1070
|
+
),
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
def _validate_usage_of_environment_variables_in_model_group_config(
|
|
1075
|
+
model_groups: List[Dict[str, Any]],
|
|
1076
|
+
) -> None:
|
|
1077
|
+
# Limit the use of ${env_var} in the model_groups config to the following variables:
|
|
1078
|
+
# deployment, api_base, api_key, api_version, aws_region_name
|
|
1079
|
+
allowed_env_vars = {
|
|
1080
|
+
DEPLOYMENT_CONFIG_KEY,
|
|
1081
|
+
API_BASE_CONFIG_KEY,
|
|
1082
|
+
API_KEY,
|
|
1083
|
+
API_VERSION_CONFIG_KEY,
|
|
1084
|
+
AWS_REGION_NAME_CONFIG_KEY,
|
|
1085
|
+
}
|
|
1086
|
+
|
|
1087
|
+
for model_group in model_groups:
|
|
1088
|
+
for model_config in model_group[MODELS_CONFIG_KEY]:
|
|
1089
|
+
for key, value in model_config.items():
|
|
1090
|
+
if isinstance(value, str):
|
|
1091
|
+
if re.match(r"\${(\w+)}", value) and key not in allowed_env_vars:
|
|
1092
|
+
print_error_and_exit(
|
|
1093
|
+
f"You defined '{key}' as environment variable in model "
|
|
1094
|
+
f"group '{model_group[MODEL_GROUP_ID_CONFIG_KEY]}', "
|
|
1095
|
+
f"which is not allowed. "
|
|
1096
|
+
f"You can only use environment variables for the following "
|
|
1097
|
+
f"keys: {', '.join(allowed_env_vars)}. "
|
|
1098
|
+
f"Please update your config."
|
|
1099
|
+
)
|
|
1100
|
+
|
|
1101
|
+
|
|
1102
|
+
def _validate_api_key_is_an_environment_variable(
|
|
1103
|
+
model_groups: List[Dict[str, Any]],
|
|
1104
|
+
) -> None:
|
|
1105
|
+
# the api key can only be set as an environment variable
|
|
1106
|
+
for model_group in model_groups:
|
|
1107
|
+
for model_config in model_group[MODELS_CONFIG_KEY]:
|
|
1108
|
+
for key, value in model_config.items():
|
|
1109
|
+
if (
|
|
1110
|
+
key == API_KEY
|
|
1111
|
+
and isinstance(value, str)
|
|
1112
|
+
and not re.match(r"\${(\w+)}", value)
|
|
1113
|
+
):
|
|
1114
|
+
print_error_and_exit(
|
|
1115
|
+
f"You defined the '{API_KEY}' in model group "
|
|
1116
|
+
f"'{model_group[MODEL_GROUP_ID_CONFIG_KEY]}' as a string. "
|
|
1117
|
+
f"The '{API_KEY}' must be set as an environment variable. "
|
|
1118
|
+
f"Please update your config."
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def validate_model_group_configuration_setup() -> None:
|
|
1123
|
+
"""Validates the model group configuration setup in endpoints.yml."""
|
|
1124
|
+
endpoints = AvailableEndpoints.get_instance()
|
|
1125
|
+
|
|
1126
|
+
if endpoints.model_groups is None:
|
|
1127
|
+
return
|
|
1128
|
+
|
|
1129
|
+
_validate_unique_model_group_ids(endpoints.model_groups)
|
|
1130
|
+
_validate_model_group_with_multiple_models(endpoints.model_groups)
|
|
1131
|
+
_validate_usage_of_environment_variables_in_model_group_config(
|
|
1132
|
+
endpoints.model_groups
|
|
1133
|
+
)
|
|
1134
|
+
_validate_api_key_is_an_environment_variable(endpoints.model_groups)
|
|
1135
|
+
_validate_model_group_router_setting(endpoints.model_groups)
|
|
1136
|
+
|
|
1137
|
+
|
|
848
1138
|
def validate_command_generator_exclusivity(schema: GraphSchema) -> None:
|
|
849
1139
|
"""Validate that multiple command generators are not defined at same time."""
|
|
850
1140
|
from rasa.dialogue_understanding.generator import (
|
rasa/model_manager/config.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import os
|
|
3
3
|
|
|
4
|
+
DEFAULT_SERVER_BASE_WORKING_DIRECTORY = "working-data"
|
|
5
|
+
|
|
4
6
|
SERVER_BASE_WORKING_DIRECTORY = os.environ.get(
|
|
5
|
-
"RASA_MODEL_SERVER_BASE_DIRECTORY",
|
|
7
|
+
"RASA_MODEL_SERVER_BASE_DIRECTORY", DEFAULT_SERVER_BASE_WORKING_DIRECTORY
|
|
6
8
|
)
|
|
7
9
|
|
|
8
10
|
SERVER_BASE_URL = os.environ.get("RASA_MODEL_SERVER_BASE_URL", None)
|
|
@@ -14,5 +16,17 @@ SERVER_MODEL_REMOTE_STORAGE = os.environ.get("RASA_REMOTE_STORAGE", None)
|
|
|
14
16
|
# we will use the same python to run training / bots
|
|
15
17
|
RASA_PYTHON_PATH = sys.executable
|
|
16
18
|
|
|
17
|
-
# the max limit for parallel training
|
|
18
|
-
|
|
19
|
+
# the max limit for parallel training requests
|
|
20
|
+
DEFAULT_MAX_PARALLEL_TRAININGS = 10
|
|
21
|
+
|
|
22
|
+
MAX_PARALLEL_TRAININGS = os.getenv(
|
|
23
|
+
"MAX_PARALLEL_TRAININGS", DEFAULT_MAX_PARALLEL_TRAININGS
|
|
24
|
+
)
|
|
25
|
+
# the max limit for parallel running bots
|
|
26
|
+
DEFAULT_MAX_PARALLEL_BOT_RUNS = 10
|
|
27
|
+
|
|
28
|
+
MAX_PARALLEL_BOT_RUNS = os.getenv(
|
|
29
|
+
"MAX_PARALLEL_BOT_RUNS", DEFAULT_MAX_PARALLEL_BOT_RUNS
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
DEFAULT_SERVER_PATH_PREFIX = "talk"
|
rasa/model_manager/model_api.py
CHANGED
|
@@ -16,6 +16,7 @@ from rasa.model_manager.config import SERVER_BASE_URL
|
|
|
16
16
|
from rasa.constants import MODEL_ARCHIVE_EXTENSION
|
|
17
17
|
from rasa.model_manager.runner_service import (
|
|
18
18
|
BotSession,
|
|
19
|
+
BotSessionStatus,
|
|
19
20
|
fetch_remote_model_to_dir,
|
|
20
21
|
run_bot,
|
|
21
22
|
terminate_bot,
|
|
@@ -24,11 +25,13 @@ from rasa.model_manager.runner_service import (
|
|
|
24
25
|
from rasa.model_manager.socket_bridge import create_bridge_server
|
|
25
26
|
from rasa.model_manager.trainer_service import (
|
|
26
27
|
TrainingSession,
|
|
28
|
+
TrainingSessionStatus,
|
|
27
29
|
run_training,
|
|
28
30
|
terminate_training,
|
|
29
31
|
update_training_status,
|
|
30
32
|
)
|
|
31
33
|
from rasa.model_manager.utils import (
|
|
34
|
+
InvalidPathException,
|
|
32
35
|
get_logs_content,
|
|
33
36
|
logs_base_path,
|
|
34
37
|
models_base_path,
|
|
@@ -91,7 +94,7 @@ def base_server_url(request: Request) -> str:
|
|
|
91
94
|
if SERVER_BASE_URL:
|
|
92
95
|
return SERVER_BASE_URL.rstrip("/")
|
|
93
96
|
else:
|
|
94
|
-
return f"{request.scheme}://{request.host}"
|
|
97
|
+
return f"{request.scheme}://{request.host}/{config.DEFAULT_SERVER_PATH_PREFIX}"
|
|
95
98
|
|
|
96
99
|
|
|
97
100
|
async def continuously_update_process_status() -> None:
|
|
@@ -134,7 +137,8 @@ def internal_blueprint() -> Blueprint:
|
|
|
134
137
|
[
|
|
135
138
|
training
|
|
136
139
|
for training in trainings.values()
|
|
137
|
-
if training.status ==
|
|
140
|
+
if training.status == TrainingSessionStatus.RUNNING
|
|
141
|
+
and training.process.poll() is None
|
|
138
142
|
]
|
|
139
143
|
)
|
|
140
144
|
|
|
@@ -152,7 +156,7 @@ def internal_blueprint() -> Blueprint:
|
|
|
152
156
|
@bp.on_request # type: ignore[misc]
|
|
153
157
|
async def limit_parallel_bot_runs(request: Request) -> Any:
|
|
154
158
|
"""Limit the number of parallel bot runs."""
|
|
155
|
-
from rasa.model_manager.config import
|
|
159
|
+
from rasa.model_manager.config import MAX_PARALLEL_BOT_RUNS
|
|
156
160
|
|
|
157
161
|
if not request.url.endswith("/bot"):
|
|
158
162
|
return None
|
|
@@ -161,15 +165,15 @@ def internal_blueprint() -> Blueprint:
|
|
|
161
165
|
[
|
|
162
166
|
bot
|
|
163
167
|
for bot in running_bots.values()
|
|
164
|
-
if bot.status in {
|
|
168
|
+
if bot.status in {BotSessionStatus.RUNNING, BotSessionStatus.QUEUED}
|
|
165
169
|
]
|
|
166
170
|
)
|
|
167
171
|
|
|
168
|
-
if running_requests >= int(
|
|
172
|
+
if running_requests >= int(MAX_PARALLEL_BOT_RUNS):
|
|
169
173
|
return response.json(
|
|
170
174
|
{
|
|
171
175
|
"message": f"Too many parallel bot runs, above "
|
|
172
|
-
f"the limit of {
|
|
176
|
+
f"the limit of {MAX_PARALLEL_BOT_RUNS}. "
|
|
173
177
|
f"Retry later or increase your server's "
|
|
174
178
|
f"memory and CPU resources."
|
|
175
179
|
},
|
|
@@ -244,6 +248,8 @@ def internal_blueprint() -> Blueprint:
|
|
|
244
248
|
return json({"message": "Training id is required"}, status=400)
|
|
245
249
|
|
|
246
250
|
try:
|
|
251
|
+
# file deepcode ignore PT: path traversal is prevented
|
|
252
|
+
# by the `subpath` function found in the `rasa.model_manager.utils` module
|
|
247
253
|
training_session = run_training(
|
|
248
254
|
training_id=training_id,
|
|
249
255
|
assistant_id=assistant_id,
|
|
@@ -254,8 +260,10 @@ def internal_blueprint() -> Blueprint:
|
|
|
254
260
|
return json(
|
|
255
261
|
{"training_id": training_id, "model_name": training_session.model_name}
|
|
256
262
|
)
|
|
257
|
-
except
|
|
258
|
-
return json({"message": str(
|
|
263
|
+
except InvalidPathException as exc:
|
|
264
|
+
return json({"message": str(exc)}, status=403)
|
|
265
|
+
except Exception as exc:
|
|
266
|
+
return json({"message": str(exc)}, status=500)
|
|
259
267
|
|
|
260
268
|
@bp.get("/training/<training_id>")
|
|
261
269
|
async def get_training(request: Request, training_id: str) -> response.HTTPResponse:
|
|
@@ -87,7 +87,7 @@ async def is_bot_startup_finished(bot: BotSession) -> bool:
|
|
|
87
87
|
return False
|
|
88
88
|
|
|
89
89
|
|
|
90
|
-
def
|
|
90
|
+
def set_bot_status_to_stopped(bot: BotSession) -> None:
|
|
91
91
|
"""Set a bots state to stopped."""
|
|
92
92
|
structlogger.info(
|
|
93
93
|
"model_runner.bot_stopped",
|
|
@@ -97,7 +97,7 @@ def update_bot_to_stopped(bot: BotSession) -> None:
|
|
|
97
97
|
bot.status = BotSessionStatus.STOPPED
|
|
98
98
|
|
|
99
99
|
|
|
100
|
-
def
|
|
100
|
+
def set_bot_status_to_running(bot: BotSession) -> None:
|
|
101
101
|
"""Set a bots state to running."""
|
|
102
102
|
structlogger.info(
|
|
103
103
|
"model_runner.bot_running",
|
|
@@ -119,7 +119,9 @@ def get_open_port() -> int:
|
|
|
119
119
|
return port
|
|
120
120
|
|
|
121
121
|
|
|
122
|
-
def
|
|
122
|
+
def write_encoded_config_data_to_files(
|
|
123
|
+
encoded_configs: Dict[str, bytes], base_path: str
|
|
124
|
+
) -> None:
|
|
123
125
|
"""Write the encoded config data to files."""
|
|
124
126
|
for key, value in encoded_configs.items():
|
|
125
127
|
write_encoded_data_to_file(value, subpath(base_path, f"{key}.yml"))
|
|
@@ -155,7 +157,7 @@ def prepare_bot_directory(
|
|
|
155
157
|
dst=subpath(bot_base_path, "models"),
|
|
156
158
|
)
|
|
157
159
|
|
|
158
|
-
|
|
160
|
+
write_encoded_config_data_to_files(encoded_configs, bot_base_path)
|
|
159
161
|
|
|
160
162
|
|
|
161
163
|
def fetch_remote_model_to_dir(
|
|
@@ -253,9 +255,9 @@ def run_bot(
|
|
|
253
255
|
async def update_bot_status(bot: BotSession) -> None:
|
|
254
256
|
"""Update the status of a bot based on the process return code."""
|
|
255
257
|
if bot.has_died_recently():
|
|
256
|
-
|
|
258
|
+
set_bot_status_to_stopped(bot)
|
|
257
259
|
elif await bot.completed_startup_recently():
|
|
258
|
-
|
|
260
|
+
set_bot_status_to_running(bot)
|
|
259
261
|
|
|
260
262
|
|
|
261
263
|
def terminate_bot(bot: BotSession) -> None:
|
|
@@ -84,6 +84,7 @@ def create_bridge_server(sio: AsyncServer, running_bots: Dict[str, BotSession])
|
|
|
84
84
|
|
|
85
85
|
@sio.on("disconnect")
|
|
86
86
|
async def disconnect(sid: str) -> None:
|
|
87
|
+
"""Disconnect the bot connection."""
|
|
87
88
|
structlogger.debug("model_runner.bot_disconnect", sid=sid)
|
|
88
89
|
if sid in socket_proxy_clients:
|
|
89
90
|
await socket_proxy_clients[sid].disconnect()
|
|
@@ -91,10 +92,12 @@ def create_bridge_server(sio: AsyncServer, running_bots: Dict[str, BotSession])
|
|
|
91
92
|
|
|
92
93
|
@sio.on("*")
|
|
93
94
|
async def handle_message(event: str, sid: str, data: Dict[str, Any]) -> None:
|
|
94
|
-
|
|
95
|
-
# send the response back to the client. both need to happen
|
|
96
|
-
# in parallel in an async way
|
|
95
|
+
""" "Bridge messages between user and bot.
|
|
97
96
|
|
|
97
|
+
Both incoming user messages to the bot_url and
|
|
98
|
+
bot responses sent back to the client need to
|
|
99
|
+
happen in parallel in an async way.
|
|
100
|
+
"""
|
|
98
101
|
client = socket_proxy_clients.get(sid)
|
|
99
102
|
if client is None:
|
|
100
103
|
structlogger.error("model_runner.bot_not_connected", sid=sid)
|
|
@@ -52,12 +52,12 @@ class TrainingSession(BaseModel):
|
|
|
52
52
|
|
|
53
53
|
def train_path(training_id: str) -> str:
|
|
54
54
|
"""Return the path to the training directory for a given training id."""
|
|
55
|
-
return subpath(config.SERVER_BASE_WORKING_DIRECTORY
|
|
55
|
+
return subpath(config.SERVER_BASE_WORKING_DIRECTORY + "/trainings", training_id)
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def cache_for_assistant_path(assistant_id: str) -> str:
|
|
59
59
|
"""Return the path to the cache directory for a given assistant id."""
|
|
60
|
-
return subpath(config.SERVER_BASE_WORKING_DIRECTORY
|
|
60
|
+
return subpath(config.SERVER_BASE_WORKING_DIRECTORY + "/caches", assistant_id)
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def terminate_training(training: TrainingSession) -> None:
|
|
@@ -132,8 +132,8 @@ def move_model_to_local_storage(training: TrainingSession) -> None:
|
|
|
132
132
|
ensure_base_directory_exists(models_base_path())
|
|
133
133
|
|
|
134
134
|
model_path = subpath(
|
|
135
|
-
train_path(training.training_id),
|
|
136
|
-
f"
|
|
135
|
+
train_path(training.training_id) + "/models",
|
|
136
|
+
f"{training.model_name}.{MODEL_ARCHIVE_EXTENSION}",
|
|
137
137
|
)
|
|
138
138
|
|
|
139
139
|
if os.path.exists(model_path):
|
|
@@ -217,9 +217,11 @@ def write_training_data_to_files(
|
|
|
217
217
|
}
|
|
218
218
|
|
|
219
219
|
for key, file in data_to_be_written_to_files.items():
|
|
220
|
+
parent_path, file_name = os.path.split(file)
|
|
221
|
+
|
|
220
222
|
write_encoded_data_to_file(
|
|
221
223
|
encoded_training_data.get(key, ""),
|
|
222
|
-
subpath(training_base_path,
|
|
224
|
+
subpath(training_base_path + "/" + parent_path, file_name),
|
|
223
225
|
)
|
|
224
226
|
|
|
225
227
|
|
rasa/model_manager/utils.py
CHANGED
|
@@ -2,7 +2,16 @@ import os
|
|
|
2
2
|
import base64
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
+
import structlog
|
|
6
|
+
|
|
5
7
|
from rasa.model_manager import config
|
|
8
|
+
from rasa.shared.exceptions import RasaException
|
|
9
|
+
|
|
10
|
+
structlogger = structlog.get_logger()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InvalidPathException(RasaException):
|
|
14
|
+
"""Raised if a path is invalid - e.g. path traversal is detected."""
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
def write_encoded_data_to_file(encoded_data: bytes, file: str) -> None:
|
|
@@ -17,7 +26,7 @@ def write_encoded_data_to_file(encoded_data: bytes, file: str) -> None:
|
|
|
17
26
|
|
|
18
27
|
|
|
19
28
|
def logs_base_path() -> str:
|
|
20
|
-
"""Return the path to the logs directory."""
|
|
29
|
+
"""Return the path to the logs' directory."""
|
|
21
30
|
return subpath(config.SERVER_BASE_WORKING_DIRECTORY, "logs")
|
|
22
31
|
|
|
23
32
|
|
|
@@ -31,7 +40,7 @@ def ensure_base_directory_exists(directory: str) -> None:
|
|
|
31
40
|
|
|
32
41
|
|
|
33
42
|
def models_base_path() -> str:
|
|
34
|
-
"""Return the path to the models directory."""
|
|
43
|
+
"""Return the path to the models' directory."""
|
|
35
44
|
return subpath(config.SERVER_BASE_WORKING_DIRECTORY, "models")
|
|
36
45
|
|
|
37
46
|
|
|
@@ -48,13 +57,24 @@ def subpath(parent: str, child: str) -> str:
|
|
|
48
57
|
"""Return the path to the child directory of the parent directory.
|
|
49
58
|
|
|
50
59
|
Ensures, that child doesn't navigate to parent directories. Prevents
|
|
51
|
-
path traversal.
|
|
60
|
+
path traversal. Raises an InvalidPathException if the path is invalid.
|
|
61
|
+
|
|
62
|
+
Based on Snyk's directory traversal mitigation:
|
|
63
|
+
https://learn.snyk.io/lesson/directory-traversal/
|
|
52
64
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
65
|
+
safe_path = os.path.abspath(os.path.join(parent, child))
|
|
66
|
+
parent = os.path.abspath(parent)
|
|
67
|
+
|
|
68
|
+
common_base = os.path.commonpath([parent, safe_path])
|
|
69
|
+
if common_base != parent:
|
|
70
|
+
raise InvalidPathException(f"Invalid path: {safe_path}")
|
|
71
|
+
|
|
72
|
+
if os.path.basename(safe_path) != child:
|
|
73
|
+
raise InvalidPathException(
|
|
74
|
+
f"Invalid path - path traversal detected: {safe_path}"
|
|
75
|
+
)
|
|
56
76
|
|
|
57
|
-
return
|
|
77
|
+
return safe_path
|
|
58
78
|
|
|
59
79
|
|
|
60
80
|
def get_logs_content(action_id: str) -> Optional[str]:
|
|
@@ -63,4 +83,5 @@ def get_logs_content(action_id: str) -> Optional[str]:
|
|
|
63
83
|
with open(logs_path(action_id), "r") as file:
|
|
64
84
|
return file.read()
|
|
65
85
|
except FileNotFoundError:
|
|
86
|
+
structlogger.debug("model_service.logs.not_found", action_id=action_id)
|
|
66
87
|
return None
|