ripple-down-rules 0.6.51__py3-none-any.whl → 0.6.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,71 +1,143 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import re
5
4
  from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass, field
6
6
  from types import NoneType
7
7
  from uuid import uuid4
8
8
 
9
- from anytree import NodeMixin
10
- from sqlalchemy.orm import DeclarativeBase as SQLTable
9
+ from anytree import Node
11
10
  from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
12
11
 
13
12
  from .datastructures.callable_expression import CallableExpression
14
13
  from .datastructures.case import Case
15
14
  from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
16
15
  from .datastructures.enums import RDREdge, Stop
17
- from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
18
16
  from .helpers import get_an_updated_case_copy
17
+ from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
19
18
 
20
19
 
21
- class Rule(NodeMixin, SubclassJSONSerializer, ABC):
22
- fired: Optional[bool] = None
20
+ @dataclass
21
+ class Rule(SubclassJSONSerializer, ABC):
22
+ conditions: Optional[CallableExpression] = field(default=None)
23
+ """
24
+ The conditions of the rule, which is a callable expression that takes a case and returns a boolean.
25
+ """
26
+ conclusion: Optional[CallableExpression] = field(default=None)
27
+ """
28
+ The conclusion of the rule, which is a callable expression that takes a case and returns a value.
29
+ """
30
+ _parent: Optional[Rule] = field(default=None)
31
+ """
32
+ The parent rule of this rule in the ripple down rules tree.
33
+ """
34
+ corner_case: Optional[Any] = field(default=None)
35
+ """
36
+ The corner case for which this rule was created.
37
+ """
38
+ _weight: Optional[RDREdge] = field(default=None)
39
+ """
40
+ The weight of the rule, which is the type of edge connecting the rule to its parent.
41
+ """
42
+ conclusion_name: Optional[str] = field(default=None)
43
+ """
44
+ The name of the conclusion of the rule, which is used to identify the conclusion
45
+ """
46
+ uid: str = field(default_factory=lambda: str(uuid4().int))
47
+ """
48
+ A unique id for the rule using uuid4
49
+ """
50
+ corner_case_metadata: Optional[CaseFactoryMetaData] = field(default=None)
51
+ """
52
+ Metadata about the corner case, such as the factory that created it or the scenario it is based on.
53
+ """
54
+ json_serialization: Optional[Dict[str, Any]] = field(init=False, default=None)
55
+ """
56
+ The JSON serialization of the rule, which is used to serialize the rule to JSON.
57
+ """
58
+ _name: Optional[str] = field(init=False, default=None)
59
+ """
60
+ The name of the rule, which is the names of the conditions and the conclusion
61
+ """
62
+ evaluated: bool = field(init=False, default=False)
63
+ """
64
+ Whether the rule has been evaluated or not (i.e. whether the rule has been reached during evaluation of the ripple
65
+ down rules tree and the conditions have been checked).
66
+ """
67
+ last_conclusion: Optional[Any] = field(init=False, default=None)
68
+ """
69
+ The last conclusion of the rule, which is the conclusion of the rule when it was last evaluated.
70
+ """
71
+ contributed: bool = field(init=False, default=False)
72
+ """
73
+ Whether the rule has contributed by a value, meaning that it has fired and the conclusion has been added to the case.
74
+ """
75
+ contributed_to_case_query: bool = field(init=False, default=False)
76
+ """
77
+ Whether the rule has contributed to the case query, meaning that it has fired and the conclusion is relevant to the
78
+ case query.
79
+ """
80
+ fired: Optional[bool] = field(init=False, default=None)
23
81
  """
24
82
  Whether the rule has fired or not.
25
83
  """
26
- mutually_exclusive: bool
84
+ mutually_exclusive: bool = field(init=False, default=True)
27
85
  """
28
86
  Whether the rule is mutually exclusive with other rules.
29
87
  """
88
+ node: Optional[Node] = field(init=False, default=None)
89
+
90
+ def __post_init__(self):
91
+ self.node = Node(self.name, parent=self.parent.node if self.parent else None)
92
+ self.node.weight = self.weight.value if self.weight else None
93
+ self.node._rdr_rule = self
94
+
95
+ @property
96
+ def descendants(self) -> List[Rule]:
97
+ """
98
+ :return: the descendants of this rule, which are the rules that are children of this rule in the ripple down
99
+ rules tree.
100
+ """
101
+ return [child._rdr_rule for child in self.node.descendants]
30
102
 
31
- def __init__(self, conditions: Optional[CallableExpression] = None,
32
- conclusion: Optional[CallableExpression] = None,
33
- parent: Optional[Rule] = None,
34
- corner_case: Optional[Union[Case, SQLTable]] = None,
35
- weight: Optional[str] = None,
36
- conclusion_name: Optional[str] = None,
37
- uid: Optional[str] = None,
38
- corner_case_metadata: Optional[CaseFactoryMetaData] = None):
103
+ @property
104
+ def children(self) -> List[Rule]:
39
105
  """
40
- A rule in the ripple down rules classifier.
106
+ :return: the children of this rule, which are the rules that are direct children of this rule in the ripple down
107
+ rules tree.
108
+ """
109
+ return [child._rdr_rule for child in self.node.children]
41
110
 
42
- :param conditions: The conditions of the rule.
43
- :param conclusion: The conclusion of the rule when the conditions are met.
44
- :param parent: The parent rule of this rule.
45
- :param corner_case: The corner case that this rule is based on/created from.
46
- :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
47
- :param conclusion_name: The name of the conclusion of the rule.
48
- :param uid: The unique id of the rule.
49
- :param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
50
- scenario it is based on.
51
- """
52
- super(Rule, self).__init__()
53
- self.conclusion = conclusion
54
- self.corner_case = corner_case
55
- self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
56
- self.parent = parent
57
- self.weight: Optional[str] = weight
58
- self.conditions = conditions if conditions else None
59
- self.conclusion_name: Optional[str] = conclusion_name
60
- self.json_serialization: Optional[Dict[str, Any]] = None
61
- self._name: Optional[str] = None
62
- # generate a unique id for the rule using uuid4
63
- self.uid: str = uid if uid else str(uuid4().int)
64
- self.evaluated: bool = False
65
- self._user_defined_name: Optional[str] = None
66
- self.last_conclusion: Optional[Any] = None
67
- self.contributed: bool = False
68
- self.contributed_to_case_query: bool = False
111
+ @property
112
+ def parent(self):
113
+ """
114
+ :return: The parent rule of this rule.
115
+ """
116
+ return self._parent
117
+
118
+ @parent.setter
119
+ def parent(self, new_parent: Optional[Rule]):
120
+ """
121
+ Set the parent rule of this rule.
122
+ :param new_parent: The new parent rule to set.
123
+ """
124
+ self._parent = new_parent
125
+ if self.node:
126
+ self.node.parent = new_parent.node
127
+
128
+ @property
129
+ def weight(self) -> RDREdge:
130
+ return self._weight
131
+
132
+ @weight.setter
133
+ def weight(self, new_weight: RDREdge):
134
+ """
135
+ Set the weight of the rule, which is the type of edge connecting the rule to its parent.
136
+ :param new_weight: The new weight to set.
137
+ """
138
+ self._weight = new_weight
139
+ if self.node:
140
+ self.node.weight = new_weight.value
69
141
 
70
142
  def get_an_updated_case_copy(self, case: Case) -> Case:
71
143
  """
@@ -96,23 +168,6 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
96
168
  else:
97
169
  return "white"
98
170
 
99
- @property
100
- def user_defined_name(self) -> Optional[str]:
101
- """
102
- Get the user defined name of the rule, if it exists.
103
- """
104
- if self._user_defined_name is None:
105
- if self.conditions and self.conditions.user_input and "def " in self.conditions.user_input:
106
- # If the conditions have a user input, use it as the name
107
- func_name = self.conditions.user_input.split('(')[0].replace('def ', '').strip()
108
- if func_name == self.conditions.encapsulating_function_name:
109
- self._user_defined_name = str(self.conditions)
110
- else:
111
- self._user_defined_name = func_name
112
- else:
113
- self._user_defined_name = f"Rule_{self.uid}"
114
- return self._user_defined_name
115
-
116
171
  @classmethod
117
172
  def from_case_query(cls, case_query: CaseQuery, parent: Optional[Rule] = None) -> Rule:
118
173
  """
@@ -124,7 +179,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
124
179
  """
125
180
  corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
126
181
  return cls(conditions=case_query.conditions, conclusion=case_query.target,
127
- corner_case=case_query.case, parent=parent,
182
+ corner_case=case_query.case, _parent=parent,
128
183
  corner_case_metadata=corner_case_metadata,
129
184
  conclusion_name=case_query.attribute_name)
130
185
 
@@ -291,7 +346,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
291
346
  "conclusion": conclusion_to_json(self.conclusion),
292
347
  "parent": self.parent.json_serialization if self.parent else None,
293
348
  "conclusion_name": self.conclusion_name,
294
- "weight": self.weight,
349
+ "weight": self.weight.value,
295
350
  "uid": self.uid}
296
351
  return json_serialization
297
352
 
@@ -299,9 +354,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
299
354
  def _from_json(cls, data: Dict[str, Any]) -> Rule:
300
355
  loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
301
356
  conclusion=CallableExpression.from_json(data["conclusion"]),
302
- parent=cls.from_json(data["parent"]),
357
+ _parent=cls.from_json(data["parent"]),
303
358
  conclusion_name=data["conclusion_name"],
304
- weight=data["weight"],
359
+ _weight=RDREdge.from_value(data["weight"]),
305
360
  uid=data["uid"])
306
361
  return loaded_rule
307
362
 
@@ -318,20 +373,25 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
318
373
  Set the name of the rule.
319
374
  """
320
375
  self._name = new_name
376
+ self.node.name = new_name
321
377
 
322
378
  @property
323
379
  def semantic_condition_name(self) -> Optional[str]:
324
380
  """
325
381
  Get the name of the conditions of the rule, which is the user input of the conditions.
326
382
  """
327
- return self.expression_name(self.conditions)
383
+ if isinstance(self.conditions, CallableExpression):
384
+ return self.expression_name(self.conditions)
385
+ return None
328
386
 
329
387
  @property
330
388
  def semantic_conclusion_name(self) -> Optional[str]:
331
389
  """
332
390
  Get the name of the conclusion of the rule, which is the user input of the conclusion.
333
391
  """
334
- return self.expression_name(self.conclusion)
392
+ if isinstance(self.conclusion, CallableExpression):
393
+ return self.expression_name(self.conclusion)
394
+ return None
335
395
 
336
396
  @staticmethod
337
397
  def expression_name(expression: CallableExpression) -> str:
@@ -341,7 +401,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
341
401
  """
342
402
  if expression.user_defined_name is not None and expression.user_defined_name != expression.encapsulating_function_name:
343
403
  return expression.user_defined_name.strip()
344
- func_name = expression.user_input.split('(')[0].replace('def ', '').strip() if "def " in expression.user_input else None
404
+ func_name = expression.user_input.split('(')[0].replace('def ',
405
+ '').strip() if "def " in expression.user_input else None
345
406
  if func_name is not None and func_name != expression.encapsulating_function_name:
346
407
  return func_name
347
408
  elif expression.user_input:
@@ -358,20 +419,29 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
358
419
  def __repr__(self):
359
420
  return self.__str__()
360
421
 
422
+ def __eq__(self, other):
423
+ if not isinstance(other, Rule):
424
+ return False
425
+ return other.uid == self.uid
426
+
427
+ def __hash__(self):
428
+ return hash(self.uid)
361
429
 
430
+
431
+ @dataclass
362
432
  class HasAlternativeRule:
363
433
  """
364
434
  A mixin class for rules that have an alternative rule.
365
435
  """
366
- _alternative: Optional[Rule] = None
436
+ _alternative: Optional[Rule] = field(init=False, default=None)
367
437
  """
368
438
  The alternative rule of the rule, which is evaluated when the rule doesn't fire.
369
439
  """
370
- furthest_alternative: Optional[List[Rule]] = None
440
+ furthest_alternative: Optional[List[Rule]] = field(init=False, default=None)
371
441
  """
372
442
  The furthest alternative rule of the rule, which is the last alternative rule in the chain of alternative rules.
373
443
  """
374
- all_alternatives: Optional[List[Rule]] = None
444
+ all_alternatives: Optional[List[Rule]] = field(init=False, default=None)
375
445
  """
376
446
  All alternative rules of the rule, which is all the alternative rules in the chain of alternative rules.
377
447
  """
@@ -380,9 +450,6 @@ class HasAlternativeRule:
380
450
  def alternative(self) -> Optional[Rule]:
381
451
  return self._alternative
382
452
 
383
- def set_immediate_alternative(self, alternative: Optional[Rule]):
384
- self._alternative = alternative
385
-
386
453
  @alternative.setter
387
454
  def alternative(self, new_rule: Rule):
388
455
  """
@@ -395,13 +462,14 @@ class HasAlternativeRule:
395
462
  self.furthest_alternative[-1].alternative = new_rule
396
463
  else:
397
464
  new_rule.parent = self
398
- new_rule.weight = RDREdge.Alternative.value if not new_rule.weight else new_rule.weight
465
+ new_rule.weight = RDREdge.Alternative if not new_rule.weight else new_rule.weight
399
466
  self._alternative = new_rule
400
467
  self.furthest_alternative = [new_rule]
401
468
 
402
469
 
470
+ @dataclass
403
471
  class HasRefinementRule:
404
- _refinement: Optional[HasAlternativeRule] = None
472
+ _refinement: Optional[HasAlternativeRule] = field(init=False, default=None)
405
473
  """
406
474
  The refinement rule of the rule, which is evaluated when the rule fires.
407
475
  """
@@ -422,16 +490,18 @@ class HasRefinementRule:
422
490
  self.refinement.alternative = new_rule
423
491
  else:
424
492
  new_rule.parent = self
425
- new_rule.weight = RDREdge.Refinement.value if not isinstance(new_rule, MultiClassFilterRule) else new_rule.weight
493
+ new_rule.weight = RDREdge.Refinement if not isinstance(new_rule,
494
+ MultiClassFilterRule) else new_rule.weight
426
495
  self._refinement = new_rule
427
496
 
428
497
 
498
+ @dataclass(eq=False)
429
499
  class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
430
500
  """
431
501
  A rule in the SingleClassRDR classifier, it can have a refinement or an alternative rule or both.
432
502
  """
433
503
 
434
- mutually_exclusive: bool = True
504
+ mutually_exclusive: bool = field(init=False, default=True)
435
505
 
436
506
  def evaluate_next_rule(self, x: Case) -> SingleClassRule:
437
507
  if self.fired:
@@ -443,7 +513,7 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
443
513
  def fit_rule(self, case_query: CaseQuery):
444
514
  corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
445
515
  new_rule = SingleClassRule(case_query.conditions, case_query.target,
446
- corner_case=case_query.case, parent=self,
516
+ corner_case=case_query.case, _parent=self,
447
517
  corner_case_metadata=corner_case_metadata,
448
518
  )
449
519
  if self.fired:
@@ -465,18 +535,19 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
465
535
  return loaded_rule
466
536
 
467
537
  def _if_statement_source_code_clause(self) -> str:
468
- return "elif" if self.weight == RDREdge.Alternative.value else "if"
538
+ return "elif" if self.weight == RDREdge.Alternative else "if"
469
539
 
470
540
 
541
+ @dataclass(eq=False)
471
542
  class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
472
543
  """
473
544
  A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule.
474
545
  """
475
- top_rule: Optional[MultiClassTopRule] = None
546
+ top_rule: Optional[MultiClassTopRule] = field(init=False, default=None)
476
547
  """
477
548
  The top rule of the rule, which is the nearest ancestor that fired and this rule is a refinement of.
478
549
  """
479
- mutually_exclusive: bool = False
550
+ mutually_exclusive: bool = field(init=False, default=False)
480
551
 
481
552
  def _to_json(self) -> Dict[str, Any]:
482
553
  self.json_serialization = {**Rule._to_json(self),
@@ -495,18 +566,20 @@ class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
495
566
  return loaded_rule
496
567
 
497
568
  def _if_statement_source_code_clause(self) -> str:
498
- return "elif" if self.weight == RDREdge.Alternative.value else "if"
569
+ return "elif" if self.weight == RDREdge.Alternative else "if"
499
570
 
500
571
 
572
+ @dataclass(eq=False)
501
573
  class MultiClassStopRule(MultiClassRefinementRule):
502
574
  """
503
575
  A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
504
576
  the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
505
577
  """
506
-
507
- def __init__(self, *args, **kwargs):
508
- super(MultiClassRefinementRule, self).__init__(*args, **kwargs)
509
- self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
578
+ conclusion: CallableExpression = field(default_factory=lambda: CallableExpression(conclusion_type=(Stop,),
579
+ conclusion=Stop.stop))
580
+ """
581
+ The conclusion of the rule, which is a CallableExpression that returns a Stop category.
582
+ """
510
583
 
511
584
  def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
512
585
  if self.fired:
@@ -521,15 +594,13 @@ class MultiClassStopRule(MultiClassRefinementRule):
521
594
  return None, f"{parent_indent}{' ' * 4}pass\n"
522
595
 
523
596
 
597
+ @dataclass(eq=False)
524
598
  class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
525
599
  """
526
600
  A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
527
601
  the conclusion of the rule is a Filter category meant to filter the parent conclusion.
528
602
  """
529
-
530
- def __init__(self, *args, **kwargs):
531
- super(MultiClassRefinementRule, self).__init__(*args, **kwargs)
532
- self.weight = RDREdge.Filter.value
603
+ weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Filter)
533
604
 
534
605
  def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
535
606
  if self.fired:
@@ -579,15 +650,13 @@ class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
579
650
  return loaded_rule
580
651
 
581
652
 
653
+ @dataclass(eq=False)
582
654
  class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
583
655
  """
584
656
  A rule in the MultiClassRDR classifier, it can have a refinement and a next rule.
585
657
  """
586
- mutually_exclusive: bool = False
587
-
588
- def __init__(self, *args, **kwargs):
589
- super(MultiClassTopRule, self).__init__(*args, **kwargs)
590
- self.weight = RDREdge.Next.value
658
+ mutually_exclusive: bool = field(init=False, default=False)
659
+ weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Next)
591
660
 
592
661
  def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
593
662
  if self.fired and self.refinement:
@@ -603,7 +672,7 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
603
672
  if self.fired and case_query.target != self.conclusion:
604
673
  if refinement_type in [None, MultiClassStopRule]:
605
674
  new_rule = MultiClassStopRule(case_query.conditions, corner_case=case_query.case,
606
- parent=self)
675
+ _parent=self)
607
676
  elif refinement_type is MultiClassFilterRule:
608
677
  new_rule = MultiClassFilterRule.from_case_query(case_query, parent=self)
609
678
  else:
@@ -614,7 +683,7 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
614
683
  self.alternative = MultiClassTopRule.from_case_query(case_query, parent=self)
615
684
 
616
685
  def _to_json(self) -> Dict[str, Any]:
617
- self.json_serialization = {**Rule._to_json(self),
686
+ self.json_serialization = {**super()._to_json(),
618
687
  "refinement": self.refinement.to_json() if self.refinement is not None else None,
619
688
  "alternative": self.alternative.to_json() if self.alternative is not None else None}
620
689
  return self.json_serialization
@@ -24,10 +24,16 @@ class MyMagics(Magics):
24
24
  self.case_query: Optional[CaseQuery] = case_query
25
25
  self.rule_editor = TemplateFileCreator(case_query, prompt_for=prompt_for, code_to_modify=code_to_modify)
26
26
  self.all_code_lines: Optional[List[str]] = None
27
+ self.edited: bool = False
28
+ self.loaded: bool = False
27
29
 
28
30
  @line_magic
29
31
  def edit(self, line):
30
- self.rule_editor.edit()
32
+ if self.edited:
33
+ self.rule_editor.open_file_in_editor()
34
+ else:
35
+ self.rule_editor.edit()
36
+ self.edited = True
31
37
 
32
38
  @line_magic
33
39
  def load(self, line):
@@ -35,6 +41,8 @@ class MyMagics(Magics):
35
41
  self.rule_editor.func_name,
36
42
  self.rule_editor.print_func)
37
43
  self.shell.user_ns.update(updates)
44
+ self.case_query.scope.update(updates)
45
+ self.loaded = True
38
46
 
39
47
  @line_magic
40
48
  def current_value(self, line):
@@ -23,14 +23,14 @@ from ..datastructures.dataclasses import CaseQuery
23
23
  from ..datastructures.enums import PromptFor, ExitStatus
24
24
  from .ipython_custom_shell import IPythonShell
25
25
  from ..utils import make_list
26
- from threading import RLock
26
+ from threading import Lock
27
27
 
28
28
 
29
29
  class UserPrompt:
30
30
  """
31
31
  A class to handle user prompts for the RDR.
32
32
  """
33
- shell_lock: RLock = RLock() # To ensure that only one thread can access the shell at a time
33
+ shell_lock: Lock = Lock() # To ensure that only one thread can access the shell at a time
34
34
 
35
35
  def __init__(self, prompt_user: bool = True):
36
36
  """
@@ -49,43 +49,43 @@ class UserPrompt:
49
49
  :param prompt_str: The prompt string to display to the user.
50
50
  :return: A callable expression that takes a case and executes user expression on it.
51
51
  """
52
- prev_user_input: Optional[str] = None
53
- user_input_to_modify: Optional[str] = None
54
- callable_expression: Optional[CallableExpression] = None
55
- while True:
56
- with self.shell_lock:
52
+ with self.shell_lock:
53
+ prev_user_input: Optional[str] = None
54
+ user_input_to_modify: Optional[str] = None
55
+ callable_expression: Optional[CallableExpression] = None
56
+ while True:
57
57
  user_input, expression_tree = self.prompt_user_about_case(case_query, prompt_for, prompt_str,
58
58
  code_to_modify=prev_user_input)
59
- if user_input is None:
60
- if prompt_for == PromptFor.Conclusion:
61
- self.print_func(f"\n{Fore.YELLOW}No conclusion provided. Exiting.{Style.RESET_ALL}")
62
- return None, None
63
- else:
64
- self.print_func(f"\n{Fore.RED}Conditions must be provided. Please try again.{Style.RESET_ALL}")
65
- continue
66
- elif user_input in ["exit", 'quit']:
67
- self.print_func(f"\n{Fore.YELLOW}Exiting.{Style.RESET_ALL}")
68
- return user_input, None
69
-
70
- prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
71
- conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
72
- callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
73
- scope=case_query.scope,
74
- mutually_exclusive=case_query.mutually_exclusive)
75
- try:
76
- result = callable_expression(case_query.case)
77
- if len(make_list(result)) == 0 and (user_input_to_modify is not None
78
- and (prev_user_input != user_input_to_modify)):
79
- user_input_to_modify = prev_user_input
80
- self.print_func(
81
- f"{Fore.YELLOW}The given expression gave an empty result for case {case_query.name}."
82
- f" Please accept or modify!{Style.RESET_ALL}")
83
- continue
84
- break
85
- except Exception as e:
86
- logging.error(e)
87
- self.print_func(f"{Fore.RED}{e}{Style.RESET_ALL}")
88
- return user_input, callable_expression
59
+ if user_input is None:
60
+ if prompt_for == PromptFor.Conclusion:
61
+ self.print_func(f"\n{Fore.YELLOW}No conclusion provided. Exiting.{Style.RESET_ALL}")
62
+ return None, None
63
+ else:
64
+ self.print_func(f"\n{Fore.RED}Conditions must be provided. Please try again.{Style.RESET_ALL}")
65
+ continue
66
+ elif user_input in ["exit", 'quit']:
67
+ self.print_func(f"\n{Fore.YELLOW}Exiting.{Style.RESET_ALL}")
68
+ return user_input, None
69
+
70
+ prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
71
+ conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
72
+ callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
73
+ scope=case_query.scope,
74
+ mutually_exclusive=case_query.mutually_exclusive)
75
+ try:
76
+ result = callable_expression(case_query.case)
77
+ if len(make_list(result)) == 0 and (user_input_to_modify is not None
78
+ and (prev_user_input != user_input_to_modify)):
79
+ user_input_to_modify = prev_user_input
80
+ self.print_func(
81
+ f"{Fore.YELLOW}The given expression gave an empty result for case {case_query.name}."
82
+ f" Please accept or modify!{Style.RESET_ALL}")
83
+ continue
84
+ break
85
+ except Exception as e:
86
+ logging.error(e)
87
+ self.print_func(f"{Fore.RED}{e}{Style.RESET_ALL}")
88
+ return user_input, callable_expression
89
89
 
90
90
  def prompt_user_about_case(self, case_query: CaseQuery, prompt_for: PromptFor,
91
91
  prompt_str: Optional[str] = None,
@@ -177,6 +177,8 @@ class TemplateFileCreator:
177
177
  if self.case_query.is_function:
178
178
  func_args = {}
179
179
  for k, v in self.case_query.case.items():
180
+ if k == self.case_query.attribute_name:
181
+ continue
180
182
  if (self.case_query.function_args_type_hints is not None
181
183
  and k in self.case_query.function_args_type_hints):
182
184
  func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
@@ -184,6 +186,7 @@ class TemplateFileCreator:
184
186
  func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
185
187
  func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
186
188
  for k, v in func_args.items()])
189
+ func_args += ", **kwargs"
187
190
  else:
188
191
  func_args = f"case: {self.case_query.case_type.__name__}"
189
192
  return func_args