ecodev-core 0.0.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,316 @@
1
+ """
2
+ Module implementing all jwt security logic
3
+ """
4
+ from datetime import datetime
5
+ from datetime import timedelta
6
+ from datetime import timezone
7
+ from typing import Any
8
+ from typing import Dict
9
+ from typing import List
10
+ from typing import Optional
11
+ from typing import Union
12
+
13
+ from fastapi import APIRouter
14
+ from fastapi import Depends
15
+ from fastapi import HTTPException
16
+ from fastapi import status
17
+ from fastapi.security import OAuth2PasswordBearer
18
+ from jose import jwt
19
+ from jose import JWTError
20
+ from passlib.context import CryptContext
21
+ from sqladmin.authentication import AuthenticationBackend
22
+ from sqlmodel import col
23
+ from sqlmodel import select
24
+ from sqlmodel import Session
25
+ from starlette.requests import Request
26
+ from starlette.responses import RedirectResponse
27
+
28
+ from ecodev_core.app_user import AppUser
29
+ from ecodev_core.auth_configuration import ALGO
30
+ from ecodev_core.auth_configuration import EXPIRATION_LENGTH
31
+ from ecodev_core.auth_configuration import SECRET_KEY
32
+ from ecodev_core.db_connection import engine
33
+ from ecodev_core.logger import logger_get
34
+ from ecodev_core.permissions import Permission
35
+ from ecodev_core.pydantic_utils import Frozen
36
+ from ecodev_core.token_banlist import TokenBanlist
37
+
38
+ SCHEME = OAuth2PasswordBearer(tokenUrl='login')
39
+ auth_router = APIRouter(tags=['authentication'])
40
+ CONTEXT = CryptContext(schemes=['bcrypt'], deprecated='auto')
41
+ MONITORING = 'monitoring'
42
+ MONITORING_ERROR = 'Could not validate credentials. You need to be the monitoring user to call this'
43
+ INVALID_USER = 'Invalid User'
44
+ INVALID_TFA = 'Invalid TFA code'
45
+ ADMIN_ERROR = 'Could not validate credentials. You need admin rights to call this'
46
+ INVALID_CREDENTIALS = 'Invalid Credentials'
47
+ REVOKED_TOKEN = 'This token has been revoked (by a logout action), please login again.'
48
+ log = logger_get(__name__)
49
+
50
+
51
+ class Token(Frozen):
52
+ """
53
+ Simple class for storing token value and type
54
+ """
55
+ access_token: str
56
+ token_type: str
57
+
58
+
59
+ class TokenData(Frozen):
60
+ """
61
+ Simple class storing token id information
62
+ """
63
+ id: int
64
+
65
+
66
+ def get_access_token(token: Dict[str, Any]) -> str | None:
67
+ """
68
+ Robust method to return access token or None
69
+ """
70
+ try:
71
+ return token.get('token', {}).get('access_token')
72
+ except AttributeError:
73
+ return None
74
+
75
+
76
+ def get_app_services(user: AppUser, session: Session) -> List[str]:
77
+ """
78
+ Retrieve all app services the passed user has access to
79
+ """
80
+ if db_user := session.exec(select(AppUser).where(col(AppUser.id) == user.id)).first():
81
+ return [right.app_service for right in db_user.rights]
82
+ return []
83
+
84
+
85
+ class JwtAuth(AuthenticationBackend):
86
+ """
87
+ Sqladmin security class. Implement login/logout procedure as well as the authentication check.
88
+ """
89
+
90
+ async def login(self, request: Request) -> bool:
91
+ """
92
+ Login procedure: factorized with the fastapi jwt logic
93
+ """
94
+ form = await request.form()
95
+ if token := self.authorized(form):
96
+ request.session.update(token)
97
+ return True if token else False
98
+
99
+ def authorized(self, form: Any):
100
+ """
101
+ Check that the user information contained in the form corresponds to an admin user
102
+ """
103
+ with Session(engine) as session:
104
+ try:
105
+ return self.admin_token(form, session)
106
+ except HTTPException:
107
+ return None
108
+
109
+ def admin_token(self, form: Any, session: Session) -> Union[Dict[str, str], None]:
110
+ """
111
+ Unsafe attempt to retrieve the token, only return it if admin rights
112
+ """
113
+ token = attempt_to_log(form.get('username', ''), form.get('password', ''), session)
114
+ return token if is_admin_user(token['access_token']) else None
115
+
116
+ async def logout(self, request: Request) -> bool:
117
+ """
118
+ Logout procedure: clears the cache
119
+ """
120
+ request.session.clear()
121
+ return True
122
+
123
+ async def authenticate(self, request: Request) -> Optional[RedirectResponse]:
124
+ """
125
+ Authentication procedure
126
+ """
127
+ return (token := request.session.get('access_token')) and is_admin_user(token)
128
+
129
+
130
+ def attempt_to_log(user: str,
131
+ password: str,
132
+ session: Session,
133
+ tfa_value: Optional[str] = None
134
+ ) -> Union[Dict, HTTPException]:
135
+ """
136
+ Factorized security logic. Ensure that the user is a legit one with a valid password.
137
+ If so, generate a token (with or without encoded tfa_value depending on whether this argument is
138
+ passed as an argument). If not, returns an HTTP exception with an intelligible error message.
139
+
140
+ Attributes are:
141
+ user: the user as expected to be found in the AppUser db
142
+ password: the plain password, to be compared with the hashed one in the AppUser db
143
+ session: db connection
144
+ tfa_value: if filled, add it encoded to the generated token
145
+ """
146
+ selector = select(AppUser).where(col(AppUser.user) == user)
147
+ if not (db_user := session.exec(selector).first()):
148
+ log.warning('unauthorized user')
149
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=INVALID_USER)
150
+ if not _check_password(password, db_user.password):
151
+ log.warning('invalid user')
152
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=INVALID_CREDENTIALS)
153
+
154
+ return {'access_token': _create_access_token(data={'user_id': db_user.id}, tfa_value=tfa_value),
155
+ 'token_type': 'bearer'}
156
+
157
+
158
+ def is_authorized_user(token: str = Depends(SCHEME)) -> bool:
159
+ """
160
+ Check if the passed token corresponds to an authorized user
161
+ """
162
+ if is_banned(token):
163
+ return False
164
+
165
+ try:
166
+ return get_current_user(token) is not None
167
+ except Exception:
168
+ return False
169
+
170
+
171
+ def safe_get_user(token: Dict, tfa_check: bool = False) -> Union[AppUser, None]:
172
+ """
173
+ Safe method returning a user if one found given the passed token
174
+ """
175
+ try:
176
+ return get_user(get_access_token(token), token['tfa'] if tfa_check else None, tfa_check)
177
+ except (HTTPException, AttributeError):
178
+ return None
179
+
180
+
181
+ def get_user(token: str = Depends(SCHEME),
182
+ tfa_value: Optional[str] = None,
183
+ tfa_check: bool = False) -> AppUser:
184
+ """
185
+ Retrieves (if it exists) the db user corresponding to the passed token
186
+ """
187
+ if is_banned(token):
188
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=REVOKED_TOKEN,
189
+ headers={'WWW-Authenticate': 'Bearer'})
190
+ if user := get_current_user(token, tfa_value, tfa_check):
191
+ return user
192
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=INVALID_CREDENTIALS,
193
+ headers={'WWW-Authenticate': 'Bearer'})
194
+
195
+
196
+ def ban_token(token: str, session: Session) -> None:
197
+ """
198
+ Ban the passed token
199
+ """
200
+ session.add(TokenBanlist(token=token))
201
+ session.commit()
202
+
203
+
204
+ def is_banned(token: str) -> bool:
205
+ """
206
+ Check if the passed token is banned.
207
+
208
+ NB: Clean the TokenBanlist table (deleting old entries) on the fly
209
+ """
210
+ with Session(engine) as session:
211
+ threshold = datetime.now() - timedelta(minutes=EXPIRATION_LENGTH)
212
+ for token_banned in session.exec(
213
+ select(TokenBanlist).where(TokenBanlist.created_at <= threshold)).all():
214
+ session.delete(token_banned)
215
+ session.commit()
216
+ return token in session.exec(select(TokenBanlist.token)).all()
217
+
218
+
219
+ def get_current_user(token: str,
220
+ tfa_value: Optional[str] = None,
221
+ tfa_check: bool = False
222
+ ) -> Union[AppUser, None]:
223
+ """
224
+ Retrieves (if it exists) a valid (meaning who has valid credentials) user from the db
225
+ """
226
+ token = _verify_access_token(token, tfa_value, tfa_check)
227
+ with Session(engine) as session:
228
+ return session.exec(select(AppUser).where(col(AppUser.id) == token.id)).first()
229
+
230
+
231
+ def is_admin_user(token: str = Depends(SCHEME)) -> AppUser:
232
+ """
233
+ Retrieves (if it exists) the admin (meaning who has valid credentials) user from the db
234
+ """
235
+ if is_banned(token):
236
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=REVOKED_TOKEN,
237
+ headers={'WWW-Authenticate': 'Bearer'})
238
+
239
+ if (user := get_current_user(token)) and user.permission == Permission.ADMIN:
240
+ return user
241
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ADMIN_ERROR,
242
+ headers={'WWW-Authenticate': 'Bearer'})
243
+
244
+
245
+ def is_monitoring_user(token: str = Depends(SCHEME)) -> AppUser:
246
+ """
247
+ Retrieves (if it exists) the monitoring user from the db
248
+ """
249
+ if is_banned(token):
250
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=REVOKED_TOKEN,
251
+ headers={'WWW-Authenticate': 'Bearer'})
252
+
253
+ if (user := get_current_user(token)) and user.user == MONITORING:
254
+ return user
255
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
256
+ detail=MONITORING_ERROR, headers={'WWW-Authenticate': 'Bearer'})
257
+
258
+
259
+ def upsert_new_user(token: str, user: str, password: str = '') -> None:
260
+ """
261
+ Upsert a new user if not already present in db.
262
+
263
+ NB: this method RAISES a http error if he token is invalid
264
+ """
265
+ user_id = _verify_access_token(token).id
266
+ with Session(engine) as session:
267
+ if not session.exec(select(AppUser).where(col(AppUser.id) == user_id)).first():
268
+ session.add(AppUser(user=user, password=password, permission=Permission.Consultant,
269
+ id=user_id))
270
+ session.commit()
271
+
272
+
273
+ def _create_access_token(data: Dict, tfa_value: Optional[str] = None) -> str:
274
+ """
275
+ Create an access token out of the passed data. Only called if credentials are valid
276
+ """
277
+ to_encode = data.copy()
278
+ expire = datetime.now(timezone.utc) + timedelta(minutes=EXPIRATION_LENGTH)
279
+ to_encode['exp'] = expire
280
+ if tfa_value:
281
+ to_encode['tfa'] = _hash_password(tfa_value)
282
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGO)
283
+
284
+
285
+ def _verify_access_token(token: str,
286
+ tfa_value: Optional[str] = None,
287
+ tfa_check: bool = False) -> TokenData:
288
+ """
289
+ Retrieves the token data associated to the passed token if it contains valid credential info.
290
+ """
291
+ try:
292
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGO])
293
+ if tfa_check and (not tfa_value or not _check_password(tfa_value, payload.get('tfa'))):
294
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=INVALID_TFA,
295
+ headers={'WWW-Authenticate': 'Bearer'})
296
+ if (user_id := payload.get('user_id')) is None:
297
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=INVALID_USER,
298
+ headers={'WWW-Authenticate': 'Bearer'})
299
+ return TokenData(id=user_id)
300
+ except JWTError as e:
301
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=INVALID_CREDENTIALS,
302
+ headers={'WWW-Authenticate': 'Bearer'}) from e
303
+
304
+
305
+ def _hash_password(password: str) -> str:
306
+ """
307
+ Hashes the passed password (encoding).
308
+ """
309
+ return CONTEXT.hash(password)
310
+
311
+
312
+ def _check_password(plain_password: Optional[str], hashed_password: str) -> bool:
313
+ """
314
+ Check the passed password (compare it to the passed encoded one).
315
+ """
316
+ return CONTEXT.verify(plain_password, hashed_password)
ecodev_core/backup.py ADDED
@@ -0,0 +1,105 @@
1
+ """
2
+ Module implementing backup mechanism on a ftp server.
3
+ """
4
+ import tarfile
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from subprocess import PIPE
8
+ from subprocess import Popen
9
+ from subprocess import run
10
+ from subprocess import STDOUT
11
+ from typing import List
12
+
13
+ from pydantic_settings import BaseSettings
14
+ from pydantic_settings import SettingsConfigDict
15
+
16
+ from ecodev_core.db_connection import DB_URL
17
+ from ecodev_core.logger import logger_get
18
+ from ecodev_core.settings import SETTINGS
19
+
20
+ log = logger_get(__name__)
21
+
22
+
23
+ class BackUpSettings(BaseSettings):
24
+ """
25
+ Settings class used to connect to the ftp server backup
26
+ """
27
+ backup_username: str = ''
28
+ backup_password: str = ''
29
+ backup_url: str = ''
30
+ model_config = SettingsConfigDict(env_file='.env')
31
+
32
+
33
+ BCK, SETTINGS_BCK = BackUpSettings(), SETTINGS.backup # type: ignore[attr-defined]
34
+ _USER = SETTINGS_BCK.backup_username or BCK.backup_username
35
+ _PASSWD = SETTINGS_BCK.backup_password or BCK.backup_password
36
+ _URL = SETTINGS_BCK.backup_url or BCK.backup_url
37
+ BACKUP_URL = f'ftp://{_USER}:{_PASSWD}@{_URL}'
38
+
39
+
40
+ def backup(backed_folder: Path, nb_saves: int = 5, additional_id: str = 'default') -> None:
41
+ """
42
+ Backup db and backed_folder: write the dump/tar on the backup server and erase old copies
43
+ """
44
+ timestamp = datetime.now().strftime('%Y_%m_%d_%Hh_%Mmn_%Ss')
45
+ _backup_db(Path.cwd() / f'{additional_id}_db.{timestamp}.dump', nb_saves)
46
+ _backup_files(backed_folder, Path.cwd() / f'{additional_id}_files.{timestamp}.tgz', nb_saves)
47
+
48
+
49
+ def retrieve_most_recent_backup(name: str = 'default_files') -> None:
50
+ """
51
+ Retrieve from backup server the most recent backup of the name family.
52
+ """
53
+ output = run(['lftp', '-c', f'open {BACKUP_URL}; ls'], capture_output=True, text=True)
54
+ all_backups = sorted([x.split(' ')[-1] for x in output.stdout.splitlines() if name in x])
55
+ log.info(f'most recent backup {all_backups[-1]}')
56
+ run(['lftp', '-c', f'open {BACKUP_URL}; get {all_backups[-1]}'])
57
+ return None
58
+
59
+
60
+ def _backup_db(db_dump_path: Path, nb_saves: int) -> None:
61
+ """
62
+ Pg_dump of DB_URL db andwrite on the backup server
63
+ """
64
+ process = Popen(['pg_dump', f'--dbname={DB_URL}', '-f', db_dump_path.name],
65
+ stdout=PIPE, stderr=STDOUT, cwd=db_dump_path.parent)
66
+ if not (process_output := process.communicate()[0]):
67
+ _backup_content(db_dump_path, nb_saves)
68
+ else:
69
+ log.critical(f'something went wrong : {process_output}')
70
+
71
+
72
+
73
+ def _backup_files(backed_folder: Path, backup_file: Path, nb_saves: int) -> None:
74
+ """
75
+ Zip backed_folder and write on the backup server
76
+ """
77
+ with tarfile.open(backup_file, 'w:gz') as tar:
78
+ tar.add(backed_folder, arcname=backed_folder.name)
79
+ _backup_content(backup_file, nb_saves)
80
+
81
+
82
+ def _backup_content(file_to_backup: Path, nb_saves: int) -> None:
83
+ """
84
+ Write file_to_backup on the backup server and delete versions so as to keep only nb_saves copies
85
+ """
86
+ log.warning(f'Transferring backup to server: {file_to_backup}')
87
+ run(['lftp', '-c', f'open {BACKUP_URL}; put {file_to_backup}'])
88
+ backups_to_delete = _get_old_backups(file_to_backup, nb_saves)
89
+ log.info(f'deleting remote backups {backups_to_delete}')
90
+ for to_rm in backups_to_delete:
91
+ run(['lftp', '-c', f'open {BACKUP_URL}; rm {to_rm}'], capture_output=True, text=True)
92
+ log.info(f'deleting local {file_to_backup}')
93
+ file_to_backup.unlink()
94
+
95
+
96
+ def _get_old_backups(file_to_backup: Path, nb_saves: int) -> List[str]:
97
+ """
98
+ Retrieve old versions of file_to_backup in order to erase them (more than nb_saves ago)
99
+ """
100
+ output = run(['lftp', '-c', f'open {BACKUP_URL}; ls'], capture_output=True, text=True)
101
+ filename_base = file_to_backup.name.split('.')[0]
102
+ all_backups = sorted([x.split(' ')[-1]
103
+ for x in output.stdout.splitlines() if filename_base in x])
104
+ log.info(f'existing remote backups {all_backups}')
105
+ return all_backups[:-nb_saves]
@@ -0,0 +1,179 @@
1
+ """
2
+ Module computing and checking high level dependencies in the coe (based on pydeps)
3
+ """
4
+ from pathlib import Path
5
+ from subprocess import run
6
+ from typing import Dict
7
+ from typing import Iterator
8
+ from typing import List
9
+
10
+ from ecodev_core.logger import logger_get
11
+
12
+
13
+ CONF_FILE = 'dependencies.json'
14
+ Dependency = Dict[str, List[str]]
15
+ log = logger_get(__name__)
16
+
17
+
18
+ def check_dependencies(code_base: Path, theoretical_deps: Path):
19
+ """
20
+ hook for preserving the pre established solution dependencies.
21
+ Compare regroupment of module dependencies matrix to a pre-computed matrix stored
22
+ in theoretical_deps. Computation done on code_base.
23
+ """
24
+ dependencies = _get_current_dependencies(_valid_modules(code_base), code_base, code_base.name)
25
+ allowed_dependencies = _get_allowed_dependencies(theoretical_deps)
26
+ if not (ok_deps := _test_dependency(allowed_dependencies, dependencies)):
27
+ log.error('you changed high level solution dependencies. Intended?')
28
+ return ok_deps
29
+
30
+
31
+ def compute_dependencies(code_base: Path, output_folder: Path, plot: bool = True):
32
+ """
33
+ Given a code base, compute the dependencies between its high level modules.
34
+ Store in output_folder the dependency matrix in txt format and the png of the dependencies
35
+ """
36
+ code_folder = code_base.name
37
+ modules = _valid_modules(code_base)
38
+
39
+ deps: Dependency = _get_current_dependencies(modules, code_base, code_folder)
40
+
41
+ for mod, mod_deps in deps.items():
42
+ with open(output_folder / f'{mod}.py', 'w') as f_stream:
43
+ f_stream.writelines([f'from .{other_module} import to\n' for other_module in mod_deps])
44
+
45
+ with open(output_folder / '__init__.py', 'w') as f_stream:
46
+ f_stream.writelines([])
47
+
48
+ with open(output_folder / f'dependencies_{code_folder}.txt', 'w') as f_stream:
49
+ f_stream.writelines([f'{dependency}\n' for dependency in _get_dep_matrix(modules, deps)])
50
+
51
+ if plot:
52
+ run(['pydeps', '.', '-T', 'png', '--no-show', '--rmprefix',
53
+ f'{output_folder.name}.'], cwd=output_folder)
54
+
55
+
56
+ def _test_dependency(allowed_deps: Dependency, dependencies: Dependency) -> bool:
57
+ """
58
+ For each modules stored in a dependencies.json file, check whether the current
59
+ module dependencies are the same as the config ones.
60
+ """
61
+ for module in dependencies:
62
+ for dep in allowed_deps[module]:
63
+ if dep not in dependencies[module]:
64
+ log.error(f'{module} no longer imported in {dep}. Intended ?')
65
+ for dep in dependencies[module]:
66
+ if dep not in allowed_deps[module]:
67
+ log.error(f'{dep} now imported in {module}. Intended ?')
68
+ for dep in dependencies[module]:
69
+ if module in dependencies[dep] and dep != module:
70
+ log.error(f'Circular ref created between {module} and {dep}.')
71
+ return dependencies == allowed_deps
72
+
73
+
74
+ def _get_allowed_dependencies(config_path: Path) -> Dependency:
75
+ """
76
+ Given the pre established dependency file path, compute
77
+ the pre established modules and their dependencies seen as an adjacency dict.
78
+ All the values of a given key are its dependencies.
79
+ The keys and the values of the dictionary take their labels in
80
+ the pre established module list.
81
+ """
82
+ raw_lines = list(_safe_read_lines(config_path))
83
+ raw_matrix = [raw_dependency.split(',') for raw_dependency in raw_lines][1:]
84
+ modules = [raw_dependency[0] for raw_dependency in raw_matrix]
85
+ module_dependencies: Dict[str, List[str]] = {
86
+ module: [modules[idx_other_module] for idx_other_module in range(len(modules))
87
+ if raw_matrix[idx_module][idx_other_module + 1] == 'Yes']
88
+ for idx_module, module in enumerate(modules)
89
+ }
90
+
91
+ return module_dependencies
92
+
93
+
94
+ def _get_current_dependencies(modules: List[str],
95
+ code_base: Path,
96
+ code_folder: str) -> Dependency:
97
+ """
98
+ Given the pre established modules, the code_base directory and the relative path of the code
99
+ directory wrt code_folder, compute the pre current dependencies as an adjacency dict.
100
+ All the values of a given key are its dependencies. The keys and the values of
101
+ the dictionary take their labels in the pre established module list.
102
+ """
103
+ module_dependencies: Dependency = {}
104
+ for module in modules:
105
+ module_dependencies[module] = []
106
+ module_dir = code_base / module
107
+ for other_module in modules:
108
+ if other_module in module_dependencies[module]:
109
+ break
110
+ for py_file in _get_recursively_all_files_in_dir(module_dir, 'py'):
111
+ if _depends_on_module(module_dir / py_file, other_module, code_folder):
112
+ module_dependencies[module].append(other_module)
113
+ break
114
+
115
+ return module_dependencies
116
+
117
+
118
+ def _depends_on_module(file: Path, module: str, code_folder: str) -> bool:
119
+ """
120
+ check if a reference to module is in the imports of python_file
121
+ """
122
+ return any(
123
+ (f'from {code_folder}.{module}' in line and 'import' in line)
124
+ or (line.startswith(f'import {code_folder}.{module}.'))
125
+ for line in _safe_read_lines(file)
126
+ )
127
+
128
+
129
+ def _safe_read_lines(filename: Path) -> Iterator[str]:
130
+ """
131
+ read all lines in file, erase the final special \n character
132
+ """
133
+ with open(filename, 'r') as f:
134
+ lines = f.readlines()
135
+ yield from [line.strip() for line in lines]
136
+
137
+
138
+ def _get_recursively_all_files_in_dir(code_folder: Path, extension: str) -> Iterator[Path]:
139
+ """
140
+ Given a folder, recursively return all files of the given extension in the folder
141
+ """
142
+ yield from code_folder.glob(f'**/*.{extension}')
143
+
144
+
145
+ def _valid_folder(folder: Path):
146
+ """
147
+ Return True if folder is a python regroupment of module to be considered for dependency analysis
148
+ """
149
+ return (
150
+ folder.is_dir()
151
+ and not folder.name.startswith('.')
152
+ and not folder.name.startswith('_')
153
+ and folder.name != 'data'
154
+ )
155
+
156
+
157
+ def _valid_modules(root_folder: Path):
158
+ """
159
+ Retrieve valid solution module (found at the base level of root_folder)
160
+ """
161
+ return sorted([folder.name for folder in root_folder.iterdir() if _valid_folder(folder)])
162
+
163
+
164
+ def _get_dep_matrix(modules: List[str], deps: Dependency) -> List[str]:
165
+ """
166
+ Retrieve the dependency matrix of the inspected solution in txt format
167
+ """
168
+ dependencies = [f'module x depends on,{",".join(modules)}']
169
+ dependencies.extend(f'{module},' + ','.join([_depends_on(module, other_module, deps)
170
+ for other_module in modules])for module in modules)
171
+
172
+ return dependencies
173
+
174
+
175
+ def _depends_on(module, other_module, deps):
176
+ """
177
+ Write correct input in the dependency matrix ("Yes" if other_module is in deps of module key)
178
+ """
179
+ return 'Yes' if other_module in deps[module] else 'No'
@@ -0,0 +1,27 @@
1
+ """
2
+ Module comparing whether two elements are both None or both not None and equals
3
+ """
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+
8
+
9
+ def custom_equal(element_1: Optional[object], element_2: Optional[object], element_type: type):
10
+ """
11
+ Compare whether two elements are both None or both not None and equals (same type/same value)
12
+
13
+ Args:
14
+ element_1: the first element of the comparison
15
+ element_2: the second element of the comparison
16
+ element_type: the expected element type for both elements
17
+
18
+ Returns:
19
+ True if both None or both not None and equals (same type/same value)
20
+ """
21
+ if element_1 is None:
22
+ return element_2 is None
23
+
24
+ if not isinstance(element_1, element_type) or not isinstance(element_2, element_type):
25
+ return False
26
+
27
+ return np.isclose(element_1, element_2) if element_type == float else element_1 == element_2