rasa-pro 3.10.13a1__py3-none-any.whl → 3.10.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/api.py +1 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +1 -1
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/run.py +1 -1
- rasa/cli/studio/studio.py +18 -8
- rasa/cli/train.py +9 -0
- rasa/cli/x.py +1 -1
- rasa/core/policies/enterprise_search_policy.py +13 -1
- rasa/core/policies/flows/flow_executor.py +18 -8
- rasa/core/policies/intentless_policy.py +13 -1
- rasa/core/processor.py +7 -5
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +11 -1
- rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -1
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +8 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +8 -0
- rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
- rasa/e2e_test/assertions.py +48 -6
- rasa/e2e_test/e2e_test_runner.py +4 -3
- rasa/engine/validation.py +78 -1
- rasa/model_training.py +1 -0
- rasa/shared/constants.py +5 -0
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -1
- rasa/shared/providers/llm/_base_litellm_client.py +5 -1
- rasa/shared/utils/llm.py +28 -7
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +17 -8
- rasa/version.py +1 -1
- {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/METADATA +2 -2
- {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/RECORD +41 -41
- {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/entry_points.txt +0 -0
rasa/e2e_test/assertions.py
CHANGED
|
@@ -452,6 +452,11 @@ class ActionExecutedAssertion(Assertion):
|
|
|
452
452
|
**kwargs: Any,
|
|
453
453
|
) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
|
|
454
454
|
"""Run the action executed assertion on the given events for that user turn."""
|
|
455
|
+
step_index = kwargs.get("step_index")
|
|
456
|
+
original_turn_events, turn_events = _get_turn_events_based_on_step_index(
|
|
457
|
+
step_index, turn_events, prior_events
|
|
458
|
+
)
|
|
459
|
+
|
|
455
460
|
try:
|
|
456
461
|
matching_event = next(
|
|
457
462
|
event
|
|
@@ -464,7 +469,7 @@ class ActionExecutedAssertion(Assertion):
|
|
|
464
469
|
error_message += assertion_order_error_message
|
|
465
470
|
|
|
466
471
|
return self._generate_assertion_failure(
|
|
467
|
-
error_message, prior_events,
|
|
472
|
+
error_message, prior_events, original_turn_events, self.line
|
|
468
473
|
)
|
|
469
474
|
|
|
470
475
|
return None, matching_event
|
|
@@ -519,6 +524,11 @@ class SlotWasSetAssertion(Assertion):
|
|
|
519
524
|
"""Run the slot_was_set assertion on the given events for that user turn."""
|
|
520
525
|
matching_event = None
|
|
521
526
|
|
|
527
|
+
step_index = kwargs.get("step_index")
|
|
528
|
+
original_turn_events, turn_events = _get_turn_events_based_on_step_index(
|
|
529
|
+
step_index, turn_events, prior_events
|
|
530
|
+
)
|
|
531
|
+
|
|
522
532
|
for slot in self.slots:
|
|
523
533
|
matching_events = [
|
|
524
534
|
event
|
|
@@ -557,7 +567,7 @@ class SlotWasSetAssertion(Assertion):
|
|
|
557
567
|
error_message += assertion_order_error_message
|
|
558
568
|
|
|
559
569
|
return self._generate_assertion_failure(
|
|
560
|
-
error_message, prior_events,
|
|
570
|
+
error_message, prior_events, original_turn_events, slot.line
|
|
561
571
|
)
|
|
562
572
|
|
|
563
573
|
return None, matching_event
|
|
@@ -595,6 +605,11 @@ class SlotWasNotSetAssertion(Assertion):
|
|
|
595
605
|
"""Run the slot_was_not_set assertion on the given events for that user turn."""
|
|
596
606
|
matching_event = None
|
|
597
607
|
|
|
608
|
+
step_index = kwargs.get("step_index")
|
|
609
|
+
original_turn_events, turn_events = _get_turn_events_based_on_step_index(
|
|
610
|
+
step_index, turn_events, prior_events
|
|
611
|
+
)
|
|
612
|
+
|
|
598
613
|
for slot in self.slots:
|
|
599
614
|
matching_events = [
|
|
600
615
|
event
|
|
@@ -630,7 +645,7 @@ class SlotWasNotSetAssertion(Assertion):
|
|
|
630
645
|
error_message += assertion_order_error_message
|
|
631
646
|
|
|
632
647
|
return self._generate_assertion_failure(
|
|
633
|
-
error_message, prior_events,
|
|
648
|
+
error_message, prior_events, original_turn_events, slot.line
|
|
634
649
|
)
|
|
635
650
|
|
|
636
651
|
return None, matching_event
|
|
@@ -723,6 +738,11 @@ class BotUtteredAssertion(Assertion):
|
|
|
723
738
|
"""Run the bot_uttered assertion on the given events for that user turn."""
|
|
724
739
|
matching_event = None
|
|
725
740
|
|
|
741
|
+
step_index = kwargs.get("step_index")
|
|
742
|
+
original_turn_events, turn_events = _get_turn_events_based_on_step_index(
|
|
743
|
+
step_index, turn_events, prior_events
|
|
744
|
+
)
|
|
745
|
+
|
|
726
746
|
if self.utter_name is not None:
|
|
727
747
|
try:
|
|
728
748
|
matching_event = next(
|
|
@@ -736,7 +756,7 @@ class BotUtteredAssertion(Assertion):
|
|
|
736
756
|
error_message += assertion_order_error_message
|
|
737
757
|
|
|
738
758
|
return self._generate_assertion_failure(
|
|
739
|
-
error_message, prior_events,
|
|
759
|
+
error_message, prior_events, original_turn_events, self.line
|
|
740
760
|
)
|
|
741
761
|
|
|
742
762
|
if self.text_matches is not None:
|
|
@@ -756,7 +776,7 @@ class BotUtteredAssertion(Assertion):
|
|
|
756
776
|
error_message += assertion_order_error_message
|
|
757
777
|
|
|
758
778
|
return self._generate_assertion_failure(
|
|
759
|
-
error_message, prior_events,
|
|
779
|
+
error_message, prior_events, original_turn_events, self.line
|
|
760
780
|
)
|
|
761
781
|
|
|
762
782
|
if self.buttons:
|
|
@@ -772,7 +792,7 @@ class BotUtteredAssertion(Assertion):
|
|
|
772
792
|
)
|
|
773
793
|
error_message += assertion_order_error_message
|
|
774
794
|
return self._generate_assertion_failure(
|
|
775
|
-
error_message, prior_events,
|
|
795
|
+
error_message, prior_events, original_turn_events, self.line
|
|
776
796
|
)
|
|
777
797
|
|
|
778
798
|
return None, matching_event
|
|
@@ -1179,3 +1199,25 @@ def _find_matching_generative_events(turn_events: List[Event]) -> List[BotUttere
|
|
|
1179
1199
|
and event.metadata.get(UTTER_SOURCE_METADATA_KEY)
|
|
1180
1200
|
in ELIGIBLE_UTTER_SOURCE_METADATA
|
|
1181
1201
|
]
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
def _get_turn_events_based_on_step_index(
|
|
1205
|
+
step_index: int, turn_events: List[Event], prior_events: List[Event]
|
|
1206
|
+
) -> Tuple[List[Event], List[Event]]:
|
|
1207
|
+
"""Get the turn events based on the step index.
|
|
1208
|
+
|
|
1209
|
+
For the first step, we need to include the prior events as well
|
|
1210
|
+
in the same user turn. For the subsequent steps, we only need the
|
|
1211
|
+
events that follow the user uttered event on which the tracker
|
|
1212
|
+
was originally sliced by.
|
|
1213
|
+
|
|
1214
|
+
Returns:
|
|
1215
|
+
List[Event]: The copy of turn_events
|
|
1216
|
+
List[Event]: The turn events based on the step index
|
|
1217
|
+
|
|
1218
|
+
"""
|
|
1219
|
+
original_turn_events = turn_events[:]
|
|
1220
|
+
if step_index == 0:
|
|
1221
|
+
return original_turn_events, prior_events + turn_events
|
|
1222
|
+
|
|
1223
|
+
return original_turn_events, turn_events
|
rasa/e2e_test/e2e_test_runner.py
CHANGED
|
@@ -136,7 +136,7 @@ class E2ETestRunner:
|
|
|
136
136
|
return turns
|
|
137
137
|
|
|
138
138
|
tracker = await self.agent.processor.fetch_tracker_with_initial_session(
|
|
139
|
-
sender_id
|
|
139
|
+
sender_id, output_channel=collector
|
|
140
140
|
)
|
|
141
141
|
# turn -1 i used to contain events that happen during
|
|
142
142
|
# the start of the session and before the first user message
|
|
@@ -442,7 +442,7 @@ class E2ETestRunner:
|
|
|
442
442
|
assertion_failure_found = False
|
|
443
443
|
input_metadata = input_metadata if input_metadata else []
|
|
444
444
|
|
|
445
|
-
for step in test_case.steps:
|
|
445
|
+
for index, step in enumerate(test_case.steps):
|
|
446
446
|
if not step.assertions:
|
|
447
447
|
structlogger.debug(
|
|
448
448
|
"e2e_test_runner.run_assertions.no_assertions.skipping_step",
|
|
@@ -490,6 +490,7 @@ class E2ETestRunner:
|
|
|
490
490
|
assertion_order_error_message=assertion_order_error_msg,
|
|
491
491
|
llm_judge_config=self.llm_judge_config,
|
|
492
492
|
step_text=step.text,
|
|
493
|
+
step_index=index,
|
|
493
494
|
)
|
|
494
495
|
|
|
495
496
|
if assertion_failure:
|
|
@@ -826,7 +827,7 @@ class E2ETestRunner:
|
|
|
826
827
|
return
|
|
827
828
|
|
|
828
829
|
tracker = await self.agent.processor.fetch_tracker_with_initial_session(
|
|
829
|
-
sender_id
|
|
830
|
+
sender_id, output_channel=CollectingOutputChannel()
|
|
830
831
|
)
|
|
831
832
|
|
|
832
833
|
for fixture in fixtures:
|
rasa/engine/validation.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import (
|
|
|
16
16
|
Union,
|
|
17
17
|
TypeVar,
|
|
18
18
|
List,
|
|
19
|
+
Literal,
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
import structlog
|
|
@@ -34,6 +35,7 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
34
35
|
from rasa.dialogue_understanding.generator import (
|
|
35
36
|
LLMBasedCommandGenerator,
|
|
36
37
|
)
|
|
38
|
+
from rasa.dialogue_understanding.generator.constants import FLOW_RETRIEVAL_KEY
|
|
37
39
|
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
38
40
|
from rasa.engine.constants import RESERVED_PLACEHOLDERS
|
|
39
41
|
from rasa.engine.exceptions import GraphSchemaValidationException
|
|
@@ -47,7 +49,15 @@ from rasa.engine.graph import (
|
|
|
47
49
|
from rasa.engine.storage.resource import Resource
|
|
48
50
|
from rasa.engine.storage.storage import ModelStorage
|
|
49
51
|
from rasa.engine.training.fingerprinting import Fingerprintable
|
|
50
|
-
from rasa.shared.constants import
|
|
52
|
+
from rasa.shared.constants import (
|
|
53
|
+
DOCS_URL_GRAPH_COMPONENTS,
|
|
54
|
+
ROUTE_TO_CALM_SLOT,
|
|
55
|
+
API_TYPE_CONFIG_KEY,
|
|
56
|
+
VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY,
|
|
57
|
+
PROVIDER_CONFIG_KEY,
|
|
58
|
+
LLM_CONFIG_KEY,
|
|
59
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
60
|
+
)
|
|
51
61
|
from rasa.shared.core.constants import ACTION_RESET_ROUTING, ACTION_TRIGGER_CHITCHAT
|
|
52
62
|
from rasa.shared.core.domain import Domain
|
|
53
63
|
from rasa.shared.core.flows import FlowsList, Flow
|
|
@@ -871,3 +881,70 @@ def validate_command_generator_setup(
|
|
|
871
881
|
) -> None:
|
|
872
882
|
schema = model_configuration.predict_schema
|
|
873
883
|
validate_command_generator_exclusivity(schema)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
def validate_model_client_configuration_setup(config: Dict[str, Any]) -> None:
|
|
887
|
+
"""Validates the model client configuration setup.
|
|
888
|
+
|
|
889
|
+
Validation fails, if
|
|
890
|
+
- the LLM/embeddings provider is defined using 'api_type' key for providers other
|
|
891
|
+
than 'openai' or 'azure'
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
config: The config dictionary
|
|
895
|
+
"""
|
|
896
|
+
for outer_key in ["pipeline", "policies"]:
|
|
897
|
+
if outer_key not in config or config[outer_key] is None:
|
|
898
|
+
continue
|
|
899
|
+
|
|
900
|
+
for component_config in config[outer_key]:
|
|
901
|
+
for key in [LLM_CONFIG_KEY, EMBEDDINGS_CONFIG_KEY]:
|
|
902
|
+
validate_api_type_config_key_usage(component_config, key)
|
|
903
|
+
|
|
904
|
+
# as flow retrieval is not a component itself, we need to
|
|
905
|
+
# check it separately
|
|
906
|
+
if (
|
|
907
|
+
FLOW_RETRIEVAL_KEY in component_config
|
|
908
|
+
and EMBEDDINGS_CONFIG_KEY in component_config[FLOW_RETRIEVAL_KEY]
|
|
909
|
+
):
|
|
910
|
+
validate_api_type_config_key_usage(
|
|
911
|
+
component_config[FLOW_RETRIEVAL_KEY],
|
|
912
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
913
|
+
component_config["name"] + "." + FLOW_RETRIEVAL_KEY,
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
def validate_api_type_config_key_usage(
|
|
918
|
+
component_config: Dict[str, Any],
|
|
919
|
+
key: Literal["llm", "embeddings"],
|
|
920
|
+
component_name: Optional[str] = None,
|
|
921
|
+
) -> None:
|
|
922
|
+
"""Validate the LLM/embeddings configuration of a component.
|
|
923
|
+
|
|
924
|
+
Validation fails, if
|
|
925
|
+
- the LLM/embeddings provider is defined using 'api_type' key for providers other
|
|
926
|
+
than 'openai' or 'azure'
|
|
927
|
+
|
|
928
|
+
Args:
|
|
929
|
+
component_config: The config of the component
|
|
930
|
+
key: either 'llm' or 'embeddings'
|
|
931
|
+
component_name: the name of the component
|
|
932
|
+
"""
|
|
933
|
+
if component_config is None or key not in component_config:
|
|
934
|
+
return
|
|
935
|
+
|
|
936
|
+
if API_TYPE_CONFIG_KEY in component_config[key]:
|
|
937
|
+
api_type = component_config[key][API_TYPE_CONFIG_KEY]
|
|
938
|
+
if api_type not in VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY:
|
|
939
|
+
structlogger.error(
|
|
940
|
+
"validation.component.api_type_config_key_invalid",
|
|
941
|
+
event_info=(
|
|
942
|
+
f"You specified '{API_TYPE_CONFIG_KEY}: {api_type}' for "
|
|
943
|
+
f"'{component_name or component_config['name']}', which is not "
|
|
944
|
+
f"allowed. "
|
|
945
|
+
f"The '{API_TYPE_CONFIG_KEY}' key can only be used for the "
|
|
946
|
+
f"following providers: {VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY}. "
|
|
947
|
+
f"For other providers, please use the '{PROVIDER_CONFIG_KEY}' key."
|
|
948
|
+
),
|
|
949
|
+
)
|
|
950
|
+
sys.exit(1)
|
rasa/model_training.py
CHANGED
|
@@ -312,6 +312,7 @@ async def _train_graph(
|
|
|
312
312
|
rasa.engine.validation.validate_coexistance_routing_setup(
|
|
313
313
|
domain, model_configuration, flows
|
|
314
314
|
)
|
|
315
|
+
rasa.engine.validation.validate_model_client_configuration_setup(config)
|
|
315
316
|
rasa.engine.validation.validate_flow_component_dependencies(
|
|
316
317
|
flows, model_configuration
|
|
317
318
|
)
|
rasa/shared/constants.py
CHANGED
|
@@ -213,6 +213,11 @@ AZURE_OPENAI_PROVIDER = "azure"
|
|
|
213
213
|
SELF_HOSTED_PROVIDER = "self-hosted"
|
|
214
214
|
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER = "huggingface_local"
|
|
215
215
|
|
|
216
|
+
VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY = [
|
|
217
|
+
OPENAI_PROVIDER,
|
|
218
|
+
AZURE_OPENAI_PROVIDER,
|
|
219
|
+
]
|
|
220
|
+
|
|
216
221
|
SELF_HOSTED_VLLM_PREFIX = "hosted_vllm"
|
|
217
222
|
SELF_HOSTED_VLLM_API_KEY_ENV_VAR = "HOSTED_VLLM_API_KEY"
|
|
218
223
|
|
|
@@ -221,12 +221,16 @@ class FlowsList:
|
|
|
221
221
|
[f for f in self.underlying_flows if not f.is_startable_only_via_link()]
|
|
222
222
|
)
|
|
223
223
|
|
|
224
|
-
def available_slot_names(
|
|
224
|
+
def available_slot_names(
|
|
225
|
+
self, ask_before_filling: Optional[bool] = None
|
|
226
|
+
) -> Set[str]:
|
|
225
227
|
"""Get all slot names collected by flows."""
|
|
226
228
|
return {
|
|
227
229
|
step.collect
|
|
228
230
|
for flow in self.underlying_flows
|
|
229
231
|
for step in flow.get_collect_steps()
|
|
232
|
+
if ask_before_filling is None
|
|
233
|
+
or step.ask_before_filling == ask_before_filling
|
|
230
234
|
}
|
|
231
235
|
|
|
232
236
|
def available_custom_actions(self) -> Set[str]:
|
|
@@ -5,6 +5,8 @@ import litellm
|
|
|
5
5
|
import logging
|
|
6
6
|
import structlog
|
|
7
7
|
from litellm import aembedding, embedding, validate_environment
|
|
8
|
+
|
|
9
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY
|
|
8
10
|
from rasa.shared.exceptions import (
|
|
9
11
|
ProviderClientAPIException,
|
|
10
12
|
ProviderClientValidationError,
|
|
@@ -85,7 +87,10 @@ class _BaseLiteLLMEmbeddingClient:
|
|
|
85
87
|
|
|
86
88
|
def _validate_environment_variables(self) -> None:
|
|
87
89
|
"""Validate that the required environment variables are set."""
|
|
88
|
-
validation_info = validate_environment(
|
|
90
|
+
validation_info = validate_environment(
|
|
91
|
+
self._litellm_model_name,
|
|
92
|
+
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
93
|
+
)
|
|
89
94
|
if missing_environment_variables := validation_info.get(
|
|
90
95
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
|
|
91
96
|
):
|
|
@@ -9,6 +9,7 @@ from litellm import (
|
|
|
9
9
|
validate_environment,
|
|
10
10
|
)
|
|
11
11
|
|
|
12
|
+
from rasa.shared.constants import API_BASE_CONFIG_KEY
|
|
12
13
|
from rasa.shared.exceptions import (
|
|
13
14
|
ProviderClientAPIException,
|
|
14
15
|
ProviderClientValidationError,
|
|
@@ -102,7 +103,10 @@ class _BaseLiteLLMClient:
|
|
|
102
103
|
|
|
103
104
|
def _validate_environment_variables(self) -> None:
|
|
104
105
|
"""Validate that the required environment variables are set."""
|
|
105
|
-
validation_info = validate_environment(
|
|
106
|
+
validation_info = validate_environment(
|
|
107
|
+
self._litellm_model_name,
|
|
108
|
+
api_base=self._litellm_extra_parameters.get(API_BASE_CONFIG_KEY),
|
|
109
|
+
)
|
|
106
110
|
if missing_environment_variables := validation_info.get(
|
|
107
111
|
_VALIDATE_ENVIRONMENT_MISSING_KEYS_KEY
|
|
108
112
|
):
|
rasa/shared/utils/llm.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import sys
|
|
1
2
|
from functools import wraps
|
|
2
3
|
from typing import (
|
|
3
4
|
Any,
|
|
@@ -52,7 +53,6 @@ from rasa.shared.providers.mappings import (
|
|
|
52
53
|
HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
|
|
53
54
|
get_client_config_class_from_provider,
|
|
54
55
|
)
|
|
55
|
-
from rasa.shared.utils.cli import print_error_and_exit
|
|
56
56
|
|
|
57
57
|
if TYPE_CHECKING:
|
|
58
58
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
@@ -418,12 +418,33 @@ def try_instantiate_llm_client(
|
|
|
418
418
|
except (ProviderClientValidationError, ValueError) as e:
|
|
419
419
|
structlogger.error(
|
|
420
420
|
f"{log_source_function}.llm_instantiation_failed",
|
|
421
|
-
|
|
421
|
+
event_info=(
|
|
422
|
+
f"Unable to create the LLM client for component - "
|
|
423
|
+
f"{log_source_component}. Please make sure you specified the required "
|
|
424
|
+
f"environment variables and configuration keys."
|
|
425
|
+
),
|
|
422
426
|
error=e,
|
|
423
427
|
)
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
428
|
+
sys.exit(1)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def try_instantiate_embedder(
|
|
432
|
+
custom_embeddings_config: Optional[Dict],
|
|
433
|
+
default_embeddings_config: Optional[Dict],
|
|
434
|
+
log_source_function: str,
|
|
435
|
+
log_source_component: str,
|
|
436
|
+
) -> EmbeddingClient:
|
|
437
|
+
"""Validate embeddings configuration."""
|
|
438
|
+
try:
|
|
439
|
+
return embedder_factory(custom_embeddings_config, default_embeddings_config)
|
|
440
|
+
except (ProviderClientValidationError, ValueError) as e:
|
|
441
|
+
structlogger.error(
|
|
442
|
+
f"{log_source_function}.embedder_instantiation_failed",
|
|
443
|
+
event_info=(
|
|
444
|
+
f"Unable to create the Embedding client for component - "
|
|
445
|
+
f"{log_source_component}. Please make sure you specified the required "
|
|
446
|
+
f"environment variables and configuration keys."
|
|
447
|
+
),
|
|
448
|
+
error=e,
|
|
429
449
|
)
|
|
450
|
+
sys.exit(1)
|
rasa/studio/auth.py
CHANGED
|
@@ -23,12 +23,10 @@ from rasa.studio.results_logger import with_studio_error_handler, StudioResult
|
|
|
23
23
|
class StudioAuth:
|
|
24
24
|
"""Handles the authentication with the Rasa Studio authentication server."""
|
|
25
25
|
|
|
26
|
-
def __init__(
|
|
27
|
-
self,
|
|
28
|
-
studio_config: StudioConfig,
|
|
29
|
-
verify: bool = True,
|
|
30
|
-
) -> None:
|
|
26
|
+
def __init__(self, studio_config: StudioConfig) -> None:
|
|
31
27
|
self.config = studio_config
|
|
28
|
+
verify = not studio_config.disable_verify
|
|
29
|
+
|
|
32
30
|
self.keycloak_openid = KeycloakOpenID(
|
|
33
31
|
server_url=studio_config.authentication_server_url,
|
|
34
32
|
client_id=studio_config.client_id,
|
rasa/studio/config.py
CHANGED
|
@@ -2,13 +2,14 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
-
from typing import Dict, Optional, Text
|
|
5
|
+
from typing import Any, Dict, Optional, Text
|
|
6
6
|
|
|
7
7
|
from rasa.utils.common import read_global_config_value, write_global_config_value
|
|
8
8
|
|
|
9
9
|
from rasa.studio.constants import (
|
|
10
10
|
RASA_STUDIO_AUTH_SERVER_URL_ENV,
|
|
11
11
|
RASA_STUDIO_CLI_CLIENT_ID_KEY_ENV,
|
|
12
|
+
RASA_STUDIO_CLI_DISABLE_VERIFY_KEY_ENV,
|
|
12
13
|
RASA_STUDIO_CLI_REALM_NAME_KEY_ENV,
|
|
13
14
|
RASA_STUDIO_CLI_STUDIO_URL_ENV,
|
|
14
15
|
STUDIO_CONFIG_KEY,
|
|
@@ -19,6 +20,7 @@ STUDIO_URL_KEY = "studio_url"
|
|
|
19
20
|
CLIENT_ID_KEY = "client_id"
|
|
20
21
|
REALM_NAME_KEY = "realm_name"
|
|
21
22
|
CLIENT_SECRET_KEY = "client_secret"
|
|
23
|
+
DISABLE_VERIFY = "disable_verify"
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
@dataclass
|
|
@@ -27,13 +29,15 @@ class StudioConfig:
|
|
|
27
29
|
studio_url: Optional[Text]
|
|
28
30
|
client_id: Optional[Text]
|
|
29
31
|
realm_name: Optional[Text]
|
|
32
|
+
disable_verify: bool = False
|
|
30
33
|
|
|
31
|
-
def to_dict(self) -> Dict[Text, Optional[
|
|
34
|
+
def to_dict(self) -> Dict[Text, Optional[Any]]:
|
|
32
35
|
return {
|
|
33
36
|
AUTH_SERVER_URL_KEY: self.authentication_server_url,
|
|
34
37
|
STUDIO_URL_KEY: self.studio_url,
|
|
35
38
|
CLIENT_ID_KEY: self.client_id,
|
|
36
39
|
REALM_NAME_KEY: self.realm_name,
|
|
40
|
+
DISABLE_VERIFY: self.disable_verify,
|
|
37
41
|
}
|
|
38
42
|
|
|
39
43
|
@classmethod
|
|
@@ -43,6 +47,7 @@ class StudioConfig:
|
|
|
43
47
|
studio_url=data[STUDIO_URL_KEY],
|
|
44
48
|
client_id=data[CLIENT_ID_KEY],
|
|
45
49
|
realm_name=data[REALM_NAME_KEY],
|
|
50
|
+
disable_verify=data.get(DISABLE_VERIFY, False),
|
|
46
51
|
)
|
|
47
52
|
|
|
48
53
|
def write_config(self) -> None:
|
|
@@ -73,7 +78,7 @@ class StudioConfig:
|
|
|
73
78
|
config = read_global_config_value(STUDIO_CONFIG_KEY, unavailable_ok=True)
|
|
74
79
|
|
|
75
80
|
if config is None:
|
|
76
|
-
return StudioConfig(None, None, None, None)
|
|
81
|
+
return StudioConfig(None, None, None, None, False)
|
|
77
82
|
|
|
78
83
|
if not isinstance(config, dict):
|
|
79
84
|
raise ValueError(
|
|
@@ -83,7 +88,7 @@ class StudioConfig:
|
|
|
83
88
|
)
|
|
84
89
|
|
|
85
90
|
for key in config:
|
|
86
|
-
if not isinstance(config[key], str):
|
|
91
|
+
if not isinstance(config[key], str) and key != DISABLE_VERIFY:
|
|
87
92
|
raise ValueError(
|
|
88
93
|
"Invalid config file format. "
|
|
89
94
|
f"Key '{key}' is not a text value."
|
|
@@ -102,6 +107,9 @@ class StudioConfig:
|
|
|
102
107
|
studio_url=StudioConfig._read_env_value(RASA_STUDIO_CLI_STUDIO_URL_ENV),
|
|
103
108
|
client_id=StudioConfig._read_env_value(RASA_STUDIO_CLI_CLIENT_ID_KEY_ENV),
|
|
104
109
|
realm_name=StudioConfig._read_env_value(RASA_STUDIO_CLI_REALM_NAME_KEY_ENV),
|
|
110
|
+
disable_verify=bool(
|
|
111
|
+
os.getenv(RASA_STUDIO_CLI_DISABLE_VERIFY_KEY_ENV, False)
|
|
112
|
+
),
|
|
105
113
|
)
|
|
106
114
|
|
|
107
115
|
@staticmethod
|
|
@@ -124,4 +132,5 @@ class StudioConfig:
|
|
|
124
132
|
studio_url=self.studio_url or other.studio_url,
|
|
125
133
|
client_id=self.client_id or other.client_id,
|
|
126
134
|
realm_name=self.realm_name or other.realm_name,
|
|
135
|
+
disable_verify=self.disable_verify or other.disable_verify,
|
|
127
136
|
)
|
rasa/studio/constants.py
CHANGED
|
@@ -10,6 +10,7 @@ RASA_STUDIO_AUTH_SERVER_URL_ENV = "RASA_STUDIO_AUTH_SERVER_URL"
|
|
|
10
10
|
RASA_STUDIO_CLI_STUDIO_URL_ENV = "RASA_STUDIO_CLI_STUDIO_URL"
|
|
11
11
|
RASA_STUDIO_CLI_REALM_NAME_KEY_ENV = "RASA_STUDIO_CLI_REALM_NAME_KEY"
|
|
12
12
|
RASA_STUDIO_CLI_CLIENT_ID_KEY_ENV = "RASA_STUDIO_CLI_CLIENT_ID_KEY"
|
|
13
|
+
RASA_STUDIO_CLI_DISABLE_VERIFY_KEY_ENV = "RASA_STUDIO_CLI_DISABLE_VERIFY_KEY"
|
|
13
14
|
|
|
14
15
|
STUDIO_NLU_FILENAME = "studio_nlu.yml"
|
|
15
16
|
STUDIO_DOMAIN_FILENAME = "studio_domain.yml"
|
rasa/studio/data_handler.py
CHANGED
|
@@ -76,7 +76,9 @@ class StudioDataHandler:
|
|
|
76
76
|
|
|
77
77
|
return request
|
|
78
78
|
|
|
79
|
-
def _make_request(
|
|
79
|
+
def _make_request(
|
|
80
|
+
self, GQL_req: Dict[Any, Any], verify: bool = True
|
|
81
|
+
) -> Dict[Any, Any]:
|
|
80
82
|
token = KeycloakTokenReader().get_token()
|
|
81
83
|
if token.is_expired():
|
|
82
84
|
token = self.refresh_token(token)
|
|
@@ -93,6 +95,7 @@ class StudioDataHandler:
|
|
|
93
95
|
"Authorization": f"{token.token_type} {token.access_token}",
|
|
94
96
|
"Content-Type": "application/json",
|
|
95
97
|
},
|
|
98
|
+
verify=verify,
|
|
96
99
|
)
|
|
97
100
|
|
|
98
101
|
if res.status_code != 200:
|
|
@@ -128,7 +131,9 @@ class StudioDataHandler:
|
|
|
128
131
|
The data from Rasa Studio.
|
|
129
132
|
"""
|
|
130
133
|
GQL_req = self._build_request()
|
|
131
|
-
|
|
134
|
+
verify = not self.studio_config.disable_verify
|
|
135
|
+
|
|
136
|
+
response = self._make_request(GQL_req, verify=verify)
|
|
132
137
|
self._extract_data(response)
|
|
133
138
|
|
|
134
139
|
def request_data(
|
|
@@ -145,7 +150,9 @@ class StudioDataHandler:
|
|
|
145
150
|
The data from Rasa Studio.
|
|
146
151
|
"""
|
|
147
152
|
GQL_req = self._build_request(intent_names, entity_names)
|
|
148
|
-
|
|
153
|
+
verify = not self.studio_config.disable_verify
|
|
154
|
+
|
|
155
|
+
response = self._make_request(GQL_req, verify=verify)
|
|
149
156
|
self._extract_data(response)
|
|
150
157
|
|
|
151
158
|
def get_config(self) -> Optional[str]:
|
rasa/studio/upload.py
CHANGED
|
@@ -56,7 +56,10 @@ def _get_selected_entities_and_intents(
|
|
|
56
56
|
|
|
57
57
|
def handle_upload(args: argparse.Namespace) -> None:
|
|
58
58
|
"""Uploads primitives to rasa studio."""
|
|
59
|
-
|
|
59
|
+
studio_config = StudioConfig.read_config()
|
|
60
|
+
endpoint = studio_config.studio_url
|
|
61
|
+
verify = not studio_config.disable_verify
|
|
62
|
+
|
|
60
63
|
if not endpoint:
|
|
61
64
|
rasa.shared.utils.cli.print_error_and_exit(
|
|
62
65
|
"No GraphQL endpoint found in config. Please run `rasa studio config`."
|
|
@@ -76,9 +79,9 @@ def handle_upload(args: argparse.Namespace) -> None:
|
|
|
76
79
|
|
|
77
80
|
# check safely if args.calm is set and not fail if not
|
|
78
81
|
if hasattr(args, "calm") and args.calm:
|
|
79
|
-
upload_calm_assistant(args, endpoint)
|
|
82
|
+
upload_calm_assistant(args, endpoint, verify=verify)
|
|
80
83
|
else:
|
|
81
|
-
upload_nlu_assistant(args, endpoint)
|
|
84
|
+
upload_nlu_assistant(args, endpoint, verify=verify)
|
|
82
85
|
|
|
83
86
|
|
|
84
87
|
config_keys = [
|
|
@@ -126,7 +129,9 @@ def _get_assistant_name(config: Dict[Text, Any]) -> str:
|
|
|
126
129
|
|
|
127
130
|
|
|
128
131
|
@with_studio_error_handler
|
|
129
|
-
def upload_calm_assistant(
|
|
132
|
+
def upload_calm_assistant(
|
|
133
|
+
args: argparse.Namespace, endpoint: str, verify: bool = True
|
|
134
|
+
) -> StudioResult:
|
|
130
135
|
"""Uploads the CALM assistant data to Rasa Studio.
|
|
131
136
|
|
|
132
137
|
Args:
|
|
@@ -216,11 +221,13 @@ def upload_calm_assistant(args: argparse.Namespace, endpoint: str) -> StudioResu
|
|
|
216
221
|
)
|
|
217
222
|
|
|
218
223
|
structlogger.info("Uploading to Rasa Studio...")
|
|
219
|
-
return make_request(endpoint, graphql_req)
|
|
224
|
+
return make_request(endpoint, graphql_req, verify)
|
|
220
225
|
|
|
221
226
|
|
|
222
227
|
@with_studio_error_handler
|
|
223
|
-
def upload_nlu_assistant(
|
|
228
|
+
def upload_nlu_assistant(
|
|
229
|
+
args: argparse.Namespace, endpoint: str, verify: bool = True
|
|
230
|
+
) -> StudioResult:
|
|
224
231
|
"""Uploads the classic (dm1) assistant data to Rasa Studio.
|
|
225
232
|
|
|
226
233
|
Args:
|
|
@@ -268,15 +275,16 @@ def upload_nlu_assistant(args: argparse.Namespace, endpoint: str) -> StudioResul
|
|
|
268
275
|
graphql_req = build_request(assistant_name, nlu_examples_yaml, domain_yaml)
|
|
269
276
|
|
|
270
277
|
structlogger.info("Uploading to Rasa Studio...")
|
|
271
|
-
return make_request(endpoint, graphql_req)
|
|
278
|
+
return make_request(endpoint, graphql_req, verify)
|
|
272
279
|
|
|
273
280
|
|
|
274
|
-
def make_request(endpoint: str, graphql_req: Dict) -> StudioResult:
|
|
281
|
+
def make_request(endpoint: str, graphql_req: Dict, verify: bool = True) -> StudioResult:
|
|
275
282
|
"""Makes a request to the studio endpoint to upload data.
|
|
276
283
|
|
|
277
284
|
Args:
|
|
278
285
|
endpoint: The studio endpoint
|
|
279
286
|
graphql_req: The graphql request
|
|
287
|
+
verify: Whether to verify SSL
|
|
280
288
|
"""
|
|
281
289
|
token = KeycloakTokenReader().get_token()
|
|
282
290
|
res = requests.post(
|
|
@@ -286,6 +294,7 @@ def make_request(endpoint: str, graphql_req: Dict) -> StudioResult:
|
|
|
286
294
|
"Authorization": f"{token.token_type} {token.access_token}",
|
|
287
295
|
"Content-Type": "application/json",
|
|
288
296
|
},
|
|
297
|
+
verify=verify,
|
|
289
298
|
)
|
|
290
299
|
|
|
291
300
|
if results_logger.response_has_errors(res.json()):
|
rasa/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: rasa-pro
|
|
3
|
-
Version: 3.10.
|
|
3
|
+
Version: 3.10.15
|
|
4
4
|
Summary: State-of-the-art open-core Conversational AI framework for Enterprises that natively leverages generative AI for effortless assistant development.
|
|
5
5
|
Home-page: https://rasa.com
|
|
6
6
|
Keywords: nlp,machine-learning,machine-learning-library,bot,bots,botkit,rasa conversational-agents,conversational-ai,chatbot,chatbot-framework,bot-framework
|
|
@@ -68,7 +68,7 @@ Requires-Dist: mattermostwrapper (>=2.2,<2.3)
|
|
|
68
68
|
Requires-Dist: mlflow (>=2.15.1,<3.0.0) ; extra == "mlflow"
|
|
69
69
|
Requires-Dist: networkx (>=3.1,<3.2)
|
|
70
70
|
Requires-Dist: numpy (>=1.23.5,<1.25.0) ; python_version >= "3.9" and python_version < "3.11"
|
|
71
|
-
Requires-Dist: openai (>=1.
|
|
71
|
+
Requires-Dist: openai (>=1.55.3,<1.56.0)
|
|
72
72
|
Requires-Dist: openpyxl (>=3.1.5,<4.0.0)
|
|
73
73
|
Requires-Dist: opentelemetry-api (>=1.16.0,<1.17.0)
|
|
74
74
|
Requires-Dist: opentelemetry-exporter-jaeger (>=1.16.0,<1.17.0)
|