ripple-down-rules 0.6.51__py3-none-any.whl → 0.6.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ripple_down_rules/__init__.py +12 -4
- ripple_down_rules/datastructures/dataclasses.py +52 -9
- ripple_down_rules/datastructures/enums.py +14 -87
- ripple_down_rules/datastructures/field_info.py +177 -0
- ripple_down_rules/datastructures/tracked_object.py +208 -0
- ripple_down_rules/helpers.py +37 -2
- ripple_down_rules/predicates.py +97 -0
- ripple_down_rules/rdr.py +10 -6
- ripple_down_rules/rdr_decorators.py +44 -34
- ripple_down_rules/rules.py +166 -97
- ripple_down_rules/user_interface/ipython_custom_shell.py +9 -1
- ripple_down_rules/user_interface/prompt.py +37 -37
- ripple_down_rules/user_interface/template_file_creator.py +3 -0
- ripple_down_rules/utils.py +32 -5
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/METADATA +3 -1
- ripple_down_rules-0.6.60.dist-info/RECORD +28 -0
- ripple_down_rules-0.6.51.dist-info/RECORD +0 -25
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/WHEEL +0 -0
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/licenses/LICENSE +0 -0
- {ripple_down_rules-0.6.51.dist-info → ripple_down_rules-0.6.60.dist-info}/top_level.txt +0 -0
ripple_down_rules/rules.py
CHANGED
@@ -1,71 +1,143 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import logging
|
4
3
|
import re
|
5
4
|
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import dataclass, field
|
6
6
|
from types import NoneType
|
7
7
|
from uuid import uuid4
|
8
8
|
|
9
|
-
from anytree import
|
10
|
-
from sqlalchemy.orm import DeclarativeBase as SQLTable
|
9
|
+
from anytree import Node
|
11
10
|
from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
|
12
11
|
|
13
12
|
from .datastructures.callable_expression import CallableExpression
|
14
13
|
from .datastructures.case import Case
|
15
14
|
from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
|
16
15
|
from .datastructures.enums import RDREdge, Stop
|
17
|
-
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
|
18
16
|
from .helpers import get_an_updated_case_copy
|
17
|
+
from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
|
19
18
|
|
20
19
|
|
21
|
-
|
22
|
-
|
20
|
+
@dataclass
|
21
|
+
class Rule(SubclassJSONSerializer, ABC):
|
22
|
+
conditions: Optional[CallableExpression] = field(default=None)
|
23
|
+
"""
|
24
|
+
The conditions of the rule, which is a callable expression that takes a case and returns a boolean.
|
25
|
+
"""
|
26
|
+
conclusion: Optional[CallableExpression] = field(default=None)
|
27
|
+
"""
|
28
|
+
The conclusion of the rule, which is a callable expression that takes a case and returns a value.
|
29
|
+
"""
|
30
|
+
_parent: Optional[Rule] = field(default=None)
|
31
|
+
"""
|
32
|
+
The parent rule of this rule in the ripple down rules tree.
|
33
|
+
"""
|
34
|
+
corner_case: Optional[Any] = field(default=None)
|
35
|
+
"""
|
36
|
+
The corner case for which this rule was created.
|
37
|
+
"""
|
38
|
+
_weight: Optional[RDREdge] = field(default=None)
|
39
|
+
"""
|
40
|
+
The weight of the rule, which is the type of edge connecting the rule to its parent.
|
41
|
+
"""
|
42
|
+
conclusion_name: Optional[str] = field(default=None)
|
43
|
+
"""
|
44
|
+
The name of the conclusion of the rule, which is used to identify the conclusion
|
45
|
+
"""
|
46
|
+
uid: str = field(default_factory=lambda: str(uuid4().int))
|
47
|
+
"""
|
48
|
+
A unique id for the rule using uuid4
|
49
|
+
"""
|
50
|
+
corner_case_metadata: Optional[CaseFactoryMetaData] = field(default=None)
|
51
|
+
"""
|
52
|
+
Metadata about the corner case, such as the factory that created it or the scenario it is based on.
|
53
|
+
"""
|
54
|
+
json_serialization: Optional[Dict[str, Any]] = field(init=False, default=None)
|
55
|
+
"""
|
56
|
+
The JSON serialization of the rule, which is used to serialize the rule to JSON.
|
57
|
+
"""
|
58
|
+
_name: Optional[str] = field(init=False, default=None)
|
59
|
+
"""
|
60
|
+
The name of the rule, which is the names of the conditions and the conclusion
|
61
|
+
"""
|
62
|
+
evaluated: bool = field(init=False, default=False)
|
63
|
+
"""
|
64
|
+
Whether the rule has been evaluated or not (i.e. whether the rule has been reached during evaluation of the ripple
|
65
|
+
down rules tree and the conditions have been checked).
|
66
|
+
"""
|
67
|
+
last_conclusion: Optional[Any] = field(init=False, default=None)
|
68
|
+
"""
|
69
|
+
The last conclusion of the rule, which is the conclusion of the rule when it was last evaluated.
|
70
|
+
"""
|
71
|
+
contributed: bool = field(init=False, default=False)
|
72
|
+
"""
|
73
|
+
Whether the rule has contributed by a value, meaning that it has fired and the conclusion has been added to the case.
|
74
|
+
"""
|
75
|
+
contributed_to_case_query: bool = field(init=False, default=False)
|
76
|
+
"""
|
77
|
+
Whether the rule has contributed to the case query, meaning that it has fired and the conclusion is relevant to the
|
78
|
+
case query.
|
79
|
+
"""
|
80
|
+
fired: Optional[bool] = field(init=False, default=None)
|
23
81
|
"""
|
24
82
|
Whether the rule has fired or not.
|
25
83
|
"""
|
26
|
-
mutually_exclusive: bool
|
84
|
+
mutually_exclusive: bool = field(init=False, default=True)
|
27
85
|
"""
|
28
86
|
Whether the rule is mutually exclusive with other rules.
|
29
87
|
"""
|
88
|
+
node: Optional[Node] = field(init=False, default=None)
|
89
|
+
|
90
|
+
def __post_init__(self):
|
91
|
+
self.node = Node(self.name, parent=self.parent.node if self.parent else None)
|
92
|
+
self.node.weight = self.weight.value if self.weight else None
|
93
|
+
self.node._rdr_rule = self
|
94
|
+
|
95
|
+
@property
|
96
|
+
def descendants(self) -> List[Rule]:
|
97
|
+
"""
|
98
|
+
:return: the descendants of this rule, which are the rules that are children of this rule in the ripple down
|
99
|
+
rules tree.
|
100
|
+
"""
|
101
|
+
return [child._rdr_rule for child in self.node.descendants]
|
30
102
|
|
31
|
-
|
32
|
-
|
33
|
-
parent: Optional[Rule] = None,
|
34
|
-
corner_case: Optional[Union[Case, SQLTable]] = None,
|
35
|
-
weight: Optional[str] = None,
|
36
|
-
conclusion_name: Optional[str] = None,
|
37
|
-
uid: Optional[str] = None,
|
38
|
-
corner_case_metadata: Optional[CaseFactoryMetaData] = None):
|
103
|
+
@property
|
104
|
+
def children(self) -> List[Rule]:
|
39
105
|
"""
|
40
|
-
|
106
|
+
:return: the children of this rule, which are the rules that are direct children of this rule in the ripple down
|
107
|
+
rules tree.
|
108
|
+
"""
|
109
|
+
return [child._rdr_rule for child in self.node.children]
|
41
110
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
:
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
"""
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
self.
|
56
|
-
self.
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
self.
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
111
|
+
@property
|
112
|
+
def parent(self):
|
113
|
+
"""
|
114
|
+
:return: The parent rule of this rule.
|
115
|
+
"""
|
116
|
+
return self._parent
|
117
|
+
|
118
|
+
@parent.setter
|
119
|
+
def parent(self, new_parent: Optional[Rule]):
|
120
|
+
"""
|
121
|
+
Set the parent rule of this rule.
|
122
|
+
:param new_parent: The new parent rule to set.
|
123
|
+
"""
|
124
|
+
self._parent = new_parent
|
125
|
+
if self.node:
|
126
|
+
self.node.parent = new_parent.node
|
127
|
+
|
128
|
+
@property
|
129
|
+
def weight(self) -> RDREdge:
|
130
|
+
return self._weight
|
131
|
+
|
132
|
+
@weight.setter
|
133
|
+
def weight(self, new_weight: RDREdge):
|
134
|
+
"""
|
135
|
+
Set the weight of the rule, which is the type of edge connecting the rule to its parent.
|
136
|
+
:param new_weight: The new weight to set.
|
137
|
+
"""
|
138
|
+
self._weight = new_weight
|
139
|
+
if self.node:
|
140
|
+
self.node.weight = new_weight.value
|
69
141
|
|
70
142
|
def get_an_updated_case_copy(self, case: Case) -> Case:
|
71
143
|
"""
|
@@ -96,23 +168,6 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
96
168
|
else:
|
97
169
|
return "white"
|
98
170
|
|
99
|
-
@property
|
100
|
-
def user_defined_name(self) -> Optional[str]:
|
101
|
-
"""
|
102
|
-
Get the user defined name of the rule, if it exists.
|
103
|
-
"""
|
104
|
-
if self._user_defined_name is None:
|
105
|
-
if self.conditions and self.conditions.user_input and "def " in self.conditions.user_input:
|
106
|
-
# If the conditions have a user input, use it as the name
|
107
|
-
func_name = self.conditions.user_input.split('(')[0].replace('def ', '').strip()
|
108
|
-
if func_name == self.conditions.encapsulating_function_name:
|
109
|
-
self._user_defined_name = str(self.conditions)
|
110
|
-
else:
|
111
|
-
self._user_defined_name = func_name
|
112
|
-
else:
|
113
|
-
self._user_defined_name = f"Rule_{self.uid}"
|
114
|
-
return self._user_defined_name
|
115
|
-
|
116
171
|
@classmethod
|
117
172
|
def from_case_query(cls, case_query: CaseQuery, parent: Optional[Rule] = None) -> Rule:
|
118
173
|
"""
|
@@ -124,7 +179,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
124
179
|
"""
|
125
180
|
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
126
181
|
return cls(conditions=case_query.conditions, conclusion=case_query.target,
|
127
|
-
corner_case=case_query.case,
|
182
|
+
corner_case=case_query.case, _parent=parent,
|
128
183
|
corner_case_metadata=corner_case_metadata,
|
129
184
|
conclusion_name=case_query.attribute_name)
|
130
185
|
|
@@ -291,7 +346,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
291
346
|
"conclusion": conclusion_to_json(self.conclusion),
|
292
347
|
"parent": self.parent.json_serialization if self.parent else None,
|
293
348
|
"conclusion_name": self.conclusion_name,
|
294
|
-
"weight": self.weight,
|
349
|
+
"weight": self.weight.value,
|
295
350
|
"uid": self.uid}
|
296
351
|
return json_serialization
|
297
352
|
|
@@ -299,9 +354,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
299
354
|
def _from_json(cls, data: Dict[str, Any]) -> Rule:
|
300
355
|
loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
|
301
356
|
conclusion=CallableExpression.from_json(data["conclusion"]),
|
302
|
-
|
357
|
+
_parent=cls.from_json(data["parent"]),
|
303
358
|
conclusion_name=data["conclusion_name"],
|
304
|
-
|
359
|
+
_weight=RDREdge.from_value(data["weight"]),
|
305
360
|
uid=data["uid"])
|
306
361
|
return loaded_rule
|
307
362
|
|
@@ -318,20 +373,25 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
318
373
|
Set the name of the rule.
|
319
374
|
"""
|
320
375
|
self._name = new_name
|
376
|
+
self.node.name = new_name
|
321
377
|
|
322
378
|
@property
|
323
379
|
def semantic_condition_name(self) -> Optional[str]:
|
324
380
|
"""
|
325
381
|
Get the name of the conditions of the rule, which is the user input of the conditions.
|
326
382
|
"""
|
327
|
-
|
383
|
+
if isinstance(self.conditions, CallableExpression):
|
384
|
+
return self.expression_name(self.conditions)
|
385
|
+
return None
|
328
386
|
|
329
387
|
@property
|
330
388
|
def semantic_conclusion_name(self) -> Optional[str]:
|
331
389
|
"""
|
332
390
|
Get the name of the conclusion of the rule, which is the user input of the conclusion.
|
333
391
|
"""
|
334
|
-
|
392
|
+
if isinstance(self.conclusion, CallableExpression):
|
393
|
+
return self.expression_name(self.conclusion)
|
394
|
+
return None
|
335
395
|
|
336
396
|
@staticmethod
|
337
397
|
def expression_name(expression: CallableExpression) -> str:
|
@@ -341,7 +401,8 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
341
401
|
"""
|
342
402
|
if expression.user_defined_name is not None and expression.user_defined_name != expression.encapsulating_function_name:
|
343
403
|
return expression.user_defined_name.strip()
|
344
|
-
func_name = expression.user_input.split('(')[0].replace('def ',
|
404
|
+
func_name = expression.user_input.split('(')[0].replace('def ',
|
405
|
+
'').strip() if "def " in expression.user_input else None
|
345
406
|
if func_name is not None and func_name != expression.encapsulating_function_name:
|
346
407
|
return func_name
|
347
408
|
elif expression.user_input:
|
@@ -358,20 +419,29 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
|
|
358
419
|
def __repr__(self):
|
359
420
|
return self.__str__()
|
360
421
|
|
422
|
+
def __eq__(self, other):
|
423
|
+
if not isinstance(other, Rule):
|
424
|
+
return False
|
425
|
+
return other.uid == self.uid
|
426
|
+
|
427
|
+
def __hash__(self):
|
428
|
+
return hash(self.uid)
|
361
429
|
|
430
|
+
|
431
|
+
@dataclass
|
362
432
|
class HasAlternativeRule:
|
363
433
|
"""
|
364
434
|
A mixin class for rules that have an alternative rule.
|
365
435
|
"""
|
366
|
-
_alternative: Optional[Rule] = None
|
436
|
+
_alternative: Optional[Rule] = field(init=False, default=None)
|
367
437
|
"""
|
368
438
|
The alternative rule of the rule, which is evaluated when the rule doesn't fire.
|
369
439
|
"""
|
370
|
-
furthest_alternative: Optional[List[Rule]] = None
|
440
|
+
furthest_alternative: Optional[List[Rule]] = field(init=False, default=None)
|
371
441
|
"""
|
372
442
|
The furthest alternative rule of the rule, which is the last alternative rule in the chain of alternative rules.
|
373
443
|
"""
|
374
|
-
all_alternatives: Optional[List[Rule]] = None
|
444
|
+
all_alternatives: Optional[List[Rule]] = field(init=False, default=None)
|
375
445
|
"""
|
376
446
|
All alternative rules of the rule, which is all the alternative rules in the chain of alternative rules.
|
377
447
|
"""
|
@@ -380,9 +450,6 @@ class HasAlternativeRule:
|
|
380
450
|
def alternative(self) -> Optional[Rule]:
|
381
451
|
return self._alternative
|
382
452
|
|
383
|
-
def set_immediate_alternative(self, alternative: Optional[Rule]):
|
384
|
-
self._alternative = alternative
|
385
|
-
|
386
453
|
@alternative.setter
|
387
454
|
def alternative(self, new_rule: Rule):
|
388
455
|
"""
|
@@ -395,13 +462,14 @@ class HasAlternativeRule:
|
|
395
462
|
self.furthest_alternative[-1].alternative = new_rule
|
396
463
|
else:
|
397
464
|
new_rule.parent = self
|
398
|
-
new_rule.weight = RDREdge.Alternative
|
465
|
+
new_rule.weight = RDREdge.Alternative if not new_rule.weight else new_rule.weight
|
399
466
|
self._alternative = new_rule
|
400
467
|
self.furthest_alternative = [new_rule]
|
401
468
|
|
402
469
|
|
470
|
+
@dataclass
|
403
471
|
class HasRefinementRule:
|
404
|
-
_refinement: Optional[HasAlternativeRule] = None
|
472
|
+
_refinement: Optional[HasAlternativeRule] = field(init=False, default=None)
|
405
473
|
"""
|
406
474
|
The refinement rule of the rule, which is evaluated when the rule fires.
|
407
475
|
"""
|
@@ -422,16 +490,18 @@ class HasRefinementRule:
|
|
422
490
|
self.refinement.alternative = new_rule
|
423
491
|
else:
|
424
492
|
new_rule.parent = self
|
425
|
-
new_rule.weight = RDREdge.Refinement
|
493
|
+
new_rule.weight = RDREdge.Refinement if not isinstance(new_rule,
|
494
|
+
MultiClassFilterRule) else new_rule.weight
|
426
495
|
self._refinement = new_rule
|
427
496
|
|
428
497
|
|
498
|
+
@dataclass(eq=False)
|
429
499
|
class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
430
500
|
"""
|
431
501
|
A rule in the SingleClassRDR classifier, it can have a refinement or an alternative rule or both.
|
432
502
|
"""
|
433
503
|
|
434
|
-
mutually_exclusive: bool = True
|
504
|
+
mutually_exclusive: bool = field(init=False, default=True)
|
435
505
|
|
436
506
|
def evaluate_next_rule(self, x: Case) -> SingleClassRule:
|
437
507
|
if self.fired:
|
@@ -443,7 +513,7 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
443
513
|
def fit_rule(self, case_query: CaseQuery):
|
444
514
|
corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
|
445
515
|
new_rule = SingleClassRule(case_query.conditions, case_query.target,
|
446
|
-
corner_case=case_query.case,
|
516
|
+
corner_case=case_query.case, _parent=self,
|
447
517
|
corner_case_metadata=corner_case_metadata,
|
448
518
|
)
|
449
519
|
if self.fired:
|
@@ -465,18 +535,19 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
|
|
465
535
|
return loaded_rule
|
466
536
|
|
467
537
|
def _if_statement_source_code_clause(self) -> str:
|
468
|
-
return "elif" if self.weight == RDREdge.Alternative
|
538
|
+
return "elif" if self.weight == RDREdge.Alternative else "if"
|
469
539
|
|
470
540
|
|
541
|
+
@dataclass(eq=False)
|
471
542
|
class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
|
472
543
|
"""
|
473
544
|
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule.
|
474
545
|
"""
|
475
|
-
top_rule: Optional[MultiClassTopRule] = None
|
546
|
+
top_rule: Optional[MultiClassTopRule] = field(init=False, default=None)
|
476
547
|
"""
|
477
548
|
The top rule of the rule, which is the nearest ancestor that fired and this rule is a refinement of.
|
478
549
|
"""
|
479
|
-
mutually_exclusive: bool = False
|
550
|
+
mutually_exclusive: bool = field(init=False, default=False)
|
480
551
|
|
481
552
|
def _to_json(self) -> Dict[str, Any]:
|
482
553
|
self.json_serialization = {**Rule._to_json(self),
|
@@ -495,18 +566,20 @@ class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
|
|
495
566
|
return loaded_rule
|
496
567
|
|
497
568
|
def _if_statement_source_code_clause(self) -> str:
|
498
|
-
return "elif" if self.weight == RDREdge.Alternative
|
569
|
+
return "elif" if self.weight == RDREdge.Alternative else "if"
|
499
570
|
|
500
571
|
|
572
|
+
@dataclass(eq=False)
|
501
573
|
class MultiClassStopRule(MultiClassRefinementRule):
|
502
574
|
"""
|
503
575
|
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
|
504
576
|
the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
|
505
577
|
"""
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
578
|
+
conclusion: CallableExpression = field(default_factory=lambda: CallableExpression(conclusion_type=(Stop,),
|
579
|
+
conclusion=Stop.stop))
|
580
|
+
"""
|
581
|
+
The conclusion of the rule, which is a CallableExpression that returns a Stop category.
|
582
|
+
"""
|
510
583
|
|
511
584
|
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
|
512
585
|
if self.fired:
|
@@ -521,15 +594,13 @@ class MultiClassStopRule(MultiClassRefinementRule):
|
|
521
594
|
return None, f"{parent_indent}{' ' * 4}pass\n"
|
522
595
|
|
523
596
|
|
597
|
+
@dataclass(eq=False)
|
524
598
|
class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
|
525
599
|
"""
|
526
600
|
A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
|
527
601
|
the conclusion of the rule is a Filter category meant to filter the parent conclusion.
|
528
602
|
"""
|
529
|
-
|
530
|
-
def __init__(self, *args, **kwargs):
|
531
|
-
super(MultiClassRefinementRule, self).__init__(*args, **kwargs)
|
532
|
-
self.weight = RDREdge.Filter.value
|
603
|
+
weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Filter)
|
533
604
|
|
534
605
|
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
|
535
606
|
if self.fired:
|
@@ -579,15 +650,13 @@ class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
|
|
579
650
|
return loaded_rule
|
580
651
|
|
581
652
|
|
653
|
+
@dataclass(eq=False)
|
582
654
|
class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
583
655
|
"""
|
584
656
|
A rule in the MultiClassRDR classifier, it can have a refinement and a next rule.
|
585
657
|
"""
|
586
|
-
mutually_exclusive: bool = False
|
587
|
-
|
588
|
-
def __init__(self, *args, **kwargs):
|
589
|
-
super(MultiClassTopRule, self).__init__(*args, **kwargs)
|
590
|
-
self.weight = RDREdge.Next.value
|
658
|
+
mutually_exclusive: bool = field(init=False, default=False)
|
659
|
+
weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Next)
|
591
660
|
|
592
661
|
def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
|
593
662
|
if self.fired and self.refinement:
|
@@ -603,7 +672,7 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
603
672
|
if self.fired and case_query.target != self.conclusion:
|
604
673
|
if refinement_type in [None, MultiClassStopRule]:
|
605
674
|
new_rule = MultiClassStopRule(case_query.conditions, corner_case=case_query.case,
|
606
|
-
|
675
|
+
_parent=self)
|
607
676
|
elif refinement_type is MultiClassFilterRule:
|
608
677
|
new_rule = MultiClassFilterRule.from_case_query(case_query, parent=self)
|
609
678
|
else:
|
@@ -614,7 +683,7 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
|
|
614
683
|
self.alternative = MultiClassTopRule.from_case_query(case_query, parent=self)
|
615
684
|
|
616
685
|
def _to_json(self) -> Dict[str, Any]:
|
617
|
-
self.json_serialization = {**
|
686
|
+
self.json_serialization = {**super()._to_json(),
|
618
687
|
"refinement": self.refinement.to_json() if self.refinement is not None else None,
|
619
688
|
"alternative": self.alternative.to_json() if self.alternative is not None else None}
|
620
689
|
return self.json_serialization
|
@@ -24,10 +24,16 @@ class MyMagics(Magics):
|
|
24
24
|
self.case_query: Optional[CaseQuery] = case_query
|
25
25
|
self.rule_editor = TemplateFileCreator(case_query, prompt_for=prompt_for, code_to_modify=code_to_modify)
|
26
26
|
self.all_code_lines: Optional[List[str]] = None
|
27
|
+
self.edited: bool = False
|
28
|
+
self.loaded: bool = False
|
27
29
|
|
28
30
|
@line_magic
|
29
31
|
def edit(self, line):
|
30
|
-
self.
|
32
|
+
if self.edited:
|
33
|
+
self.rule_editor.open_file_in_editor()
|
34
|
+
else:
|
35
|
+
self.rule_editor.edit()
|
36
|
+
self.edited = True
|
31
37
|
|
32
38
|
@line_magic
|
33
39
|
def load(self, line):
|
@@ -35,6 +41,8 @@ class MyMagics(Magics):
|
|
35
41
|
self.rule_editor.func_name,
|
36
42
|
self.rule_editor.print_func)
|
37
43
|
self.shell.user_ns.update(updates)
|
44
|
+
self.case_query.scope.update(updates)
|
45
|
+
self.loaded = True
|
38
46
|
|
39
47
|
@line_magic
|
40
48
|
def current_value(self, line):
|
@@ -23,14 +23,14 @@ from ..datastructures.dataclasses import CaseQuery
|
|
23
23
|
from ..datastructures.enums import PromptFor, ExitStatus
|
24
24
|
from .ipython_custom_shell import IPythonShell
|
25
25
|
from ..utils import make_list
|
26
|
-
from threading import
|
26
|
+
from threading import Lock
|
27
27
|
|
28
28
|
|
29
29
|
class UserPrompt:
|
30
30
|
"""
|
31
31
|
A class to handle user prompts for the RDR.
|
32
32
|
"""
|
33
|
-
shell_lock:
|
33
|
+
shell_lock: Lock = Lock() # To ensure that only one thread can access the shell at a time
|
34
34
|
|
35
35
|
def __init__(self, prompt_user: bool = True):
|
36
36
|
"""
|
@@ -49,43 +49,43 @@ class UserPrompt:
|
|
49
49
|
:param prompt_str: The prompt string to display to the user.
|
50
50
|
:return: A callable expression that takes a case and executes user expression on it.
|
51
51
|
"""
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
52
|
+
with self.shell_lock:
|
53
|
+
prev_user_input: Optional[str] = None
|
54
|
+
user_input_to_modify: Optional[str] = None
|
55
|
+
callable_expression: Optional[CallableExpression] = None
|
56
|
+
while True:
|
57
57
|
user_input, expression_tree = self.prompt_user_about_case(case_query, prompt_for, prompt_str,
|
58
58
|
code_to_modify=prev_user_input)
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
59
|
+
if user_input is None:
|
60
|
+
if prompt_for == PromptFor.Conclusion:
|
61
|
+
self.print_func(f"\n{Fore.YELLOW}No conclusion provided. Exiting.{Style.RESET_ALL}")
|
62
|
+
return None, None
|
63
|
+
else:
|
64
|
+
self.print_func(f"\n{Fore.RED}Conditions must be provided. Please try again.{Style.RESET_ALL}")
|
65
|
+
continue
|
66
|
+
elif user_input in ["exit", 'quit']:
|
67
|
+
self.print_func(f"\n{Fore.YELLOW}Exiting.{Style.RESET_ALL}")
|
68
|
+
return user_input, None
|
69
|
+
|
70
|
+
prev_user_input = '\n'.join(user_input.split('\n')[2:-1])
|
71
|
+
conclusion_type = bool if prompt_for == PromptFor.Conditions else case_query.attribute_type
|
72
|
+
callable_expression = CallableExpression(user_input, conclusion_type, expression_tree=expression_tree,
|
73
|
+
scope=case_query.scope,
|
74
|
+
mutually_exclusive=case_query.mutually_exclusive)
|
75
|
+
try:
|
76
|
+
result = callable_expression(case_query.case)
|
77
|
+
if len(make_list(result)) == 0 and (user_input_to_modify is not None
|
78
|
+
and (prev_user_input != user_input_to_modify)):
|
79
|
+
user_input_to_modify = prev_user_input
|
80
|
+
self.print_func(
|
81
|
+
f"{Fore.YELLOW}The given expression gave an empty result for case {case_query.name}."
|
82
|
+
f" Please accept or modify!{Style.RESET_ALL}")
|
83
|
+
continue
|
84
|
+
break
|
85
|
+
except Exception as e:
|
86
|
+
logging.error(e)
|
87
|
+
self.print_func(f"{Fore.RED}{e}{Style.RESET_ALL}")
|
88
|
+
return user_input, callable_expression
|
89
89
|
|
90
90
|
def prompt_user_about_case(self, case_query: CaseQuery, prompt_for: PromptFor,
|
91
91
|
prompt_str: Optional[str] = None,
|
@@ -177,6 +177,8 @@ class TemplateFileCreator:
|
|
177
177
|
if self.case_query.is_function:
|
178
178
|
func_args = {}
|
179
179
|
for k, v in self.case_query.case.items():
|
180
|
+
if k == self.case_query.attribute_name:
|
181
|
+
continue
|
180
182
|
if (self.case_query.function_args_type_hints is not None
|
181
183
|
and k in self.case_query.function_args_type_hints):
|
182
184
|
func_args[k] = stringify_hint(self.case_query.function_args_type_hints[k])
|
@@ -184,6 +186,7 @@ class TemplateFileCreator:
|
|
184
186
|
func_args[k] = type(v).__name__ if not isinstance(v, type) else f"Type[{v.__name__}]"
|
185
187
|
func_args = ', '.join([f"{k}: {v}" if str(v) not in ["NoneType", "None"] else str(k)
|
186
188
|
for k, v in func_args.items()])
|
189
|
+
func_args += ", **kwargs"
|
187
190
|
else:
|
188
191
|
func_args = f"case: {self.case_query.case_type.__name__}"
|
189
192
|
return func_args
|