fastworkflow 2.15.8__py3-none-any.whl → 2.15.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fastworkflow might be problematic. Click here for more details.

@@ -1,643 +1,9 @@
1
- import contextlib
2
- from enum import Enum
3
- import sys
4
- from typing import Dict, List, Optional, Type, Union
5
- import json
6
- import os
7
- from collections import Counter
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
-
10
- from pydantic import BaseModel
11
- from pydantic_core import PydanticUndefined
12
- from speedict import Rdict
13
-
14
1
  import fastworkflow
15
- from fastworkflow.utils.logging import logger
16
- from fastworkflow import Action, CommandOutput, CommandResponse, ModuleType, NLUPipelineStage
17
- from fastworkflow.cache_matching import cache_match, store_utterance_cache
2
+ from fastworkflow import Action, CommandOutput, CommandResponse, NLUPipelineStage
18
3
  from fastworkflow.command_executor import CommandExecutor
19
- from fastworkflow.command_routing import RoutingDefinition
20
- import fastworkflow.command_routing
21
- from fastworkflow.model_pipeline_training import (
22
- predict_single_sentence,
23
- get_artifact_path,
24
- CommandRouter
25
- )
26
-
27
- from fastworkflow.train.generate_synthetic import generate_diverse_utterances
28
- from fastworkflow.utils.fuzzy_match import find_best_matches
29
- from fastworkflow.utils.signatures import InputForParamExtraction
30
-
31
-
32
- INVALID_INT_VALUE = -sys.maxsize
33
- INVALID_FLOAT_VALUE = -sys.float_info.max
34
-
35
- MISSING_INFORMATION_ERRMSG = fastworkflow.get_env_var("MISSING_INFORMATION_ERRMSG")
36
- INVALID_INFORMATION_ERRMSG = fastworkflow.get_env_var("INVALID_INFORMATION_ERRMSG")
37
-
38
- NOT_FOUND = fastworkflow.get_env_var("NOT_FOUND")
39
- INVALID = fastworkflow.get_env_var("INVALID")
40
- PARAMETER_EXTRACTION_ERROR_MSG = None
41
-
42
-
43
- # TODO - generation is deterministic. They all return the same answer
44
- # TODO - Need 'temperature' for intent detection pipeline
45
- def majority_vote_predictions(command_router, command: str, n_predictions: int = 5) -> list[str]:
46
- """
47
- Generate N prediction sets in parallel and return the set that wins the majority vote.
48
-
49
- This function improves prediction reliability by running multiple parallel predictions
50
- and selecting the most common result through majority voting. This helps reduce
51
- the impact of random variations in model predictions.
52
-
53
- Args:
54
- command_router: The CommandRouter instance to use for predictions
55
- command: The input command string
56
- n_predictions: Number of parallel predictions to generate (default: 5)
57
- Can be configured via N_PARALLEL_PREDICTIONS environment variable
58
-
59
- Returns:
60
- The prediction set that received the majority vote. Falls back to a single
61
- prediction if all parallel predictions fail.
62
-
63
- Note:
64
- Uses ThreadPoolExecutor with max_workers limited to min(n_predictions, 10)
65
- to avoid overwhelming the system with too many concurrent threads.
66
- """
67
- def get_single_prediction():
68
- """Helper function to get a single prediction"""
69
- return command_router.predict(command)
70
-
71
- # Generate N predictions in parallel
72
- prediction_sets = []
73
- with ThreadPoolExecutor(max_workers=min(n_predictions, 10)) as executor:
74
- # Submit all prediction tasks
75
- futures = [executor.submit(get_single_prediction) for _ in range(n_predictions)]
76
-
77
- # Collect results as they complete
78
- for future in as_completed(futures):
79
- try:
80
- prediction_set = future.result()
81
- prediction_sets.append(prediction_set)
82
- except Exception as e:
83
- logger.warning(f"Prediction failed: {e}")
84
- # Continue with other predictions even if one fails
85
-
86
- if not prediction_sets:
87
- # Fallback to single prediction if all parallel predictions failed
88
- logger.warning("All parallel predictions failed, falling back to single prediction")
89
- return command_router.predict(command)
90
-
91
- # Convert lists to tuples so they can be hashed and counted
92
- prediction_tuples = [tuple(sorted(pred_set)) for pred_set in prediction_sets]
93
-
94
- # Count occurrences of each unique prediction set
95
- vote_counts = Counter(prediction_tuples)
96
-
97
- # Get the prediction set with the most votes
98
- winning_tuple = vote_counts.most_common(1)[0][0]
99
-
100
- # Convert back to list and return
101
- return list(winning_tuple)
102
-
103
-
104
- class CommandNamePrediction:
105
- class Output(BaseModel):
106
- command_name: Optional[str] = None
107
- error_msg: Optional[str] = None
108
- is_cme_command: bool = False
109
-
110
- def __init__(self, cme_workflow: fastworkflow.Workflow):
111
- self.cme_workflow = cme_workflow
112
- self.app_workflow = cme_workflow.context["app_workflow"]
113
- self.app_workflow_folderpath = self.app_workflow.folderpath
114
- self.app_workflow_id = self.app_workflow.id
115
-
116
- self.convo_path = os.path.join(self.app_workflow_folderpath, "___convo_info")
117
- self.cache_path = self._get_cache_path(self.app_workflow_id, self.convo_path)
118
- self.path = self._get_cache_path_cache(self.convo_path)
119
-
120
- def predict(self, command_context_name: str, command: str, nlu_pipeline_stage: NLUPipelineStage) -> "CommandNamePrediction.Output":
121
- # sourcery skip: extract-duplicate-method
122
-
123
- model_artifact_path = f"{self.app_workflow_folderpath}/___command_info/{command_context_name}"
124
- command_router = CommandRouter(model_artifact_path)
125
-
126
- # Re-use the already-built ModelPipeline attached to the router
127
- # instead of instantiating a fresh one. This avoids reloading HF
128
- # checkpoints and transferring tensors each time we see a new
129
- # message for the same context.
130
- modelpipeline = command_router.modelpipeline
131
-
132
- crd = fastworkflow.RoutingRegistry.get_definition(
133
- self.cme_workflow.folderpath)
134
- cme_command_names = crd.get_command_names('IntentDetection')
135
-
136
- valid_command_names = set()
137
- if nlu_pipeline_stage == NLUPipelineStage.INTENT_AMBIGUITY_CLARIFICATION:
138
- valid_command_names = self._get_suggested_commands(self.path)
139
- elif nlu_pipeline_stage in (
140
- NLUPipelineStage.INTENT_DETECTION, NLUPipelineStage.INTENT_MISUNDERSTANDING_CLARIFICATION):
141
- app_crd = fastworkflow.RoutingRegistry.get_definition(
142
- self.app_workflow_folderpath)
143
- valid_command_names = (
144
- set(cme_command_names) |
145
- set(app_crd.get_command_names(command_context_name))
146
- )
147
-
148
- command_name_dict = {
149
- fully_qualified_command_name.split('/')[-1]: fully_qualified_command_name
150
- for fully_qualified_command_name in valid_command_names
151
- }
152
-
153
- if nlu_pipeline_stage == NLUPipelineStage.INTENT_AMBIGUITY_CLARIFICATION:
154
- # what_can_i_do is special in INTENT_AMBIGUITY_CLARIFICATION
155
- # We will not predict, just match plain utterances with exact or fuzzy match
156
- command_name_dict |= {
157
- plain_utterance: 'IntentDetection/what_can_i_do'
158
- for plain_utterance in crd.command_directory.map_command_2_utterance_metadata[
159
- 'IntentDetection/what_can_i_do'
160
- ].plain_utterances
161
- }
162
-
163
- if nlu_pipeline_stage != NLUPipelineStage.INTENT_DETECTION:
164
- # abort is special.
165
- # We will not predict, just match plain utterances with exact or fuzzy match
166
- command_name_dict |= {
167
- plain_utterance: 'ErrorCorrection/abort'
168
- for plain_utterance in crd.command_directory.map_command_2_utterance_metadata[
169
- 'ErrorCorrection/abort'
170
- ].plain_utterances
171
- }
172
-
173
- if nlu_pipeline_stage != NLUPipelineStage.INTENT_MISUNDERSTANDING_CLARIFICATION:
174
- # you_misunderstood is special.
175
- # We will not predict, just match plain utterances with exact or fuzzy match
176
- command_name_dict |= {
177
- plain_utterance: 'ErrorCorrection/you_misunderstood'
178
- for plain_utterance in crd.command_directory.map_command_2_utterance_metadata[
179
- 'ErrorCorrection/you_misunderstood'
180
- ].plain_utterances
181
- }
182
-
183
- # See if the command starts with a command name followed by a space
184
- tentative_command_name = command.split(" ", 1)[0]
185
- normalized_command_name = tentative_command_name.lower()
186
- command_name = None
187
- if normalized_command_name in command_name_dict:
188
- command_name = normalized_command_name
189
- command = command.replace(f"{tentative_command_name}", "").strip().replace(" ", " ")
190
- else:
191
- # Use Levenshtein distance for fuzzy matching with the full command part after @
192
- best_matched_commands, _ = find_best_matches(
193
- command.replace(" ", "_"),
194
- command_name_dict.keys(),
195
- threshold=0.3 # Adjust threshold as needed
196
- )
197
- if best_matched_commands:
198
- command_name = best_matched_commands[0]
199
-
200
- if nlu_pipeline_stage == NLUPipelineStage.INTENT_DETECTION:
201
- if not command_name:
202
- if cache_result := cache_match(self.path, command, modelpipeline, 0.85):
203
- command_name = cache_result
204
- else:
205
- predictions=command_router.predict(command)
206
- # predictions = majority_vote_predictions(command_router, command)
207
-
208
- if len(predictions)==1:
209
- command_name = predictions[0].split('/')[-1]
210
- else:
211
- # If confidence is low, treat as ambiguous command (type 1)
212
- error_msg = self._formulate_ambiguous_command_error_message(
213
- predictions, "run_as_agent" in self.app_workflow.context)
214
-
215
- # Store suggested commands
216
- self._store_suggested_commands(self.path, predictions, 1)
217
- return CommandNamePrediction.Output(error_msg=error_msg)
218
-
219
- elif nlu_pipeline_stage in (
220
- NLUPipelineStage.INTENT_AMBIGUITY_CLARIFICATION,
221
- NLUPipelineStage.INTENT_MISUNDERSTANDING_CLARIFICATION
222
- ) and not command_name:
223
- command_name = "what_can_i_do"
224
4
 
225
- if not command_name or command_name == "wildcard":
226
- fully_qualified_command_name=None
227
- is_cme_command=False
228
- else:
229
- fully_qualified_command_name = command_name_dict[command_name]
230
- is_cme_command=(
231
- fully_qualified_command_name in cme_command_names or
232
- fully_qualified_command_name in crd.get_command_names('ErrorCorrection')
233
- )
234
-
235
- if (
236
- nlu_pipeline_stage
237
- in (
238
- NLUPipelineStage.INTENT_AMBIGUITY_CLARIFICATION,
239
- NLUPipelineStage.INTENT_MISUNDERSTANDING_CLARIFICATION,
240
- )
241
- and not fully_qualified_command_name.endswith('abort')
242
- and not fully_qualified_command_name.endswith('what_can_i_do')
243
- and not fully_qualified_command_name.endswith('you_misunderstood')
244
- ):
245
- command = self.cme_workflow.context["command"]
246
- store_utterance_cache(self.path, command, command_name, modelpipeline)
247
-
248
- return CommandNamePrediction.Output(
249
- command_name=fully_qualified_command_name,
250
- is_cme_command=is_cme_command
251
- )
252
-
253
- @staticmethod
254
- def _get_cache_path(workflow_id, convo_path):
255
- """
256
- Generate cache file path based on workflow ID
257
- """
258
- base_dir = convo_path
259
- # Create directory if it doesn't exist
260
- os.makedirs(base_dir, exist_ok=True)
261
- return os.path.join(base_dir, f"{workflow_id}.db")
262
-
263
- @staticmethod
264
- def _get_cache_path_cache(convo_path):
265
- """
266
- Generate cache file path based on workflow ID
267
- """
268
- base_dir = convo_path
269
- # Create directory if it doesn't exist
270
- os.makedirs(base_dir, exist_ok=True)
271
- return os.path.join(base_dir, "cache.db")
272
-
273
- # Store the suggested commands with the flag type
274
- @staticmethod
275
- def _store_suggested_commands(cache_path, command_list, flag_type):
276
- """
277
- Store the list of suggested commands for the constrained selection
278
-
279
- Args:
280
- cache_path: Path to the cache database
281
- command_list: List of suggested commands
282
- flag_type: Type of constraint (1=ambiguous, 2=misclassified)
283
- """
284
- db = Rdict(cache_path)
285
- try:
286
- db["suggested_commands"] = command_list
287
- db["flag_type"] = flag_type
288
- finally:
289
- db.close()
290
-
291
- # Get the suggested commands
292
- @staticmethod
293
- def _get_suggested_commands(cache_path):
294
- """
295
- Get the list of suggested commands for the constrained selection
296
- """
297
- db = Rdict(cache_path)
298
- try:
299
- return db.get("suggested_commands", [])
300
- finally:
301
- db.close()
302
-
303
- @staticmethod
304
- def _get_count(cache_path):
305
- db = Rdict(cache_path)
306
- try:
307
- return db.get("utterance_count", 0) # Default to 0 if key doesn't exist
308
- finally:
309
- db.close()
310
-
311
- @staticmethod
312
- def _print_db_contents(cache_path):
313
- db = Rdict(cache_path)
314
- try:
315
- print("All keys in database:", list(db.keys()))
316
- for key in db.keys():
317
- print(f"Key: {key}, Value: {db[key]}")
318
- finally:
319
- db.close()
320
-
321
- @staticmethod
322
- def _store_utterance(cache_path, utterance, label):
323
- """
324
- Store utterance in existing or new database
325
- Returns: The utterance count used
326
- """
327
- # Open the database (creates if doesn't exist)
328
- db = Rdict(cache_path)
329
-
330
- try:
331
- # Get existing counter or initialize to 0
332
- utterance_count = db.get("utterance_count", 0)
333
-
334
- # Create and store the utterance entry
335
- utterance_data = {
336
- "utterance": utterance,
337
- "label": label
338
- }
339
-
340
- db[utterance_count] = utterance_data
341
-
342
- # Increment and store the counter
343
- utterance_count += 1
344
- db["utterance_count"] = utterance_count
345
-
346
- return utterance_count - 1 # Return the count used for this utterance
347
-
348
- finally:
349
- # Always close the database
350
- db.close()
351
-
352
- # Function to read from database
353
- @staticmethod
354
- def _read_utterance(cache_path, utterance_id):
355
- """
356
- Read a specific utterance from the database
357
- """
358
- db = Rdict(cache_path)
359
- try:
360
- return db.get(utterance_id)['utterance']
361
- finally:
362
- db.close()
363
-
364
- @staticmethod
365
- def _formulate_ambiguous_command_error_message(
366
- route_choice_list: list[str], run_as_agent: bool) -> str:
367
- command_list = (
368
- "\n".join([
369
- f"{route_choice.split('/')[-1].lower()}"
370
- for route_choice in route_choice_list if route_choice != 'wildcard'
371
- ])
372
- )
373
-
374
- return (
375
- "The command is ambiguous. "
376
- + (
377
- "Choose the correct command name from these possible options and update your command:\n"
378
- if run_as_agent
379
- else "Please choose a command name from these possible options:\n"
380
- )
381
- + f"{command_list}\n\nor type 'what can i do' to see all commands\n"
382
- + ("or type 'abort' to cancel" if run_as_agent else '')
383
- )
384
-
385
- class ParameterExtraction:
386
- class Output(BaseModel):
387
- parameters_are_valid: bool
388
- cmd_parameters: Optional[BaseModel] = None
389
- error_msg: Optional[str] = None
390
- suggestions: Optional[Dict[str, List[str]]] = None
391
-
392
- def __init__(self, cme_workflow: fastworkflow.Workflow, app_workflow: fastworkflow.Workflow, command_name: str, command: str):
393
- self.cme_workflow = cme_workflow
394
- self.app_workflow = app_workflow
395
- self.command_name = command_name
396
- self.command = command
397
-
398
- def extract(self) -> "ParameterExtraction.Output":
399
- app_workflow_folderpath = self.app_workflow.folderpath
400
- app_command_routing_definition = fastworkflow.RoutingRegistry.get_definition(app_workflow_folderpath)
401
-
402
- command_parameters_class = (
403
- app_command_routing_definition.get_command_class(
404
- self.command_name, ModuleType.COMMAND_PARAMETERS_CLASS
405
- )
406
- )
407
- if not command_parameters_class:
408
- return self.Output(parameters_are_valid=True)
409
-
410
- stored_params = self._get_stored_parameters(self.cme_workflow)
411
-
412
- self.command = self.command.replace(self.command_name, "").strip()
413
-
414
- input_for_param_extraction = InputForParamExtraction.create(
415
- self.app_workflow, self.command_name,
416
- self.command)
417
-
418
- if stored_params:
419
- _, _, _, stored_missing_fields = self._extract_missing_fields(input_for_param_extraction, self.app_workflow, self.command_name, stored_params)
420
- else:
421
- stored_missing_fields = []
422
-
423
- # If we have missing fields (in parameter extraction error state), try to apply the command directly
424
- if stored_missing_fields:
425
- # Apply the command directly as parameter values
426
- direct_params = self._apply_missing_fields(self.command, stored_params, stored_missing_fields)
427
- new_params = direct_params
428
- else:
429
- # Otherwise use the LLM-based extraction
430
- new_params = input_for_param_extraction.extract_parameters(
431
- command_parameters_class,
432
- self.command_name,
433
- app_workflow_folderpath)
434
-
435
- if stored_params:
436
- merged_params = self._merge_parameters(stored_params, new_params, stored_missing_fields)
437
- else:
438
- merged_params = new_params
439
-
440
- self._store_parameters(self.cme_workflow, merged_params)
441
-
442
- is_valid, error_msg, suggestions = input_for_param_extraction.validate_parameters(
443
- self.app_workflow, self.command_name, merged_params
444
- )
445
-
446
- if not is_valid:
447
- if params_str := self._format_parameters_for_display(merged_params):
448
- error_msg = f"Extracted parameters so far:\n{params_str}\n\n{error_msg}"
449
-
450
- if "run_as_agent" not in self.app_workflow.context:
451
- error_msg += "\nEnter 'abort' to get out of this error state and/or execute a different command."
452
- error_msg += "\nEnter 'you misunderstood' if the wrong command was executed."
453
- else:
454
- error_msg += "\nCheck your command name if the wrong command was executed."
455
- return self.Output(
456
- parameters_are_valid=False,
457
- error_msg=error_msg,
458
- cmd_parameters=merged_params,
459
- suggestions=suggestions)
460
-
461
- self._clear_parameters(self.cme_workflow)
462
- return self.Output(
463
- parameters_are_valid=True,
464
- cmd_parameters=merged_params)
465
-
466
- @staticmethod
467
- def _get_stored_parameters(cme_workflow: fastworkflow.Workflow):
468
- return cme_workflow.context.get("stored_parameters")
469
-
470
- @staticmethod
471
- def _store_parameters(cme_workflow: fastworkflow.Workflow, parameters):
472
- cme_workflow.context["stored_parameters"] = parameters
473
-
474
- @staticmethod
475
- def _clear_parameters(cme_workflow: fastworkflow.Workflow):
476
- if "stored_parameters" in cme_workflow.context:
477
- del cme_workflow.context["stored_parameters"]
478
-
479
- @staticmethod
480
- def _extract_missing_fields(input_for_param_extraction, sws, command_name, stored_params):
481
- stored_missing_fields = []
482
- is_valid, error_msg, suggestions = input_for_param_extraction.validate_parameters(
483
- sws, command_name, stored_params
484
- )
485
-
486
- if not is_valid:
487
- if MISSING_INFORMATION_ERRMSG in error_msg:
488
- missing_fields_str = error_msg.split(f"{MISSING_INFORMATION_ERRMSG}")[1].split("\n")[0]
489
- stored_missing_fields = [f.strip() for f in missing_fields_str.split(",")]
490
- if INVALID_INFORMATION_ERRMSG in error_msg:
491
- invalid_section = error_msg.split(f"{INVALID_INFORMATION_ERRMSG}")[1]
492
- if "\n" in invalid_section:
493
- invalid_fields_str = invalid_section.split("\n")[0]
494
- stored_missing_fields.extend(
495
- invalid_field.split(" '")[0].strip()
496
- for invalid_field in invalid_fields_str.split(", ")
497
- )
498
- return is_valid, error_msg, suggestions, stored_missing_fields
499
-
500
- @staticmethod
501
- def _merge_parameters(old_params, new_params, missing_fields):
502
- """
503
- Merge new parameters with old parameters, prioritizing new values when appropriate.
504
- """
505
- global PARAMETER_EXTRACTION_ERROR_MSG
506
- if not PARAMETER_EXTRACTION_ERROR_MSG:
507
- PARAMETER_EXTRACTION_ERROR_MSG = fastworkflow.get_env_var("PARAMETER_EXTRACTION_ERROR_MSG")
508
-
509
- merged = old_params.model_copy()
510
-
511
- try:
512
- all_fields = list(old_params.model_fields.keys())
513
- missing_fields = missing_fields or []
514
-
515
- for field_name in all_fields:
516
- if hasattr(new_params, field_name):
517
- new_value = getattr(new_params, field_name)
518
- old_value = getattr(merged, field_name)
519
-
520
- if new_value is not None and new_value != NOT_FOUND:
521
- if isinstance(old_value, str) and INVALID in old_value and INVALID not in new_value:
522
- setattr(merged, field_name, new_value)
523
-
524
- elif old_value is None or old_value == NOT_FOUND:
525
- setattr(merged, field_name, new_value)
526
-
527
- elif isinstance(old_value, int) and old_value == INVALID_INT_VALUE:
528
- with contextlib.suppress(ValueError, TypeError):
529
- setattr(merged, field_name, int(new_value))
530
-
531
- elif isinstance(old_value, float) and old_value == INVALID_FLOAT_VALUE:
532
- with contextlib.suppress(ValueError, TypeError):
533
- setattr(merged, field_name, float(new_value))
534
-
535
- elif (field_name in missing_fields and
536
- hasattr(merged.model_fields.get(field_name), "json_schema_extra") and
537
- merged.model_fields.get(field_name).json_schema_extra and
538
- "db_lookup" in merged.model_fields.get(field_name).json_schema_extra):
539
- setattr(merged, field_name, new_value)
540
-
541
- elif field_name in missing_fields:
542
- field_info = merged.model_fields.get(field_name)
543
- has_pattern = hasattr(field_info, "pattern") and field_info.pattern is not None
544
-
545
- if not has_pattern:
546
- for meta in getattr(field_info, "metadata", []):
547
- if hasattr(meta, "pattern"):
548
- has_pattern = True
549
- break
550
-
551
- if not has_pattern and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra:
552
- has_pattern = "pattern" in field_info.json_schema_extra
553
-
554
- if has_pattern:
555
- setattr(merged, field_name, new_value)
556
- except Exception as exc:
557
- logger.warning(PARAMETER_EXTRACTION_ERROR_MSG.format(error=exc))
558
-
559
- return merged
560
-
561
- @staticmethod
562
- def _format_parameters_for_display(params):
563
- """
564
- Format parameters for display in the error message.
565
- """
566
- if not params:
567
- return ""
568
-
569
- lines = []
570
-
571
- all_fields = list(params.model_fields.keys())
572
-
573
- for field_name in all_fields:
574
- value = getattr(params, field_name, None)
575
-
576
- if value in [
577
- NOT_FOUND,
578
- None,
579
- INVALID_INT_VALUE,
580
- INVALID_FLOAT_VALUE
581
- ]:
582
- continue
583
-
584
- display_name = " ".join(word.capitalize() for word in field_name.split('_'))
585
-
586
- # Format fields appropriately based on type
587
- if (
588
- isinstance(value, bool)
589
- or not hasattr(value, 'value')
590
- and isinstance(value, (int, float))
591
- or not hasattr(value, 'value')
592
- and isinstance(value, str)
593
- or not hasattr(value, 'value')
594
- ):
595
- lines.append(f"{display_name}: {value}")
596
- else: # Handle enum types
597
- lines.append(f"{display_name}: {value.value}")
598
- return "\n".join(lines)
599
-
600
- @staticmethod
601
- def _apply_missing_fields(command: str, default_params: BaseModel, missing_fields: list):
602
- global PARAMETER_EXTRACTION_ERROR_MSG
603
- if not PARAMETER_EXTRACTION_ERROR_MSG:
604
- PARAMETER_EXTRACTION_ERROR_MSG = fastworkflow.get_env_var("PARAMETER_EXTRACTION_ERROR_MSG")
605
-
606
- params = default_params.model_copy()
607
-
608
- try:
609
- if "," in command:
610
- parts = [part.strip() for part in command.split(",")]
611
-
612
- if len(parts) == len(missing_fields):
613
- if len(missing_fields) == 1:
614
- field = missing_fields[0]
615
- if hasattr(params, field):
616
- setattr(params, field, parts[0])
617
- return params
618
- elif len(missing_fields) > 1:
619
- for i, field in enumerate(missing_fields):
620
- if i < len(parts) and hasattr(params, field):
621
- setattr(params, field, parts[i])
622
- return params
623
- else:
624
- if parts and missing_fields:
625
- field = missing_fields[0]
626
- if hasattr(params, field):
627
- setattr(params, field, parts[0])
628
- return params
629
-
630
- elif missing_fields:
631
- field = missing_fields[0]
632
- if hasattr(params, field):
633
- setattr(params, field, command.strip())
634
- return params
635
-
636
- except Exception as exc:
637
- # logger.warning(PARAMETER_EXTRACTION_ERROR_MSG.format(error=exc))
638
- pass
639
-
640
- return params
5
+ from ..intent_detection import CommandNamePrediction
6
+ from ..parameter_extraction import ParameterExtraction
641
7
 
642
8
 
643
9
  class Signature:
@@ -764,6 +130,13 @@ class ResponseGenerator:
764
130
  workflow_context["NLU_Pipeline_Stage"] = NLUPipelineStage.PARAMETER_EXTRACTION
765
131
  workflow.context = workflow_context
766
132
 
133
+ if nlu_pipeline_stage == NLUPipelineStage.PARAMETER_EXTRACTION:
134
+ cnp_output.command_name = workflow.context["command_name"]
135
+ else:
136
+ workflow_context = workflow.context
137
+ workflow_context["command_name"] = cnp_output.command_name
138
+ workflow.context = workflow_context
139
+
767
140
  command_name = cnp_output.command_name
768
141
  extractor = ParameterExtraction(workflow, app_workflow, command_name, command)
769
142
  pe_output = extractor.extract()