ripple-down-rules 0.6.1__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,74 +1,185 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  import re
5
4
  from abc import ABC, abstractmethod
6
- from pathlib import Path
5
+ from dataclasses import dataclass, field
7
6
  from types import NoneType
8
7
  from uuid import uuid4
9
8
 
10
- from anytree import NodeMixin
11
- from sqlalchemy.orm import DeclarativeBase as SQLTable
9
+ from anytree import Node
12
10
  from typing_extensions import List, Optional, Self, Union, Dict, Any, Tuple, Type, Set
13
11
 
14
12
  from .datastructures.callable_expression import CallableExpression
15
13
  from .datastructures.case import Case
16
14
  from .datastructures.dataclasses import CaseFactoryMetaData, CaseQuery
17
15
  from .datastructures.enums import RDREdge, Stop
18
- from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_imports_from_types
16
+ from .helpers import get_an_updated_case_copy
17
+ from .utils import SubclassJSONSerializer, conclusion_to_json, get_full_class_name, get_type_from_string
19
18
 
20
19
 
21
- class Rule(NodeMixin, SubclassJSONSerializer, ABC):
22
- fired: Optional[bool] = None
20
+ @dataclass
21
+ class Rule(SubclassJSONSerializer, ABC):
22
+ conditions: Optional[CallableExpression] = field(default=None)
23
+ """
24
+ The conditions of the rule, which is a callable expression that takes a case and returns a boolean.
25
+ """
26
+ conclusion: Optional[CallableExpression] = field(default=None)
27
+ """
28
+ The conclusion of the rule, which is a callable expression that takes a case and returns a value.
29
+ """
30
+ _parent: Optional[Rule] = field(default=None)
31
+ """
32
+ The parent rule of this rule in the ripple down rules tree.
33
+ """
34
+ corner_case: Optional[Any] = field(default=None)
35
+ """
36
+ The corner case for which this rule was created.
37
+ """
38
+ _weight: Optional[RDREdge] = field(default=None)
39
+ """
40
+ The weight of the rule, which is the type of edge connecting the rule to its parent.
41
+ """
42
+ conclusion_name: Optional[str] = field(default=None)
43
+ """
44
+ The name of the conclusion of the rule, which is used to identify the conclusion
45
+ """
46
+ uid: str = field(default_factory=lambda: str(uuid4().int))
47
+ """
48
+ A unique id for the rule using uuid4
49
+ """
50
+ corner_case_metadata: Optional[CaseFactoryMetaData] = field(default=None)
51
+ """
52
+ Metadata about the corner case, such as the factory that created it or the scenario it is based on.
53
+ """
54
+ json_serialization: Optional[Dict[str, Any]] = field(init=False, default=None)
55
+ """
56
+ The JSON serialization of the rule, which is used to serialize the rule to JSON.
57
+ """
58
+ _name: Optional[str] = field(init=False, default=None)
59
+ """
60
+ The name of the rule, which is the names of the conditions and the conclusion
61
+ """
62
+ evaluated: bool = field(init=False, default=False)
63
+ """
64
+ Whether the rule has been evaluated or not (i.e. whether the rule has been reached during evaluation of the ripple
65
+ down rules tree and the conditions have been checked).
66
+ """
67
+ last_conclusion: Optional[Any] = field(init=False, default=None)
68
+ """
69
+ The last conclusion of the rule, which is the conclusion of the rule when it was last evaluated.
70
+ """
71
+ contributed: bool = field(init=False, default=False)
72
+ """
73
+ Whether the rule has contributed by a value, meaning that it has fired and the conclusion has been added to the case.
74
+ """
75
+ contributed_to_case_query: bool = field(init=False, default=False)
76
+ """
77
+ Whether the rule has contributed to the case query, meaning that it has fired and the conclusion is relevant to the
78
+ case query.
79
+ """
80
+ fired: Optional[bool] = field(init=False, default=None)
23
81
  """
24
82
  Whether the rule has fired or not.
25
83
  """
84
+ mutually_exclusive: bool = field(init=False, default=True)
85
+ """
86
+ Whether the rule is mutually exclusive with other rules.
87
+ """
88
+ node: Optional[Node] = field(init=False, default=None)
26
89
 
27
- def __init__(self, conditions: Optional[CallableExpression] = None,
28
- conclusion: Optional[CallableExpression] = None,
29
- parent: Optional[Rule] = None,
30
- corner_case: Optional[Union[Case, SQLTable]] = None,
31
- weight: Optional[str] = None,
32
- conclusion_name: Optional[str] = None,
33
- uid: Optional[str] = None,
34
- corner_case_metadata: Optional[CaseFactoryMetaData] = None):
90
+ def __post_init__(self):
91
+ self.node = Node(self.name, parent=self.parent.node if self.parent else None)
92
+ self.node.weight = self.weight.value if self.weight else None
93
+ self.node._rdr_rule = self
94
+
95
+ @property
96
+ def descendants(self) -> List[Rule]:
35
97
  """
36
- A rule in the ripple down rules classifier.
98
+ :return: the descendants of this rule, which are the rules that are children of this rule in the ripple down
99
+ rules tree.
100
+ """
101
+ return [child._rdr_rule for child in self.node.descendants]
37
102
 
38
- :param conditions: The conditions of the rule.
39
- :param conclusion: The conclusion of the rule when the conditions are met.
40
- :param parent: The parent rule of this rule.
41
- :param corner_case: The corner case that this rule is based on/created from.
42
- :param weight: The weight of the rule, which is the type of edge connecting the rule to its parent.
43
- :param conclusion_name: The name of the conclusion of the rule.
44
- :param uid: The unique id of the rule.
45
- :param corner_case_metadata: Metadata about the corner case, such as the factory that created it or the
46
- scenario it is based on.
47
- """
48
- super(Rule, self).__init__()
49
- self.conclusion = conclusion
50
- self.corner_case = corner_case
51
- self.corner_case_metadata: Optional[CaseFactoryMetaData] = corner_case_metadata
52
- self.parent = parent
53
- self.weight: Optional[str] = weight
54
- self.conditions = conditions if conditions else None
55
- self.conclusion_name: Optional[str] = conclusion_name
56
- self.json_serialization: Optional[Dict[str, Any]] = None
57
- self._name: Optional[str] = None
58
- # generate a unique id for the rule using uuid4
59
- self.uid: str = uid if uid else str(uuid4().int)
103
+ @property
104
+ def children(self) -> List[Rule]:
105
+ """
106
+ :return: the children of this rule, which are the rules that are direct children of this rule in the ripple down
107
+ rules tree.
108
+ """
109
+ return [child._rdr_rule for child in self.node.children]
110
+
111
+ @property
112
+ def parent(self):
113
+ """
114
+ :return: The parent rule of this rule.
115
+ """
116
+ return self._parent
117
+
118
+ @parent.setter
119
+ def parent(self, new_parent: Optional[Rule]):
120
+ """
121
+ Set the parent rule of this rule.
122
+ :param new_parent: The new parent rule to set.
123
+ """
124
+ self._parent = new_parent
125
+ if self.node:
126
+ self.node.parent = new_parent.node
127
+
128
+ @property
129
+ def weight(self) -> RDREdge:
130
+ return self._weight
131
+
132
+ @weight.setter
133
+ def weight(self, new_weight: RDREdge):
134
+ """
135
+ Set the weight of the rule, which is the type of edge connecting the rule to its parent.
136
+ :param new_weight: The new weight to set.
137
+ """
138
+ self._weight = new_weight
139
+ if self.node:
140
+ self.node.weight = new_weight.value
141
+
142
+ def get_an_updated_case_copy(self, case: Case) -> Case:
143
+ """
144
+ :param case: The case to copy and update.
145
+ :return: A copy of the case updated with this rule conclusion.
146
+ """
147
+ return get_an_updated_case_copy(case, self.conclusion, self.conclusion_name, self.conclusion.conclusion_type,
148
+ self.mutually_exclusive)
149
+
150
+ def reset(self):
151
+ self.evaluated = False
152
+ self.fired = False
153
+ self.contributed = False
154
+ self.contributed_to_case_query = False
155
+ self.last_conclusion = None
156
+
157
+ @property
158
+ def color(self) -> str:
159
+ if self.evaluated:
160
+ if self.contributed_to_case_query:
161
+ return "green"
162
+ elif self.contributed:
163
+ return "yellow"
164
+ elif self.fired:
165
+ return "orange"
166
+ else:
167
+ return "red"
168
+ else:
169
+ return "white"
60
170
 
61
171
  @classmethod
62
- def from_case_query(cls, case_query: CaseQuery) -> Rule:
172
+ def from_case_query(cls, case_query: CaseQuery, parent: Optional[Rule] = None) -> Rule:
63
173
  """
64
174
  Create a SingleClassRule from a CaseQuery.
65
175
 
66
176
  :param case_query: The CaseQuery to create the rule from.
177
+ :param parent: The parent rule of this rule.
67
178
  :return: A SingleClassRule instance.
68
179
  """
69
180
  corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
70
181
  return cls(conditions=case_query.conditions, conclusion=case_query.target,
71
- corner_case=case_query.case, parent=None,
182
+ corner_case=case_query.case, _parent=parent,
72
183
  corner_case_metadata=corner_case_metadata,
73
184
  conclusion_name=case_query.attribute_name)
74
185
 
@@ -91,6 +202,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
91
202
  :param x: The case to evaluate the rule on.
92
203
  :return: The rule that fired or the last evaluated rule if no rule fired.
93
204
  """
205
+ self.evaluated = True
94
206
  if not self.conditions:
95
207
  raise ValueError("Rule has no conditions")
96
208
  self.fired = self.conditions(x)
@@ -149,6 +261,18 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
149
261
  f.write(conclusion_func.strip() + "\n\n\n")
150
262
  return conclusion_func_call
151
263
 
264
+ @property
265
+ def generated_conclusion_function_name(self) -> str:
266
+ return f"conclusion_{self.uid}"
267
+
268
+ @property
269
+ def generated_conditions_function_name(self) -> str:
270
+ return f"conditions_{self.uid}"
271
+
272
+ @property
273
+ def generated_corner_case_object_name(self) -> str:
274
+ return f"corner_case_{self.uid}"
275
+
152
276
  def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[Optional[str], str]:
153
277
  """
154
278
  Convert the conclusion of a rule to source code.
@@ -161,23 +285,24 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
161
285
  # This means the conclusion is a definition that should be written and then called
162
286
  conclusion_lines = conclusion.split('\n')
163
287
  # use regex to replace the function name
164
- new_function_name = f"def conclusion_{self.uid}"
288
+ new_function_name = f"def {self.generated_conclusion_function_name}"
165
289
  conclusion_lines[0] = re.sub(r"def (\w+)", new_function_name, conclusion_lines[0])
166
290
  # add type hint
167
- if len(self.conclusion.conclusion_type) == 1:
168
- hint = self.conclusion.conclusion_type[0].__name__
291
+ if not self.conclusion.mutually_exclusive:
292
+ type_names = [t.__name__ for t in self.conclusion.conclusion_type if t not in [list, set]]
293
+ if len(type_names) == 1:
294
+ hint = f"List[{type_names[0]}]"
295
+ else:
296
+ hint = f"List[Union[{', '.join(type_names)}]]"
169
297
  else:
170
- if (all(t in self.conclusion.conclusion_type for t in [list, set])
171
- and len(self.conclusion.conclusion_type) > 2):
172
- type_names = [t.__name__ for t in self.conclusion.conclusion_type if t not in [list, set]]
173
- hint = f"List[{', '.join(type_names)}]"
298
+ if NoneType in self.conclusion.conclusion_type:
299
+ type_names = [t.__name__ for t in self.conclusion.conclusion_type if t is not NoneType]
300
+ hint = f"Optional[{', '.join(type_names)}]"
301
+ elif len(self.conclusion.conclusion_type) == 1:
302
+ hint = self.conclusion.conclusion_type[0].__name__
174
303
  else:
175
- if NoneType in self.conclusion.conclusion_type:
176
- type_names = [t.__name__ for t in self.conclusion.conclusion_type if t is not NoneType]
177
- hint = f"Optional[{', '.join(type_names)}]"
178
- else:
179
- type_names = [t.__name__ for t in self.conclusion.conclusion_type]
180
- hint = f"Union[{', '.join(type_names)}]"
304
+ type_names = [t.__name__ for t in self.conclusion.conclusion_type]
305
+ hint = f"Union[{', '.join(type_names)}]"
181
306
  conclusion_lines[0] = conclusion_lines[0].replace("):", f") -> {hint}:")
182
307
  func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
183
308
  return "\n".join(conclusion_lines).strip(' '), func_call
@@ -199,7 +324,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
199
324
  # This means the conditions are a definition that should be written and then called
200
325
  conditions_lines = self.conditions.user_input.split('\n')
201
326
  # use regex to replace the function name
202
- new_function_name = f"def conditions_{self.uid}"
327
+ new_function_name = f"def {self.generated_conditions_function_name}"
203
328
  conditions_lines[0] = re.sub(r"def (\w+)", new_function_name, conditions_lines[0])
204
329
  # add type hint
205
330
  conditions_lines[0] = conditions_lines[0].replace('):', ') -> bool:')
@@ -216,34 +341,22 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
216
341
  pass
217
342
 
218
343
  def _to_json(self) -> Dict[str, Any]:
219
- try:
220
- corner_case = SubclassJSONSerializer.to_json_static(self.corner_case) if self.corner_case else None
221
- except Exception as e:
222
- logging.debug("Failed to serialize corner case to json, setting it to None. Error: %s", e)
223
- corner_case = None
224
344
  json_serialization = {"_type": get_full_class_name(type(self)),
225
345
  "conditions": self.conditions.to_json(),
226
346
  "conclusion": conclusion_to_json(self.conclusion),
227
347
  "parent": self.parent.json_serialization if self.parent else None,
228
- "corner_case": corner_case,
229
348
  "conclusion_name": self.conclusion_name,
230
- "weight": self.weight,
349
+ "weight": self.weight.value,
231
350
  "uid": self.uid}
232
351
  return json_serialization
233
352
 
234
353
  @classmethod
235
354
  def _from_json(cls, data: Dict[str, Any]) -> Rule:
236
- try:
237
- corner_case = Case.from_json(data["corner_case"])
238
- except Exception as e:
239
- logging.debug("Failed to load corner case from json, setting it to None.")
240
- corner_case = None
241
355
  loaded_rule = cls(conditions=CallableExpression.from_json(data["conditions"]),
242
356
  conclusion=CallableExpression.from_json(data["conclusion"]),
243
- parent=cls.from_json(data["parent"]),
244
- corner_case=corner_case,
357
+ _parent=cls.from_json(data["parent"]),
245
358
  conclusion_name=data["conclusion_name"],
246
- weight=data["weight"],
359
+ _weight=RDREdge.from_value(data["weight"]),
247
360
  uid=data["uid"])
248
361
  return loaded_rule
249
362
 
@@ -260,30 +373,75 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
260
373
  Set the name of the rule.
261
374
  """
262
375
  self._name = new_name
376
+ self.node.name = new_name
377
+
378
+ @property
379
+ def semantic_condition_name(self) -> Optional[str]:
380
+ """
381
+ Get the name of the conditions of the rule, which is the user input of the conditions.
382
+ """
383
+ if isinstance(self.conditions, CallableExpression):
384
+ return self.expression_name(self.conditions)
385
+ return None
386
+
387
+ @property
388
+ def semantic_conclusion_name(self) -> Optional[str]:
389
+ """
390
+ Get the name of the conclusion of the rule, which is the user input of the conclusion.
391
+ """
392
+ if isinstance(self.conclusion, CallableExpression):
393
+ return self.expression_name(self.conclusion)
394
+ return None
395
+
396
+ @staticmethod
397
+ def expression_name(expression: CallableExpression) -> str:
398
+ """
399
+ Get the name of the expression, which is the user input of the expression if it exists,
400
+ otherwise it is the conclusion or conditions of the rule.
401
+ """
402
+ if expression.user_defined_name is not None and expression.user_defined_name != expression.encapsulating_function_name:
403
+ return expression.user_defined_name.strip()
404
+ func_name = expression.user_input.split('(')[0].replace('def ',
405
+ '').strip() if "def " in expression.user_input else None
406
+ if func_name is not None and func_name != expression.encapsulating_function_name:
407
+ return func_name
408
+ elif expression.user_input:
409
+ return expression.user_input.strip()
410
+ else:
411
+ return str(expression)
263
412
 
264
413
  def __str__(self, sep="\n"):
265
414
  """
266
415
  Get the string representation of the rule, which is the conditions and the conclusion.
267
416
  """
268
- return f"{self.conditions}{sep}=> {self.conclusion}"
417
+ return f"{self.semantic_condition_name}{sep}=> {self.semantic_conclusion_name}"
269
418
 
270
419
  def __repr__(self):
271
420
  return self.__str__()
272
421
 
422
+ def __eq__(self, other):
423
+ if not isinstance(other, Rule):
424
+ return False
425
+ return other.uid == self.uid
426
+
427
+ def __hash__(self):
428
+ return hash(self.uid)
273
429
 
430
+
431
+ @dataclass
274
432
  class HasAlternativeRule:
275
433
  """
276
434
  A mixin class for rules that have an alternative rule.
277
435
  """
278
- _alternative: Optional[Rule] = None
436
+ _alternative: Optional[Rule] = field(init=False, default=None)
279
437
  """
280
438
  The alternative rule of the rule, which is evaluated when the rule doesn't fire.
281
439
  """
282
- furthest_alternative: Optional[List[Rule]] = None
440
+ furthest_alternative: Optional[List[Rule]] = field(init=False, default=None)
283
441
  """
284
442
  The furthest alternative rule of the rule, which is the last alternative rule in the chain of alternative rules.
285
443
  """
286
- all_alternatives: Optional[List[Rule]] = None
444
+ all_alternatives: Optional[List[Rule]] = field(init=False, default=None)
287
445
  """
288
446
  All alternative rules of the rule, which is all the alternative rules in the chain of alternative rules.
289
447
  """
@@ -304,13 +462,14 @@ class HasAlternativeRule:
304
462
  self.furthest_alternative[-1].alternative = new_rule
305
463
  else:
306
464
  new_rule.parent = self
307
- new_rule.weight = RDREdge.Alternative.value if not new_rule.weight else new_rule.weight
465
+ new_rule.weight = RDREdge.Alternative if not new_rule.weight else new_rule.weight
308
466
  self._alternative = new_rule
309
467
  self.furthest_alternative = [new_rule]
310
468
 
311
469
 
470
+ @dataclass
312
471
  class HasRefinementRule:
313
- _refinement: Optional[HasAlternativeRule] = None
472
+ _refinement: Optional[HasAlternativeRule] = field(init=False, default=None)
314
473
  """
315
474
  The refinement rule of the rule, which is evaluated when the rule fires.
316
475
  """
@@ -327,20 +486,23 @@ class HasRefinementRule:
327
486
  """
328
487
  if new_rule is None:
329
488
  return
330
- new_rule.top_rule = self
331
489
  if self.refinement:
332
490
  self.refinement.alternative = new_rule
333
491
  else:
334
492
  new_rule.parent = self
335
- new_rule.weight = RDREdge.Refinement.value
493
+ new_rule.weight = RDREdge.Refinement if not isinstance(new_rule,
494
+ MultiClassFilterRule) else new_rule.weight
336
495
  self._refinement = new_rule
337
496
 
338
497
 
498
+ @dataclass(eq=False)
339
499
  class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
340
500
  """
341
501
  A rule in the SingleClassRDR classifier, it can have a refinement or an alternative rule or both.
342
502
  """
343
503
 
504
+ mutually_exclusive: bool = field(init=False, default=True)
505
+
344
506
  def evaluate_next_rule(self, x: Case) -> SingleClassRule:
345
507
  if self.fired:
346
508
  returned_rule = self.refinement(x) if self.refinement else self
@@ -351,7 +513,7 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
351
513
  def fit_rule(self, case_query: CaseQuery):
352
514
  corner_case_metadata = CaseFactoryMetaData.from_case_query(case_query)
353
515
  new_rule = SingleClassRule(case_query.conditions, case_query.target,
354
- corner_case=case_query.case, parent=self,
516
+ corner_case=case_query.case, _parent=self,
355
517
  corner_case_metadata=corner_case_metadata,
356
518
  )
357
519
  if self.fired:
@@ -373,31 +535,19 @@ class SingleClassRule(Rule, HasAlternativeRule, HasRefinementRule):
373
535
  return loaded_rule
374
536
 
375
537
  def _if_statement_source_code_clause(self) -> str:
376
- return "elif" if self.weight == RDREdge.Alternative.value else "if"
538
+ return "elif" if self.weight == RDREdge.Alternative else "if"
377
539
 
378
540
 
379
- class MultiClassStopRule(Rule, HasAlternativeRule):
541
+ @dataclass(eq=False)
542
+ class MultiClassRefinementRule(Rule, HasAlternativeRule, ABC):
380
543
  """
381
- A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
382
- the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
544
+ A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule.
383
545
  """
384
- top_rule: Optional[MultiClassTopRule] = None
546
+ top_rule: Optional[MultiClassTopRule] = field(init=False, default=None)
385
547
  """
386
548
  The top rule of the rule, which is the nearest ancestor that fired and this rule is a refinement of.
387
549
  """
388
-
389
- def __init__(self, *args, **kwargs):
390
- super(MultiClassStopRule, self).__init__(*args, **kwargs)
391
- self.conclusion = CallableExpression(conclusion_type=(Stop,), conclusion=Stop.stop)
392
-
393
- def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
394
- if self.fired:
395
- self.top_rule.fired = False
396
- return self.top_rule.alternative
397
- elif self.alternative:
398
- return self.alternative(x)
399
- else:
400
- return self.top_rule.alternative
550
+ mutually_exclusive: bool = field(init=False, default=False)
401
551
 
402
552
  def _to_json(self) -> Dict[str, Any]:
403
553
  self.json_serialization = {**Rule._to_json(self),
@@ -405,47 +555,135 @@ class MultiClassStopRule(Rule, HasAlternativeRule):
405
555
  return self.json_serialization
406
556
 
407
557
  @classmethod
408
- def _from_json(cls, data: Dict[str, Any]) -> MultiClassStopRule:
409
- loaded_rule = super(MultiClassStopRule, cls)._from_json(data)
558
+ def _from_json(cls, data: Dict[str, Any]) -> MultiClassRefinementRule:
559
+ loaded_rule = super(MultiClassRefinementRule, cls)._from_json(data)
410
560
  # The following is done to prevent re-initialization of the top rule,
411
561
  # so the top rule that is already initialized is passed in the data instead of its json serialization.
412
562
  loaded_rule.top_rule = data['top_rule']
413
563
  if data['alternative'] is not None:
414
564
  data['alternative']['top_rule'] = data['top_rule']
415
- loaded_rule.alternative = MultiClassStopRule.from_json(data["alternative"])
565
+ loaded_rule.alternative = SubclassJSONSerializer.from_json(data["alternative"])
416
566
  return loaded_rule
417
567
 
568
+ def _if_statement_source_code_clause(self) -> str:
569
+ return "elif" if self.weight == RDREdge.Alternative else "if"
570
+
571
+
572
+ @dataclass(eq=False)
573
+ class MultiClassStopRule(MultiClassRefinementRule):
574
+ """
575
+ A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
576
+ the conclusion of the rule is a Stop category meant to stop the parent conclusion from being made.
577
+ """
578
+ conclusion: CallableExpression = field(default_factory=lambda: CallableExpression(conclusion_type=(Stop,),
579
+ conclusion=Stop.stop))
580
+ """
581
+ The conclusion of the rule, which is a CallableExpression that returns a Stop category.
582
+ """
583
+
584
+ def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
585
+ if self.fired:
586
+ self.top_rule.fired = False
587
+ return self.top_rule.alternative
588
+ elif self.alternative:
589
+ return self.alternative(x)
590
+ else:
591
+ return self.top_rule.alternative
592
+
418
593
  def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
419
594
  return None, f"{parent_indent}{' ' * 4}pass\n"
420
595
 
421
- def _if_statement_source_code_clause(self) -> str:
422
- return "elif" if self.weight == RDREdge.Alternative.value else "if"
423
596
 
597
+ @dataclass(eq=False)
598
+ class MultiClassFilterRule(MultiClassRefinementRule, HasRefinementRule):
599
+ """
600
+ A rule in the MultiClassRDR classifier, it can have an alternative rule and a top rule,
601
+ the conclusion of the rule is a Filter category meant to filter the parent conclusion.
602
+ """
603
+ weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Filter)
604
+
605
+ def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassRefinementRule, MultiClassTopRule]]:
606
+ if self.fired:
607
+ if self.refinement:
608
+ case_cp = x
609
+ if isinstance(self.refinement, MultiClassFilterRule):
610
+ case_cp = self.get_an_updated_case_copy(case_cp)
611
+ return self.refinement(case_cp)
612
+ else:
613
+ return self.top_rule.alternative
614
+ elif self.alternative:
615
+ return self.alternative(x)
616
+ else:
617
+ return self.top_rule.alternative
618
+
619
+ def get_conclusion_as_source_code(self, conclusion: Any, parent_indent: str = "") -> Tuple[None, str]:
620
+ func, func_call = super().get_conclusion_as_source_code(str(conclusion), parent_indent=parent_indent)
621
+ conclusion_str = func_call.replace("return ", "").strip()
622
+ conclusion_str = conclusion_str.replace("(case)", "(case_cp)")
424
623
 
624
+ parent_to_filter = self.get_parent_to_filter()
625
+ statement = (
626
+ f"\n{parent_indent} case_cp = get_an_updated_case_copy(case, {parent_to_filter.generated_conclusion_function_name},"
627
+ f" attribute_name, conclusion_type, mutually_exclusive)")
628
+ statement += f"\n{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
629
+ return func, statement
630
+
631
+ def get_parent_to_filter(self, parent: Union[None, MultiClassRefinementRule, MultiClassTopRule] = None) \
632
+ -> Union[MultiClassFilterRule, MultiClassTopRule]:
633
+ parent = self.parent if parent is None else parent
634
+ if isinstance(parent, (MultiClassFilterRule, MultiClassTopRule)) and parent.fired:
635
+ return parent
636
+ else:
637
+ return parent.parent
638
+
639
+ def _to_json(self) -> Dict[str, Any]:
640
+ self.json_serialization = super(MultiClassFilterRule, self)._to_json()
641
+ self.json_serialization['refinement'] = self.refinement.to_json() if self.refinement is not None else None
642
+ return self.json_serialization
643
+
644
+ @classmethod
645
+ def _from_json(cls, data: Dict[str, Any]) -> MultiClassFilterRule:
646
+ loaded_rule = super(MultiClassFilterRule, cls)._from_json(data)
647
+ if data['refinement'] is not None:
648
+ data['refinement']['top_rule'] = data['top_rule']
649
+ loaded_rule.refinement = cls.from_json(data["refinement"]) if data["refinement"] is not None else None
650
+ return loaded_rule
651
+
652
+
653
+ @dataclass(eq=False)
425
654
  class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
426
655
  """
427
656
  A rule in the MultiClassRDR classifier, it can have a refinement and a next rule.
428
657
  """
429
-
430
- def __init__(self, *args, **kwargs):
431
- super(MultiClassTopRule, self).__init__(*args, **kwargs)
432
- self.weight = RDREdge.Next.value
658
+ mutually_exclusive: bool = field(init=False, default=False)
659
+ weight: RDREdge = field(init=False, default_factory=lambda: RDREdge.Next)
433
660
 
434
661
  def evaluate_next_rule(self, x: Case) -> Optional[Union[MultiClassStopRule, MultiClassTopRule]]:
435
662
  if self.fired and self.refinement:
436
- return self.refinement(x)
663
+ case_cp = x
664
+ if isinstance(self.refinement, MultiClassFilterRule):
665
+ case_cp = self.get_an_updated_case_copy(case_cp)
666
+ return self.refinement(case_cp)
437
667
  elif self.alternative: # Here alternative refers to next rule in MultiClassRDR
438
668
  return self.alternative
669
+ return None
439
670
 
440
- def fit_rule(self, case_query: CaseQuery):
671
+ def fit_rule(self, case_query: CaseQuery, refinement_type: Optional[Type[MultiClassRefinementRule]] = None):
441
672
  if self.fired and case_query.target != self.conclusion:
442
- self.refinement = MultiClassStopRule(case_query.conditions, corner_case=case_query.case, parent=self)
673
+ if refinement_type in [None, MultiClassStopRule]:
674
+ new_rule = MultiClassStopRule(case_query.conditions, corner_case=case_query.case,
675
+ _parent=self)
676
+ elif refinement_type is MultiClassFilterRule:
677
+ new_rule = MultiClassFilterRule.from_case_query(case_query, parent=self)
678
+ else:
679
+ raise ValueError(f"Unknown refinement type {refinement_type}")
680
+ new_rule.top_rule = self
681
+ self.refinement = new_rule
443
682
  elif not self.fired:
444
- self.alternative = MultiClassTopRule(case_query.conditions, case_query.target,
445
- corner_case=case_query.case, parent=self)
683
+ self.alternative = MultiClassTopRule.from_case_query(case_query, parent=self)
446
684
 
447
685
  def _to_json(self) -> Dict[str, Any]:
448
- self.json_serialization = {**Rule._to_json(self),
686
+ self.json_serialization = {**super()._to_json(),
449
687
  "refinement": self.refinement.to_json() if self.refinement is not None else None,
450
688
  "alternative": self.alternative.to_json() if self.alternative is not None else None}
451
689
  return self.json_serialization
@@ -457,7 +695,8 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
457
695
  # so the top rule that is already initialized is passed in the data instead of its json serialization.
458
696
  if data['refinement'] is not None:
459
697
  data['refinement']['top_rule'] = loaded_rule
460
- loaded_rule.refinement = MultiClassStopRule.from_json(data["refinement"])
698
+ data_type = get_type_from_string(data["refinement"]["_type"])
699
+ loaded_rule.refinement = data_type.from_json(data["refinement"])
461
700
  loaded_rule.alternative = MultiClassTopRule.from_json(data["alternative"])
462
701
  return loaded_rule
463
702
 
@@ -466,8 +705,6 @@ class MultiClassTopRule(Rule, HasRefinementRule, HasAlternativeRule):
466
705
  conclusion_str = func_call.replace("return ", "").strip()
467
706
 
468
707
  statement = f"{parent_indent} conclusions.update(make_set({conclusion_str}))\n"
469
- if self.alternative is None:
470
- statement += f"{parent_indent}return conclusions\n"
471
708
  return func, statement
472
709
 
473
710
  def _if_statement_source_code_clause(self) -> str: