ormlambda 3.7.1__py3-none-any.whl → 3.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. ormlambda/__init__.py +2 -0
  2. ormlambda/caster/base_caster.py +3 -3
  3. ormlambda/common/global_checker.py +1 -1
  4. ormlambda/components/select/ISelect.py +4 -4
  5. ormlambda/components/select/__init__.py +1 -1
  6. ormlambda/databases/my_sql/caster/caster.py +1 -0
  7. ormlambda/databases/my_sql/caster/types/bytes.py +3 -3
  8. ormlambda/databases/my_sql/caster/types/datetime.py +3 -3
  9. ormlambda/databases/my_sql/caster/types/float.py +3 -3
  10. ormlambda/databases/my_sql/caster/types/int.py +3 -3
  11. ormlambda/databases/my_sql/caster/types/iterable.py +3 -3
  12. ormlambda/databases/my_sql/caster/types/none.py +3 -3
  13. ormlambda/databases/my_sql/caster/types/string.py +3 -3
  14. ormlambda/databases/my_sql/clauses/__init__.py +1 -0
  15. ormlambda/databases/my_sql/clauses/alias.py +15 -21
  16. ormlambda/databases/my_sql/clauses/group_by.py +19 -20
  17. ormlambda/databases/my_sql/clauses/having.py +16 -0
  18. ormlambda/databases/my_sql/clauses/order.py +6 -1
  19. ormlambda/databases/my_sql/clauses/update.py +1 -1
  20. ormlambda/databases/my_sql/clauses/where.py +3 -3
  21. ormlambda/databases/my_sql/functions/concat.py +8 -6
  22. ormlambda/databases/my_sql/join_context.py +3 -3
  23. ormlambda/databases/my_sql/repository/repository.py +60 -13
  24. ormlambda/databases/my_sql/statements.py +73 -22
  25. ormlambda/databases/my_sql/types.py +73 -0
  26. ormlambda/engine/__init__.py +2 -0
  27. ormlambda/engine/create.py +35 -0
  28. ormlambda/engine/url.py +744 -0
  29. ormlambda/engine/utils.py +17 -0
  30. ormlambda/repository/base_repository.py +2 -3
  31. ormlambda/repository/interfaces/IRepositoryBase.py +1 -0
  32. ormlambda/sql/column.py +27 -2
  33. ormlambda/sql/foreign_key.py +36 -4
  34. ormlambda/sql/table/table_constructor.py +2 -2
  35. ormlambda/statements/interfaces/IStatements.py +37 -25
  36. ormlambda/statements/types.py +4 -1
  37. {ormlambda-3.7.1.dist-info → ormlambda-3.11.0.dist-info}/METADATA +107 -8
  38. {ormlambda-3.7.1.dist-info → ormlambda-3.11.0.dist-info}/RECORD +40 -35
  39. {ormlambda-3.7.1.dist-info → ormlambda-3.11.0.dist-info}/LICENSE +0 -0
  40. {ormlambda-3.7.1.dist-info → ormlambda-3.11.0.dist-info}/WHEEL +0 -0
ormlambda/__init__.py CHANGED
@@ -20,3 +20,5 @@ from .model.base_model import (
20
20
  BaseModel as BaseModel,
21
21
  ORM as ORM,
22
22
  ) # COMMENT: to avoid relative import we need to import BaseModel after import Table,Column, ForeignKey, IRepositoryBase and Disassembler
23
+
24
+ from .engine import create_engine, URL # noqa: F401
@@ -14,7 +14,7 @@ class BaseCaster[TProp, TType](abc.ABC):
14
14
  def wildcard_to_select(self, value: str) -> str: ...
15
15
  @overload
16
16
  def wildcard_to_select(self) -> str: ...
17
-
17
+
18
18
  @abc.abstractmethod
19
19
  def wildcard_to_select(self) -> str: ...
20
20
 
@@ -22,7 +22,7 @@ class BaseCaster[TProp, TType](abc.ABC):
22
22
  def wildcard_to_where(self, value: str) -> str: ...
23
23
  @overload
24
24
  def wildcard_to_where(self) -> str: ...
25
-
25
+
26
26
  @abc.abstractmethod
27
27
  def wildcard_to_where(self) -> str: ...
28
28
 
@@ -30,7 +30,7 @@ class BaseCaster[TProp, TType](abc.ABC):
30
30
  def wildcard_to_insert(self, value: str) -> str: ...
31
31
  @overload
32
32
  def wildcard_to_insert(self) -> str: ...
33
-
33
+
34
34
  @abc.abstractmethod
35
35
  def wildcard_to_insert(self) -> str: ...
36
36
 
@@ -23,6 +23,6 @@ class GlobalChecker:
23
23
  except TypeError as err:
24
24
  cond1 = r"takes \d+ positional argument but \d+ were given"
25
25
  cond2 = r"missing \d+ required positional arguments:"
26
- if re.search(r"("+f"{cond1}|{cond2}"+r")", err.args[0]):
26
+ if re.search(r"(" + f"{cond1}|{cond2}" + r")", err.args[0]):
27
27
  raise UnmatchedLambdaParameterError(len(tables), obj)
28
28
  raise err
@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
6
6
  if TYPE_CHECKING:
7
7
  from ormlambda.sql.clause_info import ClauseInfo
8
8
 
9
+
9
10
  class ISelect(IQuery):
10
11
  @property
11
12
  @abc.abstractmethod
12
- def FROM(self)->ClauseInfo: ...
13
-
13
+ def FROM(self) -> ClauseInfo: ...
14
+
14
15
  @property
15
16
  @abc.abstractmethod
16
- def COLUMNS(self)->str: ...
17
-
17
+ def COLUMNS(self) -> str: ...
@@ -1 +1 @@
1
- from .ISelect import ISelect # noqa: F401
1
+ from .ISelect import ISelect # noqa: F401
@@ -23,6 +23,7 @@ class MySQLCaster(ICaster):
23
23
  NoneType: NoneTypeCaster,
24
24
  datetime: DatetimeCaster,
25
25
  bytes: BytesCaster,
26
+ bytearray: BytesCaster,
26
27
  tuple: IterableCaster,
27
28
  list: IterableCaster,
28
29
  }
@@ -6,13 +6,13 @@ class BytesCaster[TType](BaseCaster[bytes, TType]):
6
6
  def __init__(self, value: bytes, type_value: TType):
7
7
  super().__init__(value, type_value)
8
8
 
9
- def wildcard_to_select(self, value:str = PLACEHOLDER) -> str:
9
+ def wildcard_to_select(self, value: str = PLACEHOLDER) -> str:
10
10
  return value
11
11
 
12
- def wildcard_to_where(self, value:str = PLACEHOLDER) -> str:
12
+ def wildcard_to_where(self, value: str = PLACEHOLDER) -> str:
13
13
  return value
14
14
 
15
- def wildcard_to_insert(self, value:str = PLACEHOLDER) -> str:
15
+ def wildcard_to_insert(self, value: str = PLACEHOLDER) -> str:
16
16
  return value
17
17
 
18
18
  @property
@@ -8,13 +8,13 @@ class DatetimeCaster[TType](BaseCaster[datetime, TType]):
8
8
  def __init__(self, value: datetime, type_value: TType):
9
9
  super().__init__(value, type_value)
10
10
 
11
- def wildcard_to_select(self, value:str=PLACEHOLDER) -> str:
11
+ def wildcard_to_select(self, value: str = PLACEHOLDER) -> str:
12
12
  return value
13
13
 
14
- def wildcard_to_where(self, value:str=PLACEHOLDER) -> str:
14
+ def wildcard_to_where(self, value: str = PLACEHOLDER) -> str:
15
15
  return value
16
16
 
17
- def wildcard_to_insert(self, value:str=PLACEHOLDER) -> str:
17
+ def wildcard_to_insert(self, value: str = PLACEHOLDER) -> str:
18
18
  return value
19
19
 
20
20
  @property
@@ -6,13 +6,13 @@ class FloatCaster[TType](BaseCaster[float, TType]):
6
6
  def __init__(self, value: float, type_value: TType):
7
7
  super().__init__(value, type_value)
8
8
 
9
- def wildcard_to_select(self, value:str=PLACEHOLDER) -> str:
9
+ def wildcard_to_select(self, value: str = PLACEHOLDER) -> str:
10
10
  return value
11
11
 
12
- def wildcard_to_where(self, value:str=PLACEHOLDER) -> str:
12
+ def wildcard_to_where(self, value: str = PLACEHOLDER) -> str:
13
13
  return value
14
14
 
15
- def wildcard_to_insert(self, value:str=PLACEHOLDER) -> str:
15
+ def wildcard_to_insert(self, value: str = PLACEHOLDER) -> str:
16
16
  return value
17
17
 
18
18
  @property
@@ -6,13 +6,13 @@ class IntegerCaster[TType](BaseCaster[int, TType]):
6
6
  def __init__(self, value: int, type_value: TType):
7
7
  super().__init__(value, type_value)
8
8
 
9
- def wildcard_to_select(self, value:str=PLACEHOLDER) -> str:
9
+ def wildcard_to_select(self, value: str = PLACEHOLDER) -> str:
10
10
  return value
11
11
 
12
- def wildcard_to_where(self, value:str=PLACEHOLDER) -> str:
12
+ def wildcard_to_where(self, value: str = PLACEHOLDER) -> str:
13
13
  return value
14
14
 
15
- def wildcard_to_insert(self, value:str=PLACEHOLDER) -> str:
15
+ def wildcard_to_insert(self, value: str = PLACEHOLDER) -> str:
16
16
  return value
17
17
 
18
18
  @property
@@ -6,13 +6,13 @@ class IterableCaster[TType](BaseCaster[bytes, TType]):
6
6
  def __init__(self, value: bytes, type_value: TType):
7
7
  super().__init__(value, type_value)
8
8
 
9
- def wildcard_to_select(self, value:str=PLACEHOLDER) -> str:
9
+ def wildcard_to_select(self, value: str = PLACEHOLDER) -> str:
10
10
  return value
11
11
 
12
- def wildcard_to_where(self, value:str=PLACEHOLDER) -> str:
12
+ def wildcard_to_where(self, value: str = PLACEHOLDER) -> str:
13
13
  return value
14
14
 
15
- def wildcard_to_insert(self, value:str=PLACEHOLDER) -> str:
15
+ def wildcard_to_insert(self, value: str = PLACEHOLDER) -> str:
16
16
  return value
17
17
 
18
18
  @property
@@ -6,13 +6,13 @@ class NoneTypeCaster[TType](BaseCaster[NoneType, TType]):
6
6
  def __init__(self, value: NoneType, type_value: TType):
7
7
  super().__init__(value, type_value)
8
8
 
9
- def wildcard_to_select(self, value:str=PLACEHOLDER) -> str:
9
+ def wildcard_to_select(self, value: str = PLACEHOLDER) -> str:
10
10
  return value
11
11
 
12
- def wildcard_to_where(self, value:str=PLACEHOLDER) -> str:
12
+ def wildcard_to_where(self, value: str = PLACEHOLDER) -> str:
13
13
  return value
14
14
 
15
- def wildcard_to_insert(self, value:str=PLACEHOLDER) -> str:
15
+ def wildcard_to_insert(self, value: str = PLACEHOLDER) -> str:
16
16
  return value
17
17
 
18
18
  # TODOL: cheched if it's right
@@ -6,13 +6,13 @@ class StringCaster[TType](BaseCaster[str, TType]):
6
6
  def __init__(self, value: str, type_value: TType):
7
7
  super().__init__(value, type_value)
8
8
 
9
- def wildcard_to_select(self, value:str = PLACEHOLDER) -> str:
9
+ def wildcard_to_select(self, value: str = PLACEHOLDER) -> str:
10
10
  return value
11
11
 
12
- def wildcard_to_where(self, value:str = PLACEHOLDER) -> str:
12
+ def wildcard_to_where(self, value: str = PLACEHOLDER) -> str:
13
13
  return value
14
14
 
15
- def wildcard_to_insert(self, value:str = PLACEHOLDER) -> str:
15
+ def wildcard_to_insert(self, value: str = PLACEHOLDER) -> str:
16
16
  return value
17
17
 
18
18
  @property
@@ -12,6 +12,7 @@ from .order import Order as Order
12
12
  from .update import UpdateQuery as UpdateQuery
13
13
  from .upsert import UpsertQuery as UpsertQuery
14
14
  from .where import Where as Where
15
+ from .having import Having as Having
15
16
  from .count import Count as Count
16
17
  from .group_by import GroupBy as GroupBy
17
18
  from .alias import Alias as Alias
@@ -1,10 +1,10 @@
1
1
  from __future__ import annotations
2
2
  import typing as tp
3
3
 
4
- from ormlambda import Table, Column
4
+ from ormlambda import Table
5
5
  from ormlambda.sql.clause_info import ClauseInfo
6
- from ormlambda.common.interfaces.IQueryCommand import IQuery
7
- from ormlambda.sql.clause_info import ClauseInfoContext, IAggregate
6
+ from ormlambda.sql.clause_info.clause_info_context import ClauseContextType
7
+ from ormlambda.sql.types import TableType
8
8
 
9
9
  if tp.TYPE_CHECKING:
10
10
  from ormlambda.sql.types import ColumnType
@@ -12,23 +12,17 @@ if tp.TYPE_CHECKING:
12
12
  from ormlambda.sql.types import AliasType
13
13
 
14
14
 
15
- class Alias[T: Table, TProp](ClauseInfo, IQuery):
16
- def __init__(
15
+ class Alias[T: Table](ClauseInfo[T]):
16
+ def __init__[TProp](
17
17
  self,
18
- element: IAggregate | ClauseInfo[T] | ColumnType[TProp],
19
- alias_clause: tp.Optional[AliasType[ClauseInfo[T]]],
18
+ table: TableType[T],
19
+ column: tp.Optional[ColumnType[TProp]] = None,
20
+ alias_table: tp.Optional[AliasType[ClauseInfo[T]]] = None,
21
+ alias_clause: tp.Optional[AliasType[ClauseInfo[T]]] = None,
22
+ context: ClauseContextType = None,
23
+ keep_asterisk: bool = False,
24
+ preserve_context: bool = False,
20
25
  ):
21
- if isinstance(element, Column):
22
- context = ClauseInfoContext()
23
- element = ClauseInfo(table=element.table, column=element, alias_clause=alias_clause)
24
- else:
25
- context = ClauseInfoContext(table_context=element.context._table_context, clause_context={})
26
- element.context = context
27
- element._alias_clause = alias_clause
28
-
29
- self._element = element
30
-
31
- @tp.override
32
- @property
33
- def query(self) -> str:
34
- return self._element.query
26
+ if not alias_clause:
27
+ raise TypeError
28
+ super().__init__(table, column, alias_table, alias_clause, context, keep_asterisk, preserve_context)
@@ -1,32 +1,31 @@
1
1
  import typing as tp
2
- from ormlambda.common.enums.join_type import JoinType
3
- from ormlambda.sql.clause_info import IAggregate
4
2
  from ormlambda import Table
5
- from ormlambda.common.abstract_classes.decomposition_query import DecompositionQueryBase
3
+ from ormlambda.sql.clause_info.clause_info import AggregateFunctionBase
4
+ from ormlambda.sql.clause_info.clause_info_context import ClauseInfoContext
5
+ from ormlambda.sql.types import ColumnType
6
6
 
7
7
 
8
- class GroupBy[T: tp.Type[Table], *Ts, TProp](DecompositionQueryBase[T], IAggregate):
9
- CLAUSE: str = "GROUP BY"
8
+ class GroupBy[T: tp.Type[Table], *Ts, TProp](AggregateFunctionBase):
9
+ @classmethod
10
+ def FUNCTION_NAME(self) -> str:
11
+ return "GROUP BY"
10
12
 
11
13
  def __init__(
12
14
  self,
13
- table: T,
14
- column: tp.Callable[[T, *Ts], TProp],
15
- *,
16
- alias: bool = True,
17
- alias_name: str | None = None,
18
- by: JoinType = JoinType.INNER_JOIN,
19
- ) -> None:
15
+ column: ColumnType,
16
+ context: ClauseInfoContext,
17
+ **kwargs,
18
+ ):
20
19
  super().__init__(
21
- table,
22
- columns=column,
23
- alias=alias,
24
- alias_name=alias_name,
25
- by=by,
20
+ table=column.table,
21
+ column=column,
22
+ alias_table=None,
23
+ alias_clause=None,
24
+ context=context,
25
+ **kwargs,
26
26
  )
27
27
 
28
28
  @property
29
29
  def query(self) -> str:
30
- col: str = ", ".join([x.query for x in self.all_clauses])
31
-
32
- return f"{self.CLAUSE} {col}"
30
+ column = self._create_query()
31
+ return f"{self.FUNCTION_NAME()} {column}"
@@ -0,0 +1,16 @@
1
+ from __future__ import annotations
2
+
3
+ from .where import Where
4
+
5
+
6
+ class Having(Where):
7
+ """
8
+ The purpose of this class is to create 'WHERE' condition queries properly.
9
+ """
10
+
11
+ def __init__(self, *comparer, restrictive=True, context=None):
12
+ super().__init__(*comparer, restrictive=restrictive, context=context)
13
+
14
+ @staticmethod
15
+ def FUNCTION_NAME() -> str:
16
+ return "HAVING"
@@ -48,13 +48,18 @@ class Order(AggregateFunctionBase):
48
48
  columns = self.unresolved_column
49
49
 
50
50
  # if this attr is not iterable means that we only pass one column without wrapped in a list or tuple
51
+ if isinstance(columns, str):
52
+ string_columns = f"{columns} {str(self._order_type[0])}"
53
+ return f"{self.FUNCTION_NAME()} {string_columns}"
54
+
51
55
  if not isinstance(columns, tp.Iterable):
52
56
  columns = (columns,)
57
+
53
58
  assert len(columns) == len(self._order_type)
54
59
 
55
60
  context = ClauseInfoContext(table_context=self._context._table_context, clause_context=None) if self._context else None
56
61
  for index, clause in enumerate(self._convert_into_clauseInfo(columns, context)):
57
62
  clause.alias_clause = None
58
- string_columns.append(f"{clause.query} {self._order_type[index].value}")
63
+ string_columns.append(f"{clause.query} {str(self._order_type[index])}")
59
64
 
60
65
  return f"{self.FUNCTION_NAME()} {', '.join(string_columns)}"
@@ -56,7 +56,7 @@ class UpdateQuery[T: Type[Table]](UpdateQueryBase[T, IRepositoryBase]):
56
56
 
57
57
  if self.__is_valid__(col):
58
58
  clean_data = CASTER.for_value(value)
59
- col_names.append((col.column_name,clean_data.wildcard_to_insert()))
59
+ col_names.append((col.column_name, clean_data.wildcard_to_insert()))
60
60
  self._values.append(clean_data.to_database)
61
61
 
62
62
  set_query: str = ",".join(["=".join(col_data) for col_data in col_names])
@@ -32,8 +32,8 @@ class Where(AggregateFunctionBase):
32
32
  def alias_clause(self) -> None:
33
33
  return None
34
34
 
35
- @staticmethod
36
- def join_condition(wheres: tp.Iterable[Where], restrictive: bool, context: ClauseInfoContext) -> str:
35
+ @classmethod
36
+ def join_condition(cls, wheres: tp.Iterable[Where], restrictive: bool, context: ClauseInfoContext) -> str:
37
37
  if not isinstance(wheres, tp.Iterable):
38
38
  wheres = (wheres,)
39
39
 
@@ -42,4 +42,4 @@ class Where(AggregateFunctionBase):
42
42
  for c in where._comparer:
43
43
  c.set_context(context)
44
44
  comparers.append(c)
45
- return Where(*comparers, restrictive=restrictive, context=context).query
45
+ return cls(*comparers, restrictive=restrictive, context=context).query
@@ -1,20 +1,22 @@
1
+ import typing as tp
2
+
1
3
  from ormlambda.sql.clause_info import AggregateFunctionBase
2
4
  from ormlambda.sql.clause_info.clause_info_context import ClauseInfoContext, ClauseContextType
3
-
4
-
5
- import typing as tp
6
5
  from ormlambda.sql.types import ColumnType, AliasType
7
6
  from ormlambda.sql.clause_info import ClauseInfo
8
7
 
9
8
 
10
- class Concat[*Ts](AggregateFunctionBase):
9
+ type ConcatResponse[TProp] = tuple[str | ColumnType[TProp]]
10
+
11
+
12
+ class Concat[T](AggregateFunctionBase):
11
13
  @staticmethod
12
14
  def FUNCTION_NAME() -> str:
13
15
  return "CONCAT"
14
16
 
15
17
  def __init__[TProp](
16
18
  self,
17
- values: ColumnType[Ts] | tuple[ColumnType[Ts], ...],
19
+ values: ConcatResponse[TProp],
18
20
  alias_clause: AliasType[ColumnType[TProp]] = "concat",
19
21
  context: ClauseContextType = None,
20
22
  ) -> None:
@@ -33,6 +35,6 @@ class Concat[*Ts](AggregateFunctionBase):
33
35
  context = ClauseInfoContext(table_context=self._context._table_context, clause_context=None) if self._context else None
34
36
 
35
37
  for clause in self._convert_into_clauseInfo(self.unresolved_column, context=context):
36
- clause.alias_clause = None
38
+ clause.alias_clause = self.alias_clause
37
39
  columns.append(clause)
38
40
  return self._concat_alias_and_column(f"{self.FUNCTION_NAME()}({ClauseInfo.join_clauses(columns)})", self._alias_aggregate)
@@ -27,12 +27,12 @@ class JoinContext[TParent: Table, TRepo]:
27
27
  for comparer, by in self._joins:
28
28
  fk_clause, alias = self.get_fk_clause(comparer)
29
29
 
30
- foreign_key: ForeignKey = ForeignKey(comparer=comparer, clause_name=alias)
30
+ foreign_key: ForeignKey = ForeignKey(comparer=comparer, clause_name=alias, keep_alive=True)
31
31
  fk_clause.alias_table = foreign_key.alias
32
32
  self._context.add_clause_to_context(fk_clause)
33
33
  setattr(self._parent, alias, foreign_key)
34
34
 
35
- # TODOH []: We need to preserve the 'foreign_key' variable while inside the 'with' clause.
35
+ # TODOH [x]: We need to preserve the 'foreign_key' variable while inside the 'with' clause.
36
36
  # Keep in mind that 'ForeignKey.stored_calls' is cleared every time we call methods like
37
37
  # .select(), .select_one(), .insert(), .update(), or .count(). This means we only retain
38
38
  # the context from the first call of any of these methods.
@@ -49,6 +49,7 @@ class JoinContext[TParent: Table, TRepo]:
49
49
  fk: ForeignKey = getattr(self._parent, attribute)
50
50
  delattr(self._parent, attribute)
51
51
  del self._context._table_context[fk.tright]
52
+ ForeignKey.stored_calls.remove(fk)
52
53
  return None
53
54
 
54
55
  def __getattr__(self, name: str) -> TParent:
@@ -72,4 +73,3 @@ class JoinContext[TParent: Table, TRepo]:
72
73
  parent_table = conditions.difference(model).pop()
73
74
 
74
75
  return clause_dicc[parent_table], clause_dicc[parent_table].table.__name__
75
-
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
  import contextlib
3
3
  from pathlib import Path
4
- from typing import Any, Generator, Iterable, Optional, Type, override, TYPE_CHECKING
4
+ from typing import Any, Generator, Iterable, Optional, Type, override, TYPE_CHECKING, Unpack
5
+ import uuid
5
6
  import shapely as shp
6
7
 
7
8
  # from mysql.connector.pooling import MySQLConnectionPool
@@ -16,12 +17,14 @@ from ormlambda.caster import Caster
16
17
  from ..clauses import CreateDatabase, TypeExists
17
18
  from ..clauses import DropDatabase
18
19
  from ..clauses import DropTable
20
+ from ..clauses import Alias
19
21
 
20
22
 
21
23
  if TYPE_CHECKING:
22
24
  from ormlambda.common.abstract_classes.decomposition_query import ClauseInfo
23
25
  from ormlambda import Table
24
26
  from ormlambda.databases.my_sql.clauses.select import Select
27
+ from ..types import MySQLArgs
25
28
 
26
29
  type TResponse[TFlavour, *Ts] = TFlavour | tuple[dict[str, tuple[*Ts]]] | tuple[tuple[*Ts]] | tuple[TFlavour]
27
30
 
@@ -67,12 +70,15 @@ class Response[TFlavour, *Ts]:
67
70
 
68
71
  def _cast_to_flavour(self, data: list[tuple[*Ts]], **kwargs) -> list[dict[str, tuple[*Ts]]] | list[tuple[*Ts]] | list[TFlavour]:
69
72
  def _dict(**kwargs) -> list[dict[str, tuple[*Ts]]]:
73
+ nonlocal data
70
74
  return [dict(zip(self._columns, x)) for x in data]
71
75
 
72
76
  def _tuple(**kwargs) -> list[tuple[*Ts]]:
77
+ nonlocal data
73
78
  return data
74
79
 
75
80
  def _set(**kwargs) -> list[set]:
81
+ nonlocal data
76
82
  for d in data:
77
83
  n = len(d)
78
84
  for i in range(n):
@@ -83,12 +89,19 @@ class Response[TFlavour, *Ts]:
83
89
  return [set(x) for x in data]
84
90
 
85
91
  def _list(**kwargs) -> list[list]:
92
+ nonlocal data
86
93
  return [list(x) for x in data]
87
94
 
88
95
  def _default(**kwargs) -> list[TFlavour]:
89
- replacer_dicc: dict[str, str] = {x.alias_clause: x.column for x in self._select.all_clauses}
96
+ nonlocal data
97
+ replacer_dicc: dict[str, str] = {}
90
98
 
91
- cleaned_column_names = [replacer_dicc[col] for col in self._columns]
99
+ for col in self._select.all_clauses:
100
+ if hasattr(col, "_alias_aggregate") or col.alias_clause is None or isinstance(col, Alias):
101
+ continue
102
+ replacer_dicc[col.alias_clause] = col.column
103
+
104
+ cleaned_column_names = [replacer_dicc.get(col, col) for col in self._columns]
92
105
 
93
106
  result = []
94
107
  for attr in data:
@@ -155,8 +168,35 @@ class MySQLRepository(BaseRepository[MySQLConnectionPool]):
155
168
 
156
169
  #
157
170
 
158
- def __init__(self, **kwargs):
159
- super().__init__(MySQLConnectionPool, **kwargs)
171
+ def __init__(self, **kwargs: Unpack[MySQLArgs]):
172
+ timeout = self.__add_connection_timeout(kwargs)
173
+ name = self.__add_pool_name(kwargs)
174
+ size = self.__add_pool_size(kwargs)
175
+ attr = kwargs.copy()
176
+ attr["connection_timeout"] = timeout
177
+ attr["pool_name"] = name
178
+ attr["pool_size"] = size
179
+
180
+ super().__init__(MySQLConnectionPool, **attr)
181
+
182
+ @staticmethod
183
+ def __add_connection_timeout(kwargs: MySQLArgs) -> int:
184
+ if "connection_timeout" not in kwargs.keys():
185
+ return 60
186
+ return int(kwargs.pop("connection_timeout"))
187
+
188
+ @staticmethod
189
+ def __add_pool_name(kwargs: MySQLArgs) -> str:
190
+ if "pool_name" not in kwargs.keys():
191
+ return str(uuid.uuid4())
192
+
193
+ return kwargs.pop("pool_name")
194
+
195
+ @staticmethod
196
+ def __add_pool_size(kwargs: MySQLArgs) -> int:
197
+ if "pool_size" not in kwargs.keys():
198
+ return 5
199
+ return int(kwargs.pop("pool_size"))
160
200
 
161
201
  @contextlib.contextmanager
162
202
  def get_connection(self) -> Generator[MySQLConnection, None, None]:
@@ -253,7 +293,6 @@ class MySQLRepository(BaseRepository[MySQLConnectionPool]):
253
293
 
254
294
  @override
255
295
  def database_exists(self, name: str) -> bool:
256
- query = "SHOW DATABASES LIKE %s;"
257
296
  temp_config = self._pool._cnx_config
258
297
 
259
298
  config_without_db = temp_config.copy()
@@ -261,10 +300,11 @@ class MySQLRepository(BaseRepository[MySQLConnectionPool]):
261
300
  if "database" in config_without_db:
262
301
  config_without_db.pop("database")
263
302
  self._pool.set_config(**config_without_db)
303
+
264
304
  with self.get_connection() as cnx:
265
305
  with cnx.cursor(buffered=True) as cursor:
266
- cursor.execute(query, (name,))
267
- res = cursor.fetchmany(1)
306
+ cursor.execute("SHOW DATABASES LIKE %s;", (name,))
307
+ res = cursor.fetchmany(1)
268
308
 
269
309
  self._pool.set_config(**temp_config)
270
310
  return len(res) > 0
@@ -275,12 +315,11 @@ class MySQLRepository(BaseRepository[MySQLConnectionPool]):
275
315
 
276
316
  @override
277
317
  def table_exists(self, name: str) -> bool:
278
- query = "SHOW TABLES LIKE %s;"
279
318
  with self.get_connection() as cnx:
280
319
  if not cnx.database:
281
320
  raise Exception("No database selected")
282
321
  with cnx.cursor(buffered=True) as cursor:
283
- cursor.execute(query, (name,))
322
+ cursor.execute("SHOW TABLES LIKE %s;", (name,))
284
323
  res = cursor.fetchmany(1)
285
324
  return len(res) > 0
286
325
 
@@ -296,9 +335,17 @@ class MySQLRepository(BaseRepository[MySQLConnectionPool]):
296
335
 
297
336
  @property
298
337
  def database(self) -> Optional[str]:
299
- return self._data_config.get("database", None)
338
+ return self._pool._cnx_config.get("database", None)
300
339
 
301
340
  @database.setter
302
341
  def database(self, value: str) -> None:
303
- self._data_config["database"] = value
304
- self._pool.set_config(**self._data_config)
342
+ """Change the current database using USE statement"""
343
+
344
+ if not self.database_exists(value):
345
+ raise ValueError(f"You cannot set the non-existent '{value}' database.")
346
+
347
+ old_config: MySQLArgs = self._pool._cnx_config.copy()
348
+ old_config["database"] = value
349
+
350
+ self._pool._remove_connections()
351
+ self._pool = type(self)(**old_config)._pool