ripple-down-rules 0.4.88__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ripple_down_rules/rdr.py CHANGED
@@ -2,6 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import copyreg
4
4
  import importlib
5
+ import os
6
+
5
7
  from . import logger
6
8
  import sys
7
9
  from abc import ABC, abstractmethod
@@ -34,7 +36,8 @@ except ImportError as e:
34
36
  RDRCaseViewer = None
35
37
  from .utils import draw_tree, make_set, copy_case, \
36
38
  SubclassJSONSerializer, make_list, get_type_from_string, \
37
- is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name
39
+ is_conflicting, update_case, get_imports_from_scope, extract_function_source, extract_imports, get_full_class_name, \
40
+ is_iterable, str_to_snake_case
38
41
 
39
42
 
40
43
  class RippleDownRules(SubclassJSONSerializer, ABC):
@@ -61,17 +64,90 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
61
64
  """
62
65
  The type of the case (input) to the RDR classifier.
63
66
  """
67
+ case_name: Optional[str] = None
68
+ """
69
+ The name of the case type.
70
+ """
71
+ metadata_folder: str = "rdr_metadata"
72
+ """
73
+ The folder to save the metadata of the RDR classifier.
74
+ """
75
+ model_name: Optional[str] = None
76
+ """
77
+ The name of the model. If None, the model name will be the generated python file name.
78
+ """
64
79
 
65
- def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None):
80
+ def __init__(self, start_rule: Optional[Rule] = None, viewer: Optional[RDRCaseViewer] = None,
81
+ save_dir: Optional[str] = None, ask_always: bool = True, model_name: Optional[str] = None):
66
82
  """
67
83
  :param start_rule: The starting rule for the classifier.
84
+ :param viewer: The viewer gui to use for the classifier. If None, no viewer is used.
85
+ :param save_dir: The directory to save the classifier to.
86
+ :param ask_always: Whether to always ask the expert (True) or only ask when classification fails (False).
68
87
  """
88
+ self.ask_always: bool = ask_always
89
+ self.model_name: Optional[str] = model_name
90
+ self.save_dir = save_dir
69
91
  self.start_rule = start_rule
70
92
  self.fig: Optional[Figure] = None
71
93
  self.viewer: Optional[RDRCaseViewer] = viewer
72
94
  if self.viewer is not None:
73
95
  self.viewer.set_save_function(self.save)
74
96
 
97
+ def save(self, save_dir: Optional[str] = None, model_name: Optional[str] = None) -> str:
98
+ """
99
+ Save the classifier to a file.
100
+
101
+ :param save_dir: The directory to save the classifier to.
102
+ :param model_name: The name of the model to save. If None, a default name is generated.
103
+ :param postfix: The postfix to add to the file name.
104
+ :return: The name of the saved model.
105
+ """
106
+ save_dir = save_dir or self.save_dir
107
+ if save_dir is None:
108
+ raise ValueError("The save directory cannot be None. Please provide a valid directory to save"
109
+ " the classifier.")
110
+ if not os.path.exists(save_dir + '/__init__.py'):
111
+ os.makedirs(save_dir, exist_ok=True)
112
+ with open(save_dir + '/__init__.py', 'w') as f:
113
+ f.write("# This is an empty __init__.py file to make the directory a package.\n")
114
+ if model_name is not None:
115
+ self.model_name = model_name
116
+ elif self.model_name is None:
117
+ self.model_name = self.generated_python_file_name
118
+ model_dir = os.path.join(save_dir, self.model_name)
119
+ os.makedirs(model_dir, exist_ok=True)
120
+ json_dir = os.path.join(model_dir, self.metadata_folder)
121
+ os.makedirs(json_dir, exist_ok=True)
122
+ self.to_json_file(os.path.join(json_dir, self.model_name))
123
+ self._write_to_python(model_dir)
124
+ return self.model_name
125
+
126
+ @classmethod
127
+ def load(cls, load_dir: str, model_name: str) -> Self:
128
+ """
129
+ Load the classifier from a file.
130
+
131
+ :param load_dir: The path to the model directory to load the classifier from.
132
+ :param model_name: The name of the model to load.
133
+ """
134
+ model_dir = os.path.join(load_dir, model_name)
135
+ json_file = os.path.join(model_dir, cls.metadata_folder, model_name)
136
+ rdr = cls.from_json_file(json_file)
137
+ rdr.update_from_python(model_dir)
138
+ rdr.save_dir = load_dir
139
+ rdr.model_name = model_name
140
+ return rdr
141
+
142
+ @abstractmethod
143
+ def _write_to_python(self, model_dir: str):
144
+ """
145
+ Write the tree of rules as source code to a file.
146
+
147
+ :param model_dir: The path to the directory to write the source code to.
148
+ """
149
+ pass
150
+
75
151
  def set_viewer(self, viewer: RDRCaseViewer):
76
152
  """
77
153
  Set the viewer for the classifier.
@@ -160,19 +236,29 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
160
236
  """
161
237
  if case_query is None:
162
238
  raise ValueError("The case query cannot be None.")
239
+
163
240
  self.name = case_query.attribute_name if self.name is None else self.name
164
241
  self.case_type = case_query.case_type if self.case_type is None else self.case_type
242
+ self.case_name = case_query.case_name if self.case_name is None else self.case_name
243
+
165
244
  if case_query.target is None:
166
245
  case_query_cp = copy(case_query)
167
- self.classify(case_query_cp.case, modify_case=True)
168
- expert.ask_for_conclusion(case_query_cp)
169
- case_query.target = case_query_cp.target
246
+ conclusions = self.classify(case_query_cp.case, modify_case=True)
247
+ if self.ask_always or conclusions is None or is_iterable(conclusions) and len(conclusions) == 0:
248
+ expert.ask_for_conclusion(case_query_cp)
249
+ case_query.target = case_query_cp.target
170
250
  if case_query.target is None:
171
251
  return self.classify(case_query.case)
172
252
 
173
253
  self.update_start_rule(case_query, expert)
174
254
 
175
- return self._fit_case(case_query, expert=expert, **kwargs)
255
+ fit_case_result = self._fit_case(case_query, expert=expert, **kwargs)
256
+
257
+ if self.save_dir is not None:
258
+ self.save()
259
+ expert.clear_answers()
260
+
261
+ return fit_case_result
176
262
 
177
263
  @abstractmethod
178
264
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
@@ -238,28 +324,54 @@ class RippleDownRules(SubclassJSONSerializer, ABC):
238
324
  pass
239
325
 
240
326
  @abstractmethod
241
- def update_from_python_file(self, package_dir: str):
327
+ def update_from_python(self, model_dir: str):
242
328
  """
243
329
  Update the rules from the generated python file, that might have been modified by the user.
244
330
 
245
- :param package_dir: The directory of the package that contains the generated python file.
331
+ :param model_dir: The directory where the generated python file is located.
246
332
  """
247
333
  pass
248
334
 
335
+ @classmethod
336
+ def get_acronym(cls) -> str:
337
+ """
338
+ :return: The acronym of the classifier.
339
+ """
340
+ if cls.__name__ == "GeneralRDR":
341
+ return "RDR"
342
+ elif cls.__name__ == "MultiClassRDR":
343
+ return "MCRDR"
344
+ else:
345
+ return "SCRDR"
346
+
347
+ def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
348
+ """
349
+ :param package_name: The name of the package that contains the RDR classifier function.
350
+ :return: The module that contains the rdr classifier function.
351
+ """
352
+ # remove from imports if exists first
353
+ name = f"{package_name.strip('./').replace('/', '.')}.{self.generated_python_file_name}"
354
+ try:
355
+ module = importlib.import_module(name)
356
+ del sys.modules[name]
357
+ except ModuleNotFoundError:
358
+ pass
359
+ return importlib.import_module(name).classify
360
+
249
361
 
250
362
  class RDRWithCodeWriter(RippleDownRules, ABC):
251
363
 
252
- def update_from_python_file(self, package_dir: str):
364
+ def update_from_python(self, model_dir: str):
253
365
  """
254
366
  Update the rules from the generated python file, that might have been modified by the user.
255
367
 
256
- :param package_dir: The directory of the package that contains the generated python file.
368
+ :param model_dir: The directory where the generated python file is located.
257
369
  """
258
- rule_ids = [r.uid for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None]
259
- condition_func_names = [f'conditions_{rid}' for rid in rule_ids]
260
- conclusion_func_names = [f'conclusion_{rid}' for rid in rule_ids]
370
+ rules_dict = {r.uid: r for r in [self.start_rule] + list(self.start_rule.descendants) if r.conditions is not None}
371
+ condition_func_names = [f'conditions_{rid}' for rid in rules_dict.keys()]
372
+ conclusion_func_names = [f'conclusion_{rid}' for rid in rules_dict.keys() if not isinstance(rules_dict[rid], MultiClassStopRule)]
261
373
  all_func_names = condition_func_names + conclusion_func_names
262
- filepath = f"{package_dir}/{self.generated_python_defs_file_name}.py"
374
+ filepath = f"{model_dir}/{self.generated_python_defs_file_name}.py"
263
375
  functions_source = extract_function_source(filepath, all_func_names, include_signature=False)
264
376
  # get the scope from the imports in the file
265
377
  scope = extract_imports(filepath)
@@ -267,7 +379,7 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
267
379
  if rule.conditions is not None:
268
380
  rule.conditions.user_input = functions_source[f"conditions_{rule.uid}"]
269
381
  rule.conditions.scope = scope
270
- if rule.conclusion is not None:
382
+ if rule.conclusion is not None and not isinstance(rule, MultiClassStopRule):
271
383
  rule.conclusion.user_input = functions_source[f"conclusion_{rule.uid}"]
272
384
  rule.conclusion.scope = scope
273
385
 
@@ -284,17 +396,19 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
284
396
  """
285
397
  pass
286
398
 
287
- def write_to_python_file(self, file_path: str, postfix: str = ""):
399
+ def _write_to_python(self, model_dir: str):
288
400
  """
289
401
  Write the tree of rules as source code to a file.
290
402
 
291
- :param file_path: The path to the file to write the source code to.
292
- :param postfix: The postfix to add to the file name.
403
+ :param model_dir: The path to the directory to write the source code to.
293
404
  """
294
- self.generated_python_file_name = self._default_generated_python_file_name + postfix
405
+ os.makedirs(model_dir, exist_ok=True)
406
+ if not os.path.exists(model_dir + '/__init__.py'):
407
+ with open(model_dir + '/__init__.py', 'w') as f:
408
+ f.write("# This is an empty __init__.py file to make the directory a package.\n")
295
409
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
296
- file_name = file_path + f"/{self.generated_python_file_name}.py"
297
- defs_file_name = file_path + f"/{self.generated_python_defs_file_name}.py"
410
+ file_name = model_dir + f"/{self.generated_python_file_name}.py"
411
+ defs_file_name = model_dir + f"/{self.generated_python_defs_file_name}.py"
298
412
  imports, defs_imports = self._get_imports()
299
413
  # clear the files first
300
414
  with open(defs_file_name, "w") as f:
@@ -345,20 +459,6 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
345
459
  imports = "\n".join(imports) + "\n"
346
460
  return imports, defs_imports
347
461
 
348
- def get_rdr_classifier_from_python_file(self, package_name: str) -> Callable[[Any], Any]:
349
- """
350
- :param package_name: The name of the package that contains the RDR classifier function.
351
- :return: The module that contains the rdr classifier function.
352
- """
353
- # remove from imports if exists first
354
- name = f"{package_name.strip('./')}.{self.generated_python_file_name}"
355
- try:
356
- module = importlib.import_module(name)
357
- del sys.modules[name]
358
- except ModuleNotFoundError:
359
- pass
360
- return importlib.import_module(name).classify
361
-
362
462
  @property
363
463
  def _default_generated_python_file_name(self) -> Optional[str]:
364
464
  """
@@ -366,23 +466,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
366
466
  """
367
467
  if self.start_rule is None or self.start_rule.conclusion is None:
368
468
  return None
369
- return f"{self.case_type.__name__.lower()}_{self.attribute_name}_{self.acronym.lower()}"
469
+ return f"{str_to_snake_case(self.case_name)}_{self.attribute_name}_{self.get_acronym().lower()}"
370
470
 
371
471
  @property
372
472
  def generated_python_defs_file_name(self) -> str:
373
473
  return f"{self.generated_python_file_name}_defs"
374
474
 
375
- @property
376
- def acronym(self) -> str:
377
- """
378
- :return: The acronym of the classifier.
379
- """
380
- if self.__class__.__name__ == "GeneralRDR":
381
- return "GRDR"
382
- elif self.__class__.__name__ == "MultiClassRDR":
383
- return "MCRDR"
384
- else:
385
- return "SCRDR"
386
475
 
387
476
  @property
388
477
  def conclusion_type(self) -> Tuple[Type]:
@@ -403,7 +492,9 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
403
492
 
404
493
  def _to_json(self) -> Dict[str, Any]:
405
494
  return {"start_rule": self.start_rule.to_json(), "generated_python_file_name": self.generated_python_file_name,
406
- "name": self.name, "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None}
495
+ "name": self.name,
496
+ "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
497
+ "case_name": self.case_name}
407
498
 
408
499
  @classmethod
409
500
  def _from_json(cls, data: Dict[str, Any]) -> Self:
@@ -411,13 +502,15 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
411
502
  Create an instance of the class from a json
412
503
  """
413
504
  start_rule = cls.start_rule_type().from_json(data["start_rule"])
414
- new_rdr = cls(start_rule)
505
+ new_rdr = cls(start_rule=start_rule)
415
506
  if "generated_python_file_name" in data:
416
507
  new_rdr.generated_python_file_name = data["generated_python_file_name"]
417
508
  if "name" in data:
418
509
  new_rdr.name = data["name"]
419
510
  if "case_type" in data:
420
511
  new_rdr.case_type = get_type_from_string(data["case_type"])
512
+ if "case_name" in data:
513
+ new_rdr.case_name = data["case_name"]
421
514
  return new_rdr
422
515
 
423
516
  @staticmethod
@@ -431,12 +524,12 @@ class RDRWithCodeWriter(RippleDownRules, ABC):
431
524
 
432
525
  class SingleClassRDR(RDRWithCodeWriter):
433
526
 
434
- def __init__(self, start_rule: Optional[SingleClassRule] = None, default_conclusion: Optional[Any] = None):
527
+ def __init__(self, default_conclusion: Optional[Any] = None, **kwargs):
435
528
  """
436
529
  :param start_rule: The starting rule for the classifier.
437
530
  :param default_conclusion: The default conclusion for the classifier if no rules fire.
438
531
  """
439
- super(SingleClassRDR, self).__init__(start_rule)
532
+ super(SingleClassRDR, self).__init__(**kwargs)
440
533
  self.default_conclusion: Optional[Any] = default_conclusion
441
534
 
442
535
  def _fit_case(self, case_query: CaseQuery, expert: Optional[Expert] = None, **kwargs) \
@@ -488,10 +581,10 @@ class SingleClassRDR(RDRWithCodeWriter):
488
581
  matched_rule = self.start_rule(case) if self.start_rule is not None else None
489
582
  return matched_rule if matched_rule is not None else self.start_rule
490
583
 
491
- def write_to_python_file(self, file_path: str, postfix: str = ""):
492
- super().write_to_python_file(file_path, postfix)
584
+ def _write_to_python(self, model_dir: str):
585
+ super()._write_to_python(model_dir)
493
586
  if self.default_conclusion is not None:
494
- with open(file_path + f"/{self.generated_python_file_name}.py", "a") as f:
587
+ with open(model_dir + f"/{self.generated_python_file_name}.py", "a") as f:
495
588
  f.write(f"{' ' * 4}else:\n{' ' * 8}return {self.default_conclusion}\n")
496
589
 
497
590
  def write_rules_as_source_code_to_file(self, rule: SingleClassRule, file: TextIOWrapper, parent_indent: str = "",
@@ -892,7 +985,8 @@ class GeneralRDR(RippleDownRules):
892
985
  return {"start_rules": {name: rdr.to_json() for name, rdr in self.start_rules_dict.items()}
893
986
  , "generated_python_file_name": self.generated_python_file_name,
894
987
  "name": self.name,
895
- "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None}
988
+ "case_type": get_full_class_name(self.case_type) if self.case_type is not None else None,
989
+ "case_name": self.case_name}
896
990
 
897
991
  @classmethod
898
992
  def _from_json(cls, data: Dict[str, Any]) -> GeneralRDR:
@@ -902,37 +996,37 @@ class GeneralRDR(RippleDownRules):
902
996
  start_rules_dict = {}
903
997
  for k, v in data["start_rules"].items():
904
998
  start_rules_dict[k] = get_type_from_string(v['_type']).from_json(v)
905
- new_rdr = cls(start_rules_dict)
999
+ new_rdr = cls(category_rdr_map=start_rules_dict)
906
1000
  if "generated_python_file_name" in data:
907
1001
  new_rdr.generated_python_file_name = data["generated_python_file_name"]
908
1002
  if "name" in data:
909
1003
  new_rdr.name = data["name"]
910
1004
  if "case_type" in data:
911
1005
  new_rdr.case_type = get_type_from_string(data["case_type"])
1006
+ if "case_name" in data:
1007
+ new_rdr.case_name = data["case_name"]
912
1008
  return new_rdr
913
1009
 
914
- def update_from_python_file(self, package_dir: str) -> None:
1010
+ def update_from_python(self, model_dir: str) -> None:
915
1011
  """
916
1012
  Update the rules from the generated python file, that might have been modified by the user.
917
1013
 
918
- :param package_dir: The directory of the package that contains the generated python file.
1014
+ :param model_dir: The directory where the model is stored.
919
1015
  """
920
1016
  for rdr in self.start_rules_dict.values():
921
- rdr.update_from_python_file(package_dir)
1017
+ rdr.update_from_python(model_dir)
922
1018
 
923
- def write_to_python_file(self, file_path: str, postfix: str = "") -> None:
1019
+ def _write_to_python(self, model_dir: str) -> None:
924
1020
  """
925
1021
  Write the tree of rules as source code to a file.
926
1022
 
927
- :param file_path: The path to the file to write the source code to.
928
- :param postfix: The postfix to add to the file name.
1023
+ :param model_dir: The directory where the model is stored.
929
1024
  """
930
- self.generated_python_file_name = self._default_generated_python_file_name + postfix
931
1025
  for rdr in self.start_rules_dict.values():
932
- rdr.write_to_python_file(file_path, postfix=f"_of_grdr{postfix}")
1026
+ rdr._write_to_python(model_dir)
933
1027
  func_def = f"def classify(case: {self.case_type.__name__}) -> {self.conclusion_type_hint}:\n"
934
- with open(file_path + f"/{self.generated_python_file_name}.py", "w") as f:
935
- f.write(self._get_imports(file_path) + "\n\n")
1028
+ with open(model_dir + f"/{self.generated_python_file_name}.py", "w") as f:
1029
+ f.write(self._get_imports() + "\n\n")
936
1030
  f.write("classifiers_dict = dict()\n")
937
1031
  for rdr_key, rdr in self.start_rules_dict.items():
938
1032
  f.write(f"classifiers_dict['{rdr_key}'] = {self.rdr_key_to_function_name(rdr_key)}\n")
@@ -942,13 +1036,6 @@ class GeneralRDR(RippleDownRules):
942
1036
  f"{' ' * 4} case = create_case(case, max_recursion_idx=3)\n""")
943
1037
  f.write(f"{' ' * 4}return GeneralRDR._classify(classifiers_dict, case)\n")
944
1038
 
945
- def get_rdr_classifier_from_python_file(self, file_path: str) -> Callable[[Any], Any]:
946
- """
947
- :param file_path: The path to the file that contains the RDR classifier function.
948
- :return: The module that contains the rdr classifier function.
949
- """
950
- return importlib.import_module(f"{file_path.strip('./')}.{self.generated_python_file_name}").classify
951
-
952
1039
  @property
953
1040
  def _default_generated_python_file_name(self) -> Optional[str]:
954
1041
  """
@@ -956,17 +1043,16 @@ class GeneralRDR(RippleDownRules):
956
1043
  """
957
1044
  if self.start_rule is None or self.start_rule.conclusion is None:
958
1045
  return None
959
- return f"{self.case_type.__name__.lower()}_rdr".lower()
1046
+ return f"{str_to_snake_case(self.case_name)}_rdr".lower()
960
1047
 
961
1048
  @property
962
1049
  def conclusion_type_hint(self) -> str:
963
1050
  return f"Dict[str, Any]"
964
1051
 
965
- def _get_imports(self, file_path: str) -> str:
1052
+ def _get_imports(self) -> str:
966
1053
  """
967
1054
  Get the imports needed for the generated python file.
968
1055
 
969
- :param file_path: The path to the file that contains the RDR classifier function.
970
1056
  :return: The imports needed for the generated python file.
971
1057
  """
972
1058
  imports = ""
@@ -5,15 +5,17 @@ of the RDRs.
5
5
  """
6
6
  import os.path
7
7
  from functools import wraps
8
+
9
+ from pyparsing.tools.cvt_pyparsing_pep8_names import camel_to_snake
8
10
  from typing_extensions import Callable, Optional, Type, Tuple, Dict, Any, Self, get_type_hints, List, Union
9
11
 
10
- from ripple_down_rules.datastructures.case import create_case
12
+ from ripple_down_rules.datastructures.case import create_case, Case
11
13
  from ripple_down_rules.datastructures.dataclasses import CaseQuery
12
14
  from ripple_down_rules.datastructures.enums import Category
13
15
  from ripple_down_rules.experts import Expert, Human
14
16
  from ripple_down_rules.rdr import GeneralRDR, RippleDownRules
15
17
  from ripple_down_rules.utils import get_method_args_as_dict, get_func_rdr_model_name, make_set, \
16
- get_method_class_if_exists
18
+ get_method_class_if_exists, get_method_name, str_to_snake_case
17
19
 
18
20
 
19
21
  class RDRDecorator:
@@ -41,99 +43,118 @@ class RDRDecorator:
41
43
  :return: A decorator to use a GeneralRDR as a classifier that monitors and modifies the function's output.
42
44
  """
43
45
  self.rdr_models_dir = models_dir
46
+ self.model_name: Optional[str] = None
44
47
  self.output_type = output_type
45
48
  self.parsed_output_type: List[Type] = []
46
49
  self.mutual_exclusive = mutual_exclusive
47
50
  self.rdr_python_path: Optional[str] = python_dir
48
51
  self.output_name = output_name
49
52
  self.fit: bool = fit
50
- self.expert = expert if expert else Human()
51
- self.rdr_model_path: Optional[str] = None
53
+ self.expert: Optional[Expert] = expert
52
54
  self.load()
53
55
 
54
56
  def decorator(self, func: Callable) -> Callable:
55
57
 
56
58
  @wraps(func)
57
59
  def wrapper(*args, **kwargs) -> Optional[Any]:
60
+
58
61
  if len(self.parsed_output_type) == 0:
59
- self.parse_output_type(func, *args)
60
- if self.rdr_model_path is None:
61
- self.initialize_rdr_model_path_and_load(func)
62
- case_dict = get_method_args_as_dict(func, *args, **kwargs)
63
- func_output = func(*args, **kwargs)
64
- case_dict.update({self.output_name: func_output})
65
- case = create_case(case_dict, obj_name=get_func_rdr_model_name(func), max_recursion_idx=3)
62
+ self.parsed_output_type = self.parse_output_type(func, self.output_type, *args)
63
+ if self.model_name is None:
64
+ self.initialize_rdr_model_name_and_load(func)
65
+
66
66
  if self.fit:
67
- scope = func.__globals__
68
- scope.update(case_dict)
69
- func_args_type_hints = get_type_hints(func)
70
- func_args_type_hints.update({self.output_name: Union[tuple(self.parsed_output_type)]})
71
- case_query = CaseQuery(case, self.output_name, Union[tuple(self.parsed_output_type)],
72
- self.mutual_exclusive,
73
- scope=scope, is_function=True, function_args_type_hints=func_args_type_hints)
67
+ expert_answers_path = os.path.join(self.rdr_models_dir, self.model_name, "expert_answers")
68
+ self.expert = self.expert or Human(answers_save_path=expert_answers_path)
69
+ case_query = self.create_case_query_from_method(func, self.parsed_output_type,
70
+ self.mutual_exclusive, self.output_name,
71
+ *args, **kwargs)
74
72
  output = self.rdr.fit_case(case_query, expert=self.expert)
75
73
  return output[self.output_name]
76
74
  else:
75
+ case, case_dict = self.create_case_from_method(func, self.output_name, *args, **kwargs)
77
76
  return self.rdr.classify(case)[self.output_name]
78
77
 
79
78
  return wrapper
80
79
 
81
- def initialize_rdr_model_path_and_load(self, func: Callable) -> None:
80
+ @staticmethod
81
+ def create_case_query_from_method(func: Callable, output_type, mutual_exclusive: bool,
82
+ output_name: str = 'output_', *args, **kwargs) -> CaseQuery:
83
+ """
84
+ Create a CaseQuery from the function and its arguments.
85
+
86
+ :param func: The function to create a case from.
87
+ :param output_type: The type of the output.
88
+ :param mutual_exclusive: If True, the output types are mutually exclusive.
89
+ :param output_name: The name of the output in the case. Defaults to 'output_'.
90
+ :param args: The positional arguments of the function.
91
+ :param kwargs: The keyword arguments of the function.
92
+ :return: A CaseQuery object representing the case.
93
+ """
94
+ output_type = make_set(output_type)
95
+ case, case_dict = RDRDecorator.create_case_from_method(func, output_name, *args, **kwargs)
96
+ scope = func.__globals__
97
+ scope.update(case_dict)
98
+ func_args_type_hints = get_type_hints(func)
99
+ func_args_type_hints.update({output_name: Union[tuple(output_type)]})
100
+ return CaseQuery(case, output_name, Union[tuple(output_type)],
101
+ mutual_exclusive, scope=scope,
102
+ is_function=True, function_args_type_hints=func_args_type_hints)
103
+
104
+ @staticmethod
105
+ def create_case_from_method(func: Callable, output_name: str = "output_", *args, **kwargs) -> Tuple[Case, Dict[str, Any]]:
106
+ """
107
+ Create a Case from the function and its arguments.
108
+
109
+ :param func: The function to create a case from.
110
+ :param output_name: The name of the output in the case. Defaults to 'output_'.
111
+ :param args: The positional arguments of the function.
112
+ :param kwargs: The keyword arguments of the function.
113
+ :return: A Case object representing the case.
114
+ """
115
+ case_dict = get_method_args_as_dict(func, *args, **kwargs)
116
+ func_output = func(*args, **kwargs)
117
+ case_dict.update({output_name: func_output})
118
+ case_name = get_func_rdr_model_name(func)
119
+ return create_case(case_dict, obj_name=case_name, max_recursion_idx=3), case_dict
120
+
121
+ def initialize_rdr_model_name_and_load(self, func: Callable) -> None:
82
122
  model_file_name = get_func_rdr_model_name(func, include_file_name=True)
83
- model_file_name = (''.join(['_' + c.lower() if c.isupper() else c for c in model_file_name]).lstrip('_')
84
- .replace('__', '_') + ".json")
85
- self.rdr_model_path = os.path.join(self.rdr_models_dir, model_file_name)
123
+ self.model_name = str_to_snake_case(model_file_name)
86
124
  self.load()
87
125
 
88
- def parse_output_type(self, func: Callable, *args) -> None:
89
- for ot in make_set(self.output_type):
126
+ @staticmethod
127
+ def parse_output_type(func: Callable, output_type: Any, *args) -> List[Type]:
128
+ parsed_output_type = []
129
+ for ot in make_set(output_type):
90
130
  if ot is Self:
91
131
  func_class = get_method_class_if_exists(func, *args)
92
132
  if func_class is not None:
93
- self.parsed_output_type.append(func_class)
133
+ parsed_output_type.append(func_class)
94
134
  else:
95
135
  raise ValueError(f"The function {func} is not a method of a class,"
96
136
  f" and the output type is {Self}.")
97
137
  else:
98
- self.parsed_output_type.append(ot)
138
+ parsed_output_type.append(ot)
139
+ return parsed_output_type
99
140
 
100
141
  def save(self):
101
142
  """
102
143
  Save the RDR model to the specified directory.
103
144
  """
104
- self.rdr.save(self.rdr_model_path)
105
-
106
- if self.rdr_python_path is not None:
107
- if not os.path.exists(self.rdr_python_path):
108
- os.makedirs(self.rdr_python_path)
109
- if not os.path.exists(os.path.join(self.rdr_python_path, "__init__.py")):
110
- # add __init__.py file to the directory
111
- with open(os.path.join(self.rdr_python_path, "__init__.py"), "w") as f:
112
- f.write("# This is an empty __init__.py file to make the directory a package.")
113
- # write the RDR model to a python file
114
- self.rdr.write_to_python_file(self.rdr_python_path)
145
+ self.rdr.save(self.rdr_models_dir)
115
146
 
116
147
  def load(self):
117
148
  """
118
149
  Load the RDR model from the specified directory.
119
150
  """
120
- if self.rdr_model_path is not None and os.path.exists(self.rdr_model_path):
121
- self.rdr = GeneralRDR.load(self.rdr_model_path)
151
+ if self.model_name is not None and os.path.exists(os.path.join(self.rdr_models_dir, self.model_name)):
152
+ self.rdr = GeneralRDR.load(self.rdr_models_dir, self.model_name)
122
153
  else:
123
- self.rdr = GeneralRDR()
124
-
125
- def write_to_python_file(self, package_dir: str, file_name_postfix: str = ""):
126
- """
127
- Write the RDR model to a python file.
128
-
129
- :param package_dir: The path to the directory to write the python file.
130
- """
131
- self.rdr.write_to_python_file(package_dir, postfix=file_name_postfix)
154
+ self.rdr = GeneralRDR(save_dir=self.rdr_models_dir, model_name=self.model_name)
132
155
 
133
- def update_from_python_file(self, package_dir: str):
156
+ def update_from_python(self):
134
157
  """
135
158
  Update the RDR model from a python file.
136
-
137
- :param package_dir: The directory of the package that contains the generated python file.
138
159
  """
139
- self.rdr.update_from_python_file(package_dir)
160
+ self.rdr.update_from_python(self.rdr_models_dir, self.model_name)
@@ -118,7 +118,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
118
118
  func_call = f"{parent_indent} return {new_function_name.replace('def ', '')}(case)\n"
119
119
  return "\n".join(conclusion_lines).strip(' '), func_call
120
120
  else:
121
- raise ValueError(f"Conclusion is format is not valid, it should be contain a function definition."
121
+ raise ValueError(f"Conclusion format is not valid, it should contain a function definition."
122
122
  f" Instead got:\n{conclusion}\n")
123
123
 
124
124
  def write_condition_as_source_code(self, parent_indent: str = "", defs_file: Optional[str] = None) -> str:
@@ -129,9 +129,7 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
129
129
  :param defs_file: The file to write the conditions to if they are a definition.
130
130
  """
131
131
  if_clause = self._if_statement_source_code_clause()
132
- if '\n' not in self.conditions.user_input:
133
- return f"{parent_indent}{if_clause} {self.conditions.user_input}:\n"
134
- elif "def " in self.conditions.user_input:
132
+ if "def " in self.conditions.user_input:
135
133
  if defs_file is None:
136
134
  raise ValueError("Cannot write conditions to source code as definitions python file was not given.")
137
135
  # This means the conditions are a definition that should be written and then called
@@ -143,6 +141,9 @@ class Rule(NodeMixin, SubclassJSONSerializer, ABC):
143
141
  with open(defs_file, 'a') as f:
144
142
  f.write(def_code.strip() + "\n\n\n")
145
143
  return f"\n{parent_indent}{if_clause} {new_function_name.replace('def ', '')}(case):\n"
144
+ else:
145
+ raise ValueError(f"Conditions format is not valid, it should contain a function definition"
146
+ f" Instead got:\n{self.conditions.user_input}\n")
146
147
 
147
148
  @abstractmethod
148
149
  def _if_statement_source_code_clause(self) -> str: