ecodev-core 0.0.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ecodev_core/__init__.py +129 -0
- ecodev_core/app_activity.py +126 -0
- ecodev_core/app_rights.py +24 -0
- ecodev_core/app_user.py +92 -0
- ecodev_core/auth_configuration.py +24 -0
- ecodev_core/authentication.py +316 -0
- ecodev_core/backup.py +105 -0
- ecodev_core/check_dependencies.py +179 -0
- ecodev_core/custom_equal.py +27 -0
- ecodev_core/db_connection.py +94 -0
- ecodev_core/db_filters.py +142 -0
- ecodev_core/db_i18n.py +211 -0
- ecodev_core/db_insertion.py +128 -0
- ecodev_core/db_retrieval.py +193 -0
- ecodev_core/db_upsertion.py +382 -0
- ecodev_core/deployment.py +16 -0
- ecodev_core/email_sender.py +60 -0
- ecodev_core/encryption.py +46 -0
- ecodev_core/enum_utils.py +21 -0
- ecodev_core/es_connection.py +79 -0
- ecodev_core/list_utils.py +134 -0
- ecodev_core/logger.py +122 -0
- ecodev_core/pandas_utils.py +69 -0
- ecodev_core/permissions.py +21 -0
- ecodev_core/pydantic_utils.py +33 -0
- ecodev_core/read_write.py +52 -0
- ecodev_core/rest_api_client.py +211 -0
- ecodev_core/rest_api_configuration.py +25 -0
- ecodev_core/safe_utils.py +241 -0
- ecodev_core/settings.py +51 -0
- ecodev_core/sqlmodel_utils.py +16 -0
- ecodev_core/token_banlist.py +18 -0
- ecodev_core/version.py +144 -0
- ecodev_core-0.0.67.dist-info/LICENSE.md +11 -0
- ecodev_core-0.0.67.dist-info/METADATA +87 -0
- ecodev_core-0.0.67.dist-info/RECORD +37 -0
- ecodev_core-0.0.67.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Low level methods to retrieve data from db in a paginated way
|
|
3
|
+
"""
|
|
4
|
+
from math import ceil
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing import Callable
|
|
7
|
+
from typing import Dict
|
|
8
|
+
from typing import List
|
|
9
|
+
from typing import Optional
|
|
10
|
+
from typing import Tuple
|
|
11
|
+
from typing import Union
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from sqlalchemy import func
|
|
15
|
+
from sqlmodel import col
|
|
16
|
+
from sqlmodel import or_
|
|
17
|
+
from sqlmodel import select
|
|
18
|
+
from sqlmodel import Session
|
|
19
|
+
from sqlmodel.sql.expression import Select
|
|
20
|
+
from sqlmodel.sql.expression import SelectOfScalar
|
|
21
|
+
|
|
22
|
+
from ecodev_core.db_connection import engine
|
|
23
|
+
from ecodev_core.db_filters import SERVER_SIDE_FILTERS
|
|
24
|
+
from ecodev_core.db_filters import ServerSideFilter
|
|
25
|
+
from ecodev_core.db_upsertion import FILTER_ON
|
|
26
|
+
from ecodev_core.list_utils import first_or_default
|
|
27
|
+
from ecodev_core.pydantic_utils import Frozen
|
|
28
|
+
|
|
29
|
+
SelectOfScalar.inherit_cache = True # type: ignore
|
|
30
|
+
Select.inherit_cache = True # type: ignore
|
|
31
|
+
OPERATORS = ['>=', '<=', '!=', '=', '<', '>', 'contains ']
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ServerSideField(Frozen):
|
|
35
|
+
"""
|
|
36
|
+
Simple class used for sever side data retrieval
|
|
37
|
+
|
|
38
|
+
Attributes are:
|
|
39
|
+
- col_name: the name as it will appear on the frontend interface
|
|
40
|
+
- field_name: the SQLModel attribute name associated with this field
|
|
41
|
+
- field: the SQLModel attribute associated with this field
|
|
42
|
+
- filter: the filtering mechanism to use for this field
|
|
43
|
+
"""
|
|
44
|
+
col_name: str
|
|
45
|
+
field_name: str
|
|
46
|
+
field: Any = None
|
|
47
|
+
filter: ServerSideFilter
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def count_rows(fields: List[ServerSideField],
|
|
51
|
+
model: Any,
|
|
52
|
+
limit: Union[int, None] = None,
|
|
53
|
+
filter_str: str = '',
|
|
54
|
+
search_str: str = '',
|
|
55
|
+
search_cols: Optional[List] = None) -> int:
|
|
56
|
+
"""
|
|
57
|
+
Count the total number of rows in the db model, with statically defined field_filters fed with
|
|
58
|
+
dynamically set frontend filters. Divide this total number by limit to account for pagination.
|
|
59
|
+
"""
|
|
60
|
+
with Session(engine) as session:
|
|
61
|
+
count = session.exec(_get_full_query(fields, model, filter_str, True, search_str,
|
|
62
|
+
search_cols)).one()
|
|
63
|
+
|
|
64
|
+
return ceil(count / limit) if limit else count
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_rows(fields: List[ServerSideField],
|
|
68
|
+
model: Any,
|
|
69
|
+
limit: Union[int, None] = None,
|
|
70
|
+
offset: Union[int, None] = None,
|
|
71
|
+
filter_str: str = '',
|
|
72
|
+
search_str: str = '',
|
|
73
|
+
search_cols: Optional[List] = None,
|
|
74
|
+
fields_order: Optional[Callable] = None
|
|
75
|
+
) -> pd.DataFrame:
|
|
76
|
+
"""
|
|
77
|
+
Select relevant row lines from model db. Select the whole db if no limit or offset is provided.
|
|
78
|
+
Convert the rows to a dataframe in order to show the result in a dash data_table.
|
|
79
|
+
|
|
80
|
+
NB:
|
|
81
|
+
* 'fields_order' specify how to order the result rows
|
|
82
|
+
* 'limit' and 'offset' correspond to the pagination of the results.
|
|
83
|
+
* 'search_str' corresponds to the search string from the search input.
|
|
84
|
+
"""
|
|
85
|
+
with Session(engine) as session:
|
|
86
|
+
rows = _paginate_db_lines(fields, model, session, limit, offset, filter_str,
|
|
87
|
+
search_str, search_cols, fields_order)
|
|
88
|
+
if len(raw_df := pd.DataFrame.from_records([row.model_dump() for row in rows])) > 0:
|
|
89
|
+
return raw_df.rename(columns={field.field_name: field.col_name for field in fields}
|
|
90
|
+
)[[field.col_name for field in fields]]
|
|
91
|
+
return pd.DataFrame(columns=[field.col_name for field in fields])
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _paginate_db_lines(fields: List[ServerSideField],
|
|
95
|
+
model: Any,
|
|
96
|
+
session: Session,
|
|
97
|
+
limit: Union[int, None],
|
|
98
|
+
offset: Union[int, None],
|
|
99
|
+
filter_str: str,
|
|
100
|
+
search_str: str = '',
|
|
101
|
+
search_cols: Optional[List] = None,
|
|
102
|
+
fields_order: Optional[Callable] = None,
|
|
103
|
+
) -> List:
|
|
104
|
+
"""
|
|
105
|
+
Select relevant row lines from model db. Select the whole db if no limit or offset is provided.
|
|
106
|
+
"""
|
|
107
|
+
if fields_order is None:
|
|
108
|
+
fields_order = _get_default_field_order(fields)
|
|
109
|
+
|
|
110
|
+
query = fields_order(_get_full_query(fields, model, filter_str, count=False,
|
|
111
|
+
search_str=search_str, search_cols=search_cols))
|
|
112
|
+
if limit is not None and offset is not None:
|
|
113
|
+
return list(session.exec(query.offset(offset * limit).limit(limit)))
|
|
114
|
+
return list(session.exec(query).all())
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _get_full_query(fields: List[ServerSideField],
|
|
118
|
+
model: Any,
|
|
119
|
+
filter_str: str,
|
|
120
|
+
count: bool = False,
|
|
121
|
+
search_str: str = '',
|
|
122
|
+
search_cols: Optional[List] = None
|
|
123
|
+
) -> SelectOfScalar:
|
|
124
|
+
"""
|
|
125
|
+
Forge a complete select query given both search and filter strings
|
|
126
|
+
|
|
127
|
+
NB:
|
|
128
|
+
* This relies on the passed statically defined field_filters corresponding to the model.
|
|
129
|
+
* The field_filters are used jointly with the dynamically set frontend filters.
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
filter_query = _get_filter_query(fields, model, _get_frontend_filters(filter_str), count)
|
|
133
|
+
|
|
134
|
+
if not search_str or not search_cols:
|
|
135
|
+
return filter_query
|
|
136
|
+
|
|
137
|
+
return filter_query.where(or_(col(field).ilike(f'%{search_str.strip()}%')
|
|
138
|
+
for field in search_cols))
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _get_frontend_filters(raw_filters: str) -> Dict[str, Tuple[str, str]]:
|
|
142
|
+
"""
|
|
143
|
+
Forge a dictionary of field keys, (operator, value) values in order to filter a db model.
|
|
144
|
+
"""
|
|
145
|
+
split_filters = raw_filters.split(' && ')
|
|
146
|
+
return {elt[elt.find('{') + 1: elt.rfind('}')]: _forge_filter(elt) for elt in split_filters}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _forge_filter(elt: str) -> Tuple[str, str]:
|
|
150
|
+
"""
|
|
151
|
+
Forge the operator and value associated to the passed element. Do so by scanning the ordered
|
|
152
|
+
sequence of OPERATORS and returning the first matching (value is on the right of it).
|
|
153
|
+
"""
|
|
154
|
+
return next(((key, elt.split(key)[-1]) for key in OPERATORS if key in elt), ('', ''))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _get_filter_query(fields: List[ServerSideField],
|
|
158
|
+
model: Any,
|
|
159
|
+
frontend_filters: Dict[str, Tuple[str, str]],
|
|
160
|
+
count: bool = False
|
|
161
|
+
) -> SelectOfScalar:
|
|
162
|
+
"""
|
|
163
|
+
Filter a model given backend static field_filters called with dynamically set frontend_filters.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
* either the query fetching the filtered rows (count = False)
|
|
167
|
+
* or the filter row count.
|
|
168
|
+
"""
|
|
169
|
+
query = select(func.count(model.id)) if count else select(model)
|
|
170
|
+
if not frontend_filters or not all(frontend_filters.keys()):
|
|
171
|
+
return query
|
|
172
|
+
|
|
173
|
+
for key, (operator, value) in frontend_filters.items():
|
|
174
|
+
if field := first_or_default(fields, lambda x: x.col_name == key):
|
|
175
|
+
query = SERVER_SIDE_FILTERS[field.filter](query=query, operator=operator,
|
|
176
|
+
value=value, field=field.field)
|
|
177
|
+
|
|
178
|
+
return query
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _get_default_field_order(fields: List[ServerSideField]) -> Callable:
|
|
182
|
+
"""
|
|
183
|
+
Recover default field order from list of fields
|
|
184
|
+
"""
|
|
185
|
+
def fields_order(query):
|
|
186
|
+
"""
|
|
187
|
+
Default field ordering
|
|
188
|
+
|
|
189
|
+
Take the initial query as input and specify the order to use.
|
|
190
|
+
"""
|
|
191
|
+
return query.order_by(*[field.field for field in fields])
|
|
192
|
+
|
|
193
|
+
return fields_order
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module handling CRUD and version operations
|
|
3
|
+
"""
|
|
4
|
+
import enum
|
|
5
|
+
import json
|
|
6
|
+
import types
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from enum import EnumType
|
|
9
|
+
from functools import partial
|
|
10
|
+
from typing import Any
|
|
11
|
+
from typing import get_args
|
|
12
|
+
from typing import get_origin
|
|
13
|
+
from typing import Iterator
|
|
14
|
+
from typing import Union
|
|
15
|
+
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import progressbar
|
|
18
|
+
from pydantic_core._pydantic_core import PydanticUndefined
|
|
19
|
+
from sqlmodel import and_
|
|
20
|
+
from sqlmodel import Field
|
|
21
|
+
from sqlmodel import inspect
|
|
22
|
+
from sqlmodel import select
|
|
23
|
+
from sqlmodel import Session
|
|
24
|
+
from sqlmodel import SQLModel
|
|
25
|
+
from sqlmodel import text
|
|
26
|
+
from sqlmodel import update
|
|
27
|
+
from sqlmodel.main import SQLModelMetaclass
|
|
28
|
+
from sqlmodel.sql.expression import SelectOfScalar
|
|
29
|
+
|
|
30
|
+
from ecodev_core.version import get_row_versions
|
|
31
|
+
from ecodev_core.version import Version
|
|
32
|
+
|
|
33
|
+
BATCH_SIZE = 5000
|
|
34
|
+
FILTER_ON = 'filter_on'
|
|
35
|
+
INFO = 'info'
|
|
36
|
+
SA_COLUMN_KWARGS = 'sa_column_kwargs'
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def add_missing_enum_values(enum: EnumType, session: Session, new_vals: list | None = None) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Add to an existing enum its missing db values. Do so by retrieving what is already in db, and
|
|
42
|
+
insert what is new.
|
|
43
|
+
|
|
44
|
+
NB: new_val argument is there for testing purposes
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
for val in [e.name for e in new_vals or enum if e.name not in get_enum_values(enum, session)]:
|
|
48
|
+
session.execute(text(f"ALTER TYPE {enum.__name__.lower()} ADD VALUE IF NOT EXISTS '{val}'"))
|
|
49
|
+
session.commit()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_enum_values(enum: EnumType, session: Session) -> set[str]:
|
|
53
|
+
"""
|
|
54
|
+
Return all enum values in db for the passed enum.
|
|
55
|
+
"""
|
|
56
|
+
result = session.execute(text(
|
|
57
|
+
"""
|
|
58
|
+
SELECT enumlabel FROM pg_enum
|
|
59
|
+
JOIN pg_type ON pg_enum.enumtypid = pg_type.oid
|
|
60
|
+
WHERE pg_type.typname = :enum_name
|
|
61
|
+
"""
|
|
62
|
+
), {'enum_name': enum.__name__.lower()})
|
|
63
|
+
return {x[0] for x in result}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def sfield(**kwargs):
|
|
67
|
+
"""
|
|
68
|
+
Field constructor for columns not to be versioned. Those are the columns on which to select.
|
|
69
|
+
They morally are a sort of unique identifier of a row (like id but more business meaningful)
|
|
70
|
+
"""
|
|
71
|
+
sa_column_kwargs = _get_sa_column_kwargs(kwargs, sfield=True)
|
|
72
|
+
return Field(**kwargs, sa_column_kwargs=sa_column_kwargs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def field(**kwargs):
|
|
76
|
+
"""
|
|
77
|
+
Field constructor for columns to be versioned.
|
|
78
|
+
"""
|
|
79
|
+
sa_column_kwargs = _get_sa_column_kwargs(kwargs, sfield=False)
|
|
80
|
+
return Field(**kwargs, sa_column_kwargs=sa_column_kwargs)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _get_sa_column_kwargs(kwargs, sfield: bool) -> dict:
|
|
84
|
+
"""
|
|
85
|
+
Combine existing sa_column_kwargs with the new field necessary for versioning
|
|
86
|
+
"""
|
|
87
|
+
if not (additional_vals := kwargs.get(SA_COLUMN_KWARGS)):
|
|
88
|
+
return {INFO: {FILTER_ON: sfield}}
|
|
89
|
+
kwargs.pop(SA_COLUMN_KWARGS)
|
|
90
|
+
return additional_vals | {INFO: {FILTER_ON: sfield}}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def upsert_selector(values: SQLModel, db_schema: SQLModelMetaclass) -> SelectOfScalar:
|
|
94
|
+
"""
|
|
95
|
+
Return the query allowing to select on column not to be versioned values.
|
|
96
|
+
"""
|
|
97
|
+
conditions = [getattr(db_schema, x.name) == getattr(values, x.name)
|
|
98
|
+
for x in inspect(db_schema).c if x.info.get(FILTER_ON) is True]
|
|
99
|
+
return select(db_schema).where(and_(*conditions))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def upsert_updator(values: SQLModel,
|
|
103
|
+
row_id: int,
|
|
104
|
+
session: Session,
|
|
105
|
+
db_schema: SQLModelMetaclass
|
|
106
|
+
) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Update the passed row_id from db_schema db with passed new_values.
|
|
109
|
+
Only update columns to be versioned.
|
|
110
|
+
|
|
111
|
+
At the same time, store previous (column, row_id) versions for all columns that changed values.
|
|
112
|
+
"""
|
|
113
|
+
to_update = {col.name: getattr(values, col.name)
|
|
114
|
+
for col in inspect(db_schema).c if col.info.get(FILTER_ON) is False}
|
|
115
|
+
db = session.exec(select(db_schema).where(db_schema.id == row_id)).first().model_dump()
|
|
116
|
+
col_types = {x: y.annotation for x, y in db_schema.__fields__.items()}
|
|
117
|
+
table = db_schema.__tablename__
|
|
118
|
+
|
|
119
|
+
for col, val in {k: v for k, v in db.items() if k in to_update and _value_comparator(
|
|
120
|
+
v, to_update[k])}.items():
|
|
121
|
+
session.add(Version.from_table_row(table, col, row_id, col_types[col], val))
|
|
122
|
+
|
|
123
|
+
return update(db_schema).where(db_schema.id == row_id).values(**to_update)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _value_comparator(v: Any, to_update: Any) -> bool:
|
|
127
|
+
"""
|
|
128
|
+
Performs a comparison between the value in db and the value to be upserted
|
|
129
|
+
"""
|
|
130
|
+
return v.date() != to_update.date() if isinstance(v, datetime) else v != to_update
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def upsert_deletor(values: SQLModel, session: Session):
|
|
134
|
+
"""
|
|
135
|
+
Delete row in db corresponding to the passed values, selecting on columns not to be versioned.
|
|
136
|
+
"""
|
|
137
|
+
db_schema = values.__class__
|
|
138
|
+
if in_db := session.exec(upsert_selector(values, db_schema=db_schema)).first():
|
|
139
|
+
for version in get_row_versions(db_schema.__tablename__, in_db.id, session):
|
|
140
|
+
session.delete(version)
|
|
141
|
+
session.delete(in_db)
|
|
142
|
+
session.commit()
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def upsert_df_data(df: Union[pd.DataFrame], db_schema: SQLModelMetaclass, session: Session) -> None:
|
|
146
|
+
"""
|
|
147
|
+
Upsert the passed df into db_schema db.
|
|
148
|
+
"""
|
|
149
|
+
upsert_data([x.to_dict() for _, x in df.iterrows()], session, raw_db_schema=db_schema)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def upsert_data(data: list[dict | SQLModelMetaclass],
|
|
153
|
+
session: Session,
|
|
154
|
+
raw_db_schema: SQLModelMetaclass | None = None) -> None:
|
|
155
|
+
"""
|
|
156
|
+
Upsert the passed list of dicts (corresponding to db_schema) into db_schema db.
|
|
157
|
+
"""
|
|
158
|
+
db_schema = raw_db_schema or data[0].__class__
|
|
159
|
+
selector = partial(upsert_selector, db_schema=db_schema)
|
|
160
|
+
updator = partial(upsert_updator, db_schema=db_schema)
|
|
161
|
+
batches = [data[i:i + BATCH_SIZE] for i in range(0, len(data), BATCH_SIZE)]
|
|
162
|
+
|
|
163
|
+
for batch in progressbar.progressbar(batches, redirect_stdout=False):
|
|
164
|
+
for row in batch:
|
|
165
|
+
new_object = db_schema(**row) if isinstance(row, dict) else row
|
|
166
|
+
if in_db := session.exec(selector(new_object)).first():
|
|
167
|
+
session.exec(updator(new_object, in_db.id, session))
|
|
168
|
+
else:
|
|
169
|
+
session.add(new_object)
|
|
170
|
+
session.commit()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def get_sfield_columns(db_model: SQLModelMetaclass) -> list[str]:
|
|
174
|
+
"""
|
|
175
|
+
get all the columsn flagged as sfields from schema
|
|
176
|
+
Args:
|
|
177
|
+
db_model (SQLModelMetaclass): db_model
|
|
178
|
+
Returns:
|
|
179
|
+
list of str with the names of the columns
|
|
180
|
+
"""
|
|
181
|
+
return [
|
|
182
|
+
x.name
|
|
183
|
+
for x in inspect(db_model).c
|
|
184
|
+
if x.info.get(FILTER_ON) is True
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def filter_to_sfield_dict(row: dict | SQLModelMetaclass,
|
|
189
|
+
db_schema: SQLModelMetaclass | None = None) \
|
|
190
|
+
-> dict[str, dict | SQLModelMetaclass]:
|
|
191
|
+
"""
|
|
192
|
+
Returns a dict with only sfields from object
|
|
193
|
+
Args:
|
|
194
|
+
row: any object with ecodev_core field and sfield
|
|
195
|
+
db_schema (SQLModelMetaclass): db_schema. Use the schema of row if not specified
|
|
196
|
+
Returns:
|
|
197
|
+
dict
|
|
198
|
+
"""
|
|
199
|
+
return {pk: getattr(row, pk)
|
|
200
|
+
for pk in get_sfield_columns(db_schema or row.__class__)}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def add_missing_columns(model: Any, session: Session) -> None:
|
|
204
|
+
"""
|
|
205
|
+
Create all columns corresponding to fields in the passed model that are not yet columns in the
|
|
206
|
+
corresponding db table.
|
|
207
|
+
|
|
208
|
+
NB: The ORM not permitting to create new columns, we unfortunately have to rely on sqlalchemy
|
|
209
|
+
text sql statements.
|
|
210
|
+
|
|
211
|
+
NB2: As of 2025/10/01, handle the creation of int, float, str, bool, bytes, JSONB, Enum columns
|
|
212
|
+
|
|
213
|
+
NB3: possible to index columns, and to add foreign key.
|
|
214
|
+
|
|
215
|
+
NB4: Possible to have a non NULL default value
|
|
216
|
+
"""
|
|
217
|
+
table = model.__tablename__
|
|
218
|
+
current_cols, = get_existing_columns(table, session),
|
|
219
|
+
for col, py_type, fld in [(c, p, f) for c, p, f in _get_cols(model) if c not in current_cols]:
|
|
220
|
+
is_null = _is_type_nullable(py_type)
|
|
221
|
+
default = _get_default_value(fld, is_null)
|
|
222
|
+
_add_column(table, col, _py_type_to_sql(_clean_py_type(py_type)), default, is_null, session)
|
|
223
|
+
if getattr(fld, 'index', False):
|
|
224
|
+
_add_index(table, col, session)
|
|
225
|
+
if isinstance((fk := getattr(fld, 'foreign_key', None)), str) and fk.strip():
|
|
226
|
+
_add_foreign_key(f"{fk.split('.')[0]}(id)", table, col, session)
|
|
227
|
+
session.commit()
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _get_default_value(fld: Any, nullable: bool) -> Any:
|
|
231
|
+
"""
|
|
232
|
+
Find if any the field default value
|
|
233
|
+
"""
|
|
234
|
+
if not nullable and hasattr(fld, 'default') and fld.default is not None:
|
|
235
|
+
return fld.default
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _add_column(table: str,
|
|
240
|
+
col: str,
|
|
241
|
+
sql_type: str,
|
|
242
|
+
default: Any,
|
|
243
|
+
nullable: bool,
|
|
244
|
+
session: Session
|
|
245
|
+
) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Add the new column with sql_type to the passed table
|
|
248
|
+
"""
|
|
249
|
+
session.execute(text(f'ALTER TABLE {table} ADD COLUMN {col} {sql_type} '
|
|
250
|
+
f'{_get_additional_request(col, sql_type, default, nullable)}'))
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _get_additional_request(col: str, sql_type: str, default_value: Any, nullable: bool) -> str:
|
|
254
|
+
"""
|
|
255
|
+
Add if any the default value for the passed col.
|
|
256
|
+
"""
|
|
257
|
+
if nullable:
|
|
258
|
+
return 'NULL'
|
|
259
|
+
|
|
260
|
+
if default_value is not None:
|
|
261
|
+
if (default_sql := _python_default_to_sql(default_value, sql_type)) == 'NULL':
|
|
262
|
+
raise ValueError(f'Non-nullable column {col} requires a default_value')
|
|
263
|
+
return f'DEFAULT {default_sql} NOT NULL'
|
|
264
|
+
|
|
265
|
+
raise ValueError(f'Non-nullable column {col} requires a default_value')
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _add_index(table: str, col: str, session: Session):
|
|
269
|
+
"""
|
|
270
|
+
Index the new table column
|
|
271
|
+
"""
|
|
272
|
+
session.execute(text(f'CREATE INDEX IF NOT EXISTS ix_{table}_{col} ON {table} ({col})'))
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _add_foreign_key(fk: str, table: str, col: str, session: Session):
|
|
276
|
+
"""
|
|
277
|
+
Add a fk foreign key on the passed table column
|
|
278
|
+
"""
|
|
279
|
+
session.execute(text(
|
|
280
|
+
f'ALTER TABLE {table} ADD CONSTRAINT fk_{table}_{col} FOREIGN KEY ({col}) REFERENCES {fk}'))
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _get_cols(model: Any) -> Iterator[tuple[str, Any, Any]]:
|
|
284
|
+
"""
|
|
285
|
+
Retrieve all fields and their corresponding sql types from the passed model
|
|
286
|
+
"""
|
|
287
|
+
for col, field in model.model_fields.items():
|
|
288
|
+
if (col_type := getattr(field, 'annotation', None)) is not None:
|
|
289
|
+
yield col, col_type, field
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def get_existing_columns(table_name: str, session: Session) -> set[str]:
|
|
293
|
+
"""
|
|
294
|
+
Retrieve all column names from the passed table
|
|
295
|
+
"""
|
|
296
|
+
result = session.execute(text('SELECT column_name FROM information_schema.columns WHERE '
|
|
297
|
+
'table_name = :table_name'), {'table_name': table_name})
|
|
298
|
+
return {r[0] for r in result}
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _clean_py_type(col_type: Any) -> Any:
|
|
302
|
+
"""
|
|
303
|
+
Convert union and optional types to their non-None types, return directly passed type otherwise.
|
|
304
|
+
- Handle Python 3.10+ UnionType (aka X | Y)
|
|
305
|
+
- Unpack Optional types (Union[X, NoneType])
|
|
306
|
+
"""
|
|
307
|
+
if isinstance(col_type, types.UnionType):
|
|
308
|
+
if len((args := [t for t in col_type.__args__ if t is not type(None)])) == 1:
|
|
309
|
+
return args[0]
|
|
310
|
+
|
|
311
|
+
if get_origin(col_type) is Union:
|
|
312
|
+
if len((args := [t for t in get_args(col_type) if t is not type(None)])) == 1:
|
|
313
|
+
return args[0]
|
|
314
|
+
|
|
315
|
+
return col_type
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _is_type_nullable(col_type: Any) -> bool:
|
|
319
|
+
"""
|
|
320
|
+
Return True if col_type is Optional or Union[..., None].
|
|
321
|
+
"""
|
|
322
|
+
if isinstance(col_type, types.UnionType):
|
|
323
|
+
return type(None) in col_type.__args__
|
|
324
|
+
|
|
325
|
+
if get_origin(col_type) is Union:
|
|
326
|
+
return type(None) in get_args(col_type)
|
|
327
|
+
|
|
328
|
+
return col_type is type(None)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _python_default_to_sql(value: Any, sql_type: str) -> str:
|
|
332
|
+
"""
|
|
333
|
+
Convert Python default to SQL literal, handling common types.
|
|
334
|
+
"""
|
|
335
|
+
if value is None or value == PydanticUndefined:
|
|
336
|
+
return 'NULL'
|
|
337
|
+
if sql_type in ('VARCHAR', 'TEXT', 'CHAR'):
|
|
338
|
+
safe_value = value.replace("'", "''")
|
|
339
|
+
return f"'{safe_value}'"
|
|
340
|
+
if sql_type in ('INTEGER', 'FLOAT', 'NUMERIC', 'DOUBLE PRECISION'):
|
|
341
|
+
return str(value)
|
|
342
|
+
if sql_type == 'BOOLEAN':
|
|
343
|
+
return 'TRUE' if value else 'FALSE'
|
|
344
|
+
if sql_type == 'BYTEA':
|
|
345
|
+
if isinstance(value, bytes):
|
|
346
|
+
return f"decode('{value.hex()}', 'hex')"
|
|
347
|
+
raise ValueError('Default for BYTEA must be bytes')
|
|
348
|
+
if sql_type == 'JSONB':
|
|
349
|
+
json_str = json.dumps(value).replace("'", "''")
|
|
350
|
+
return f"'{json_str}'::jsonb"
|
|
351
|
+
if isinstance(value, enum.Enum):
|
|
352
|
+
return f"'{str(value.name)}'"
|
|
353
|
+
return str(value)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _py_type_to_sql(col_type: type) -> str:
|
|
357
|
+
"""
|
|
358
|
+
Convert a python type to a sql one. Only working for (as of 2025/10/01):
|
|
359
|
+
- int
|
|
360
|
+
- float
|
|
361
|
+
- str
|
|
362
|
+
- bool
|
|
363
|
+
- bytes
|
|
364
|
+
- jsonB
|
|
365
|
+
- Enum
|
|
366
|
+
NB: for enum, assumes type is already created in DB
|
|
367
|
+
"""
|
|
368
|
+
if col_type is str:
|
|
369
|
+
return 'VARCHAR'
|
|
370
|
+
if col_type is int:
|
|
371
|
+
return 'INTEGER'
|
|
372
|
+
if col_type is float:
|
|
373
|
+
return 'FLOAT'
|
|
374
|
+
if col_type is bool:
|
|
375
|
+
return 'BOOLEAN'
|
|
376
|
+
if col_type is bytes:
|
|
377
|
+
return 'BYTEA'
|
|
378
|
+
if col_type is dict:
|
|
379
|
+
return 'JSONB'
|
|
380
|
+
if hasattr(col_type, '__members__'):
|
|
381
|
+
return col_type.__name__.lower()
|
|
382
|
+
raise ValueError(f'Unsupported column type: {col_type}')
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module implementing all types of deployment
|
|
3
|
+
"""
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from enum import unique
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@unique
|
|
9
|
+
class Deployment(str, Enum):
|
|
10
|
+
"""
|
|
11
|
+
Enum listing all types of deployment
|
|
12
|
+
"""
|
|
13
|
+
LOCAL = 'local'
|
|
14
|
+
NON_PROD = 'nonprod'
|
|
15
|
+
PREPROD = 'preprod'
|
|
16
|
+
PROD = 'prod'
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module implementing generic email send
|
|
3
|
+
"""
|
|
4
|
+
from email.mime.image import MIMEImage
|
|
5
|
+
from email.mime.multipart import MIMEMultipart
|
|
6
|
+
from email.mime.text import MIMEText
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from smtplib import SMTP
|
|
9
|
+
from ssl import create_default_context
|
|
10
|
+
|
|
11
|
+
from pydantic_settings import BaseSettings
|
|
12
|
+
from pydantic_settings import SettingsConfigDict
|
|
13
|
+
|
|
14
|
+
from ecodev_core.settings import SETTINGS
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EmailAuth(BaseSettings):
|
|
18
|
+
"""
|
|
19
|
+
Simple authentication configuration class
|
|
20
|
+
"""
|
|
21
|
+
email_smtp: str = ''
|
|
22
|
+
email_sender: str = ''
|
|
23
|
+
email_password: str = ''
|
|
24
|
+
email_port: int = 587
|
|
25
|
+
model_config = SettingsConfigDict(env_file='.env')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
EMAIL_AUTH, EMAIL_SETTINGS = EmailAuth(), SETTINGS.smtp # type: ignore[attr-defined]
|
|
29
|
+
_SENDER = EMAIL_SETTINGS.email_sender or EMAIL_AUTH.email_sender
|
|
30
|
+
_SMTP = EMAIL_SETTINGS.email_smtp or EMAIL_AUTH.email_smtp
|
|
31
|
+
_PASSWD = EMAIL_SETTINGS.email_password or EMAIL_AUTH.email_password
|
|
32
|
+
_PORT = EMAIL_SETTINGS.email_port or EMAIL_AUTH.email_port
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def send_email(email: str, body: str, topic: str, images: dict[str, Path] | None = None) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Generic email sender.
|
|
38
|
+
|
|
39
|
+
Attributes are:
|
|
40
|
+
- email: The email to which to send
|
|
41
|
+
- body: the email body
|
|
42
|
+
- topic: the email topic
|
|
43
|
+
- images: if any, the Dict of image tags:image paths to incorporate in the email
|
|
44
|
+
"""
|
|
45
|
+
em = MIMEMultipart('related')
|
|
46
|
+
em['From'] = _SENDER
|
|
47
|
+
em['To'] = email
|
|
48
|
+
em['Subject'] = topic
|
|
49
|
+
em.attach(MIMEText(body, 'html'))
|
|
50
|
+
for tag, img_path in (images or {}).items():
|
|
51
|
+
with open(img_path, 'rb') as fp:
|
|
52
|
+
img = MIMEImage(fp.read())
|
|
53
|
+
img.add_header('Content-ID', f'<{tag}>')
|
|
54
|
+
em.attach(img)
|
|
55
|
+
|
|
56
|
+
with SMTP(_SMTP, _PORT) as server:
|
|
57
|
+
server.ehlo()
|
|
58
|
+
server.starttls(context=create_default_context())
|
|
59
|
+
server.login(_SENDER, _PASSWD)
|
|
60
|
+
server.sendmail(_SENDER, email, em.as_string())
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module implementing simple fernet AES128 encryption/decryption
|
|
3
|
+
"""
|
|
4
|
+
from cryptography.fernet import Fernet
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
from pydantic_settings import SettingsConfigDict
|
|
7
|
+
|
|
8
|
+
from ecodev_core.settings import SETTINGS
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EncryptionConf(BaseSettings):
|
|
12
|
+
"""
|
|
13
|
+
Simple authentication configuration class
|
|
14
|
+
"""
|
|
15
|
+
fernet_key: str = ''
|
|
16
|
+
model_config = SettingsConfigDict(env_file='.env')
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
SECRET_KEY = SETTINGS.fernet_key or EncryptionConf().fernet_key # type: ignore[attr-defined]
|
|
20
|
+
FERNET = Fernet(SECRET_KEY.encode())
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def encrypt_value(value):
|
|
24
|
+
"""
|
|
25
|
+
Encrypt a value using Fernet symmetric encryption.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
value: Value to encrypt (will be converted to string)
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Encrypted bytes
|
|
32
|
+
"""
|
|
33
|
+
return FERNET.encrypt(str(value).encode())
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def decrypt_value(encrypted):
|
|
37
|
+
"""
|
|
38
|
+
Decrypt an encrypted value and convert to float.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
encrypted: Encrypted bytes to decrypt
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Decrypted value as float
|
|
45
|
+
"""
|
|
46
|
+
return float(FERNET.decrypt(encrypted).decode())
|