ormlambda 2.8.0__py3-none-any.whl → 2.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
  from pathlib import Path
3
- from typing import Any, Optional, Type, override, Callable
3
+ from typing import Any, Optional, Type, override, Callable, TYPE_CHECKING
4
4
  import functools
5
+ import shapely as shp
5
6
 
6
7
  # from mysql.connector.pooling import MySQLConnectionPool
7
8
  from mysql.connector import MySQLConnection, Error # noqa: F401
@@ -16,12 +17,21 @@ from .clauses import DropDatabase
16
17
  from .clauses import DropTable
17
18
 
18
19
 
20
+ if TYPE_CHECKING:
21
+ from src.ormlambda.common.abstract_classes.decomposition_query import ClauseInfo
22
+ from ormlambda import Table
23
+ from src.ormlambda.databases.my_sql.clauses.select import Select
24
+
25
+ type TResponse[TFlavour, *Ts] = TFlavour | tuple[dict[str, tuple[*Ts]]] | tuple[tuple[*Ts]] | tuple[TFlavour]
26
+
27
+
19
28
  class Response[TFlavour, *Ts]:
20
- def __init__(self, response_values: list[tuple[*Ts]], columns: tuple[str], flavour: Type[TFlavour], **kwargs) -> None:
29
+ def __init__(self, response_values: list[tuple[*Ts]], columns: tuple[str], flavour: Type[TFlavour], model: Optional[Table] = None, select: Optional[Select] = None) -> None:
21
30
  self._response_values: list[tuple[*Ts]] = response_values
22
31
  self._columns: tuple[str] = columns
23
32
  self._flavour: Type[TFlavour] = flavour
24
- self._kwargs: dict[str, Any] = kwargs
33
+ self._model: Table = model
34
+ self._select: Select = select
25
35
 
26
36
  self._response_values_index: int = len(self._response_values)
27
37
  # self.select_values()
@@ -38,21 +48,28 @@ class Response[TFlavour, *Ts]:
38
48
  def is_many(self) -> bool:
39
49
  return self._response_values_index > 1
40
50
 
41
- @property
42
- def response(self) -> tuple[dict[str, tuple[*Ts]]] | tuple[tuple[*Ts]] | tuple[TFlavour]:
51
+ def response(self, _tuple: bool, **kwargs) -> TResponse[TFlavour, *Ts]:
43
52
  if not self.is_there_response:
44
53
  return tuple([])
54
+ cleaned_response = self._response_values
55
+
56
+ if self._select is not None:
57
+ cleaned_response = self._parser_response()
45
58
 
46
- return tuple(self._cast_to_flavour(self._response_values))
59
+ cast_flavour = self._cast_to_flavour(cleaned_response, **kwargs)
60
+ if _tuple is not True:
61
+ return cast_flavour
47
62
 
48
- def _cast_to_flavour(self, data: list[tuple[*Ts]]) -> list[dict[str, tuple[*Ts]]] | list[tuple[*Ts]] | list[TFlavour]:
49
- def _dict() -> list[dict[str, tuple[*Ts]]]:
63
+ return tuple(cast_flavour)
64
+
65
+ def _cast_to_flavour(self, data: list[tuple[*Ts]], **kwargs) -> list[dict[str, tuple[*Ts]]] | list[tuple[*Ts]] | list[TFlavour]:
66
+ def _dict(**kwargs) -> list[dict[str, tuple[*Ts]]]:
50
67
  return [dict(zip(self._columns, x)) for x in data]
51
68
 
52
- def _tuple() -> list[tuple[*Ts]]:
69
+ def _tuple(**kwargs) -> list[tuple[*Ts]]:
53
70
  return data
54
71
 
55
- def _set() -> list[set]:
72
+ def _set(**kwargs) -> list[set]:
56
73
  for d in data:
57
74
  n = len(d)
58
75
  for i in range(n):
@@ -62,16 +79,52 @@ class Response[TFlavour, *Ts]:
62
79
  raise TypeError(f"unhashable type '{type(d[i])}' found in '{type(d)}' when attempting to cast the result into a '{set.__name__}' object")
63
80
  return [set(x) for x in data]
64
81
 
65
- def _default() -> list[TFlavour]:
66
- return [self._flavour(x, **self._kwargs) for x in data]
82
+ def _list(**kwargs) -> list[list]:
83
+ return [list(x) for x in data]
84
+
85
+ def _default(**kwargs) -> list[TFlavour]:
86
+ return self._flavour(data, **kwargs)
67
87
 
68
88
  selector: dict[Type[object], Any] = {
69
89
  dict: _dict,
70
90
  tuple: _tuple,
71
91
  set: _set,
92
+ list: _list,
72
93
  }
73
94
 
74
- return selector.get(self._flavour, _default)()
95
+ return selector.get(self._flavour, _default)(**kwargs)
96
+
97
+ def _parser_response(self) -> TFlavour:
98
+ new_response: list[list] = []
99
+ for row in self._response_values:
100
+ new_row: list = []
101
+ for i, data in enumerate(row):
102
+ alias = self._columns[i]
103
+ clause_info = self._select[alias]
104
+ if not self._is_parser_required(clause_info):
105
+ new_row = row
106
+ break
107
+ else:
108
+ parser_data = self.parser_data(clause_info, data)
109
+ new_row.append(parser_data)
110
+ if not isinstance(new_row, tuple):
111
+ new_row = tuple(new_row)
112
+
113
+ new_response.append(new_row)
114
+ return new_response
115
+
116
+ @staticmethod
117
+ def _is_parser_required[T: Table](clause_info: ClauseInfo[T]) -> bool:
118
+ if clause_info is None:
119
+ return False
120
+
121
+ return clause_info.dtype is shp.Point
122
+
123
+ @staticmethod
124
+ def parser_data[T: Table, TProp](clause_info: ClauseInfo[T], data: TProp):
125
+ if clause_info.dtype is shp.Point:
126
+ return shp.from_wkt(data)
127
+ return data
75
128
 
76
129
 
77
130
  class MySQLRepository(IRepositoryBase[MySQLConnection]):
@@ -80,8 +133,7 @@ class MySQLRepository(IRepositoryBase[MySQLConnection]):
80
133
  def wrapper(self: MySQLRepository, *args, **kwargs):
81
134
  with self._pool.get_connection() as cnx:
82
135
  try:
83
- foo = func(self, cnx._cnx, *args, **kwargs)
84
- return foo
136
+ return func(self, cnx._cnx, *args, **kwargs)
85
137
  except Exception as e:
86
138
  cnx._cnx.rollback()
87
139
  raise e
@@ -97,7 +149,13 @@ class MySQLRepository(IRepositoryBase[MySQLConnection]):
97
149
 
98
150
  @override
99
151
  @get_connection
100
- def read_sql[TFlavour](self, cnx: MySQLConnection, query: str, flavour: Type[TFlavour] = tuple, **kwargs) -> tuple[TFlavour]:
152
+ def read_sql[TFlavour](
153
+ self,
154
+ cnx: MySQLConnection,
155
+ query: str,
156
+ flavour: Type[TFlavour] = tuple,
157
+ **kwargs,
158
+ ) -> tuple[TFlavour]:
101
159
  """
102
160
  Return tuple of tuples by default.
103
161
 
@@ -107,11 +165,15 @@ class MySQLRepository(IRepositoryBase[MySQLConnection]):
107
165
  - flavour: Type[TFlavour]: Useful to return tuple of any Iterable type as dict,set,list...
108
166
  """
109
167
 
168
+ model: Table = kwargs.pop("model", None)
169
+ select: Select = kwargs.pop("select", None)
170
+ cast_to_tuple: bool = kwargs.pop("cast_to_tuple", True)
171
+
110
172
  with cnx.cursor(buffered=True) as cursor:
111
173
  cursor.execute(query)
112
174
  values: list[tuple] = cursor.fetchall()
113
175
  columns: tuple[str] = cursor.column_names
114
- return Response[TFlavour](response_values=values, columns=columns, flavour=flavour, **kwargs).response
176
+ return Response[TFlavour](model=model, response_values=values, columns=columns, flavour=flavour, select=select).response(_tuple=cast_to_tuple, **kwargs)
115
177
 
116
178
  # FIXME [ ]: this method does not comply with the implemented interface
117
179
  @get_connection
@@ -192,7 +254,6 @@ class MySQLRepository(IRepositoryBase[MySQLConnection]):
192
254
  def create_database(self, name: str, if_exists: TypeExists = "fail") -> None:
193
255
  return CreateDatabase(self).execute(name, if_exists)
194
256
 
195
-
196
257
  @property
197
258
  def database(self) -> Optional[str]:
198
259
  return self._data_config.get("database", None)
@@ -2,17 +2,20 @@ from __future__ import annotations
2
2
  from typing import Iterable, override, Type, TYPE_CHECKING, Any, Callable, Optional
3
3
  import inspect
4
4
  from mysql.connector import MySQLConnection, errors, errorcode
5
+ import functools
5
6
 
6
7
 
7
8
  if TYPE_CHECKING:
8
9
  from ormlambda import Table
9
10
  from ormlambda.components.where.abstract_where import AbstractWhere
10
- from ormlambda.common.interfaces.IStatements import OrderType
11
+ from ormlambda.common.interfaces.IStatements import OrderTypes
11
12
  from ormlambda.common.interfaces import IQuery, IRepositoryBase, IStatements_two_generic
12
13
  from ormlambda.common.interfaces.IRepositoryBase import TypeExists
13
14
  from ormlambda.common.interfaces import IAggregate
14
15
  from ormlambda.common.interfaces.IStatements import WhereTypes
15
16
 
17
+ from ormlambda.utils.foreign_key import ForeignKey
18
+
16
19
  from ormlambda import AbstractSQLStatements
17
20
  from .clauses import DeleteQuery
18
21
  from .clauses import InsertQuery
@@ -29,13 +32,25 @@ from .clauses import Count
29
32
  from .clauses import GroupBy
30
33
 
31
34
 
32
- from ormlambda.utils import ForeignKey, Table
35
+ from ormlambda.utils import Table
33
36
  from ormlambda.common.enums import JoinType
34
37
  from . import functions as func
35
38
 
36
39
 
37
- class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
38
- def __init__(self, model: T, repository: IRepositoryBase[MySQLConnection]) -> None:
40
+ # COMMENT: It's so important to prevent information generated by other tests from being retained in the class.
41
+ def clear_list(f: Callable[..., Any]):
42
+ @functools.wraps(f)
43
+ def wrapper(self: MySQLStatements, *args, **kwargs):
44
+ try:
45
+ return f(self, *args, **kwargs)
46
+ finally:
47
+ self._query_list.clear()
48
+
49
+ return wrapper
50
+
51
+
52
+ class MySQLStatements[T: Table, *Ts](AbstractSQLStatements[T, *Ts, MySQLConnection]):
53
+ def __init__(self, model: tuple[T, *Ts], repository: IRepositoryBase[MySQLConnection]) -> None:
39
54
  super().__init__(model, repository=repository)
40
55
 
41
56
  @property
@@ -71,11 +86,11 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
71
86
  return self._repository.table_exists(self._model.__table_name__)
72
87
 
73
88
  @override
89
+ @clear_list
74
90
  def insert(self, instances: T | list[T]) -> None:
75
91
  insert = InsertQuery(self._model, self._repository)
76
92
  insert.insert(instances)
77
93
  insert.execute()
78
- self._query_list.clear()
79
94
  return None
80
95
 
81
96
  @override
@@ -95,36 +110,33 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
95
110
  return None
96
111
 
97
112
  @override
113
+ @clear_list
98
114
  def upsert(self, instances: T | list[T]) -> None:
99
115
  upsert = UpsertQuery(self._model, self._repository)
100
116
  upsert.upsert(instances)
101
117
  upsert.execute()
102
- self._query_list.clear()
103
118
  return None
104
119
 
105
120
  @override
121
+ @clear_list
106
122
  def update(self, dicc: dict[str, Any] | list[dict[str, Any]]) -> None:
107
123
  update = UpdateQuery(self._model, self._repository, self._query_list["where"])
108
124
  update.update(dicc)
109
125
  update.execute()
110
- self._query_list.clear()
126
+
111
127
  return None
112
128
 
113
129
  @override
114
130
  def limit(self, number: int) -> IStatements_two_generic[T, MySQLConnection]:
115
131
  limit = LimitQuery(number)
116
132
  # Only can be one LIMIT SQL parameter. We only use the last LimitQuery
117
- limit_list = self._query_list["limit"]
118
- if len(limit_list) > 0:
119
- self._query_list["limit"] = [limit]
120
- else:
121
- self._query_list["limit"].append(limit)
133
+ self._query_list["limit"] = [limit]
122
134
  return self
123
135
 
124
136
  @override
125
137
  def offset(self, number: int) -> IStatements_two_generic[T, MySQLConnection]:
126
138
  offset = OffsetQuery(number)
127
- self._query_list["offset"].append(offset)
139
+ self._query_list["offset"] = [offset]
128
140
  return self
129
141
 
130
142
  @override
@@ -132,24 +144,17 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
132
144
  self,
133
145
  selection: Callable[[T], tuple] = lambda x: "*",
134
146
  alias=True,
135
- alias_name=None,
147
+ alias_name="count",
136
148
  ) -> IQuery:
137
149
  return Count[T](self._model, selection, alias=alias, alias_name=alias_name)
138
150
 
139
- @override
140
- def join(self, table_left: Table, table_right: Table, *, by: str) -> IStatements_two_generic[T, MySQLConnection]:
141
- where = ForeignKey.MAPPED[table_left.__table_name__][table_right.__table_name__]
142
- join_query = JoinSelector[table_left, Table](table_left, table_right, JoinType(by), where=where)
143
- self._query_list["join"].append(join_query)
144
- return self
145
-
146
151
  @override
147
152
  def where(self, conditions: WhereTypes = lambda: None, **kwargs) -> IStatements_two_generic[T, MySQLConnection]:
148
153
  # FIXME [x]: I've wrapped self._model into tuple to pass it instance attr. Idk if it's correct
149
154
 
150
155
  if isinstance(conditions, Iterable):
151
156
  for x in conditions:
152
- self._query_list["where"].append(WhereCondition[T](function=x, instances=(self._model,), **kwargs))
157
+ self._query_list["where"].append(WhereCondition[T](function=x, instances=self._models, **kwargs))
153
158
  return self
154
159
 
155
160
  where_query = WhereCondition[T](function=conditions, instances=(self._model,), **kwargs)
@@ -157,7 +162,7 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
157
162
  return self
158
163
 
159
164
  @override
160
- def order[TValue](self, _lambda_col: Callable[[T], TValue], order_type: OrderType) -> IStatements_two_generic[T, MySQLConnection]:
165
+ def order[TValue](self, _lambda_col: Callable[[T], TValue], order_type: OrderTypes) -> IStatements_two_generic[T, MySQLConnection]:
161
166
  order = OrderQuery[T](self._model, _lambda_col, order_type)
162
167
  self._query_list["order"].append(order)
163
168
  return self
@@ -179,7 +184,19 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
179
184
  return func.Sum[T](self._model, column=column, alias=alias, alias_name=alias_name)
180
185
 
181
186
  @override
182
- def select[TValue, TFlavour, *Ts](self, selector: Optional[Callable[[T], tuple[TValue, *Ts]]] = lambda: None, *, flavour: Optional[Type[TFlavour]] = None, by: JoinType = JoinType.INNER_JOIN):
187
+ def join[*FKTables](self, joins) -> IStatements_two_generic[T, *FKTables, MySQLConnection]:
188
+ if not isinstance(joins[0], tuple):
189
+ joins = (joins,)
190
+ new_tables: list[Type[Table]] = [self._model]
191
+ for table, where in joins:
192
+ new_tables.append(table)
193
+ join_query = JoinSelector[T, type(table)](self._model, table, by=JoinType.INNER_JOIN, where=where)
194
+ self._query_list["join"].append(join_query)
195
+ self._models = new_tables
196
+ return self
197
+
198
+ @override
199
+ def select[TValue, TFlavour, *Ts](self, selector: Optional[Callable[[T, *Ts], tuple[TValue, *Ts]]] = lambda: None, *, flavour: Optional[Type[TFlavour]] = None, by: JoinType = JoinType.INNER_JOIN, **kwargs):
183
200
  if len(inspect.signature(selector).parameters) == 0:
184
201
  # COMMENT: if we do not specify any lambda function we assumed the user want to retreive only elements of the Model itself avoiding other models
185
202
  result = self.select(selector=lambda x: (x,), flavour=flavour, by=by)
@@ -188,19 +205,28 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
188
205
  if flavour:
189
206
  return result
190
207
  return () if not result else result[0]
191
- select = Select[T](self._model, lambda_query=selector, by=by, alias=False)
208
+
209
+ joins = self._query_list.pop("join", None)
210
+ select = Select[T, *Ts](
211
+ self._models,
212
+ lambda_query=selector,
213
+ by=by,
214
+ alias=False,
215
+ joins=joins,
216
+ )
192
217
  self._query_list["select"].append(select)
193
218
 
194
- query: str = self._build()
219
+ self._query: str = self._build()
220
+
195
221
  if flavour:
196
- result = self._return_flavour(query, flavour)
197
- if issubclass(flavour, tuple) and isinstance(selector(self._model), property):
222
+ result = self._return_flavour(self.query, flavour, select, **kwargs)
223
+ if issubclass(flavour, tuple) and isinstance(selector(*self._models), property):
198
224
  return tuple([x[0] for x in result])
199
225
  return result
200
- return self._return_model(select, query)
226
+ return self._return_model(select, self.query)
201
227
 
202
228
  @override
203
- def select_one[TValue, TFlavour, *Ts](self, selector: Optional[Callable[[T], tuple[TValue, *Ts]]] = lambda: None, *, flavour: Optional[Type[TFlavour]] = None, by: JoinType = JoinType.INNER_JOIN):
229
+ def select_one[TValue, TFlavour, *Ts](self, selector: Optional[Callable[[T, *Ts], tuple[TValue, *Ts]]] = lambda: None, *, flavour: Optional[Type[TFlavour]] = None, by: JoinType = JoinType.INNER_JOIN):
204
230
  self.limit(1)
205
231
  if len(inspect.signature(selector).parameters) == 0:
206
232
  response = self.select(selector=lambda x: (x,), flavour=flavour, by=by)
@@ -217,48 +243,42 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
217
243
  return tuple([res[0] for res in response])
218
244
 
219
245
  @override
220
- def group_by[*Ts](self, column: str | Callable[[T], Any]) -> IStatements_two_generic[T, MySQLConnection]:
246
+ def group_by(self, column: str | Callable[[T, *Ts], Any]):
221
247
  if isinstance(column, str):
222
- groupby = GroupBy[T, tuple[*Ts]](self._model, lambda x: column)
248
+ groupby = GroupBy[T, tuple[*Ts]](self._models, lambda x: column)
223
249
  else:
224
- groupby = GroupBy[T, tuple[*Ts]](self._model, column)
250
+ groupby = GroupBy[T, tuple[*Ts]](self._models, column)
225
251
  # Only can be one LIMIT SQL parameter. We only use the last LimitQuery
226
252
  self._query_list["group by"].append(groupby)
227
253
  return self
228
254
 
229
255
  @override
256
+ @clear_list
230
257
  def _build(self) -> str:
231
- query: str = ""
232
-
258
+ query_list: list[str] = []
233
259
  for x in self.__order__:
234
- sub_query: Optional[list[IQuery]] = self._query_list.get(x, None)
260
+ if len(self._query_list) == 0:
261
+ break
262
+
263
+ sub_query: Optional[list[IQuery]] = self._query_list.pop(x, None)
235
264
  if sub_query is None:
236
265
  continue
237
266
 
238
267
  if isinstance(sub_query[0], WhereCondition):
239
268
  query_ = self.__build_where_clause(sub_query)
240
269
 
241
- # we must check if any join already exists on query string
242
- elif isinstance(sub_query[0], JoinSelector):
243
- select_query: str = self._query_list["select"][0].query
244
- query_ = ""
245
- for join in sub_query:
246
- if join.query not in select_query:
247
- query_ += f"\n{join.query}"
248
-
249
270
  elif isinstance((select := sub_query[0]), Select):
250
271
  query_: str = ""
251
272
  where_joins = self.__create_necessary_inner_join()
252
273
  if where_joins:
253
- select._fk_relationship.update(where_joins)
274
+ select._joins.update(where_joins)
254
275
  query_ = select.query
255
276
 
256
277
  else:
257
278
  query_ = "\n".join([x.query for x in sub_query])
258
279
 
259
- query += f"\n{query_}" if query != "" else query_
260
- self._query_list.clear()
261
- return query
280
+ query_list.append(query_)
281
+ return "\n".join(query_list)
262
282
 
263
283
  def __build_where_clause(self, where_condition: list[AbstractWhere]) -> str:
264
284
  query: str = where_condition[0].query
@@ -269,18 +289,21 @@ class MySQLStatements[T: Table](AbstractSQLStatements[T, MySQLConnection]):
269
289
  query += f" {and_} ({clause})"
270
290
  return query
271
291
 
272
- def __create_necessary_inner_join(self) -> Optional[set[tuple[Type[Table], Type[Table]]]]:
292
+ def __create_necessary_inner_join(self) -> Optional[set[JoinSelector]]:
273
293
  # When we applied filters in any table that we wont select any column, we need to add manually all neccessary joins to achieve positive result.
274
294
  if "where" not in self._query_list:
275
295
  return None
276
296
 
277
- res = []
278
297
  for where in self._query_list["where"]:
279
298
  where: AbstractWhere
280
299
 
281
300
  tables = where.get_involved_tables()
282
301
 
283
302
  if tables:
284
- [res.append(x) for x in tables]
285
-
286
- return set(res)
303
+ # FIXME [ ]: Refactor to avoid copy and paste the same code of the '_add_fk_relationship' method
304
+ joins = []
305
+ for ltable, rtable in tables:
306
+ lambda_relationship = ForeignKey.MAPPED[ltable.__table_name__].referenced_tables[rtable.__table_name__].relationship
307
+ joins.append(JoinSelector(ltable, rtable, JoinType.INNER_JOIN, where=lambda_relationship))
308
+ return set(joins)
309
+ return None
ormlambda/model_base.py CHANGED
@@ -10,20 +10,20 @@ from .databases.my_sql import MySQLStatements, MySQLRepository
10
10
  # endregion
11
11
 
12
12
 
13
- class BaseModel[T: Type[Table]]:
13
+ class BaseModel[T: Type[Table], *Ts]:
14
14
  """
15
15
  Class to select the correct AbstractSQLStatements class depends on the repository.
16
16
 
17
17
  Contiene los metodos necesarios para hacer consultas a una tabla
18
18
  """
19
19
 
20
- statements_dicc: dict[Type[IRepositoryBase], Type[AbstractSQLStatements[T, IRepositoryBase]]] = {
20
+ statements_dicc: dict[Type[IRepositoryBase], Type[AbstractSQLStatements[T, *Ts, IRepositoryBase]]] = {
21
21
  MySQLRepository: MySQLStatements,
22
22
  }
23
23
 
24
24
  # region Constructor
25
25
 
26
- def __new__[TRepo](cls, model: T, repository: IRepositoryBase[TRepo]) -> IStatements_two_generic[T, TRepo]:
26
+ def __new__[TRepo](cls, model: tuple[T, *Ts], repository: IRepositoryBase[TRepo]) -> IStatements_two_generic[T, *Ts, TRepo]:
27
27
  if repository is None:
28
28
  raise ValueError("`None` cannot be passed to the `repository` attribute when calling the `BaseModel` class")
29
29
  cls: AbstractSQLStatements[T, TRepo] = cls.statements_dicc.get(type(repository), None)
ormlambda/utils/column.py CHANGED
@@ -1,7 +1,14 @@
1
- from typing import Type
1
+ from __future__ import annotations
2
+ from typing import Type, Optional, Callable, TYPE_CHECKING, Any
3
+ import shapely as sph
4
+
5
+ if TYPE_CHECKING:
6
+ from .table_constructor import Field
2
7
 
3
8
 
4
9
  class Column[T]:
10
+ CHAR: str = "%s"
11
+
5
12
  __slots__ = (
6
13
  "dtype",
7
14
  "column_name",
@@ -31,22 +38,54 @@ class Column[T]:
31
38
  self.is_auto_increment: bool = is_auto_increment
32
39
  self.is_unique: bool = is_unique
33
40
 
41
+ @property
42
+ def column_value_to_query(self) -> T:
43
+ """
44
+ This property must ensure that any variable requiring casting by different database methods is properly wrapped.
45
+ """
46
+ if self.dtype is sph.Point:
47
+ return sph.to_wkt(self.column_value, -1)
48
+ return self.column_value
49
+
50
+ @property
51
+ def placeholder(self) -> str:
52
+ return self.placeholder_resolutor(self.dtype)
53
+
54
+ @property
55
+ def placeholder_resolutor(self) -> Callable[[Type, T], str]:
56
+ return self.__fetch_wrapped_method
57
+
58
+ # FIXME [ ]: this method is allocating the Column class with MySQL database
59
+ @classmethod
60
+ def __fetch_wrapped_method(cls, type_: Type) -> Optional[str]:
61
+ """
62
+ This method must ensure that any variable requiring casting by different database methods is properly wrapped.
63
+ """
64
+ caster: dict[Type[Any], Callable[[str], str]] = {
65
+ sph.Point: lambda x: f"ST_GeomFromText({x})",
66
+ }
67
+ return caster.get(type_, lambda x: x)(cls.CHAR)
68
+
34
69
  def __repr__(self) -> str:
35
- return f"<Column: {self.column_name}>"
70
+ return f"<Column: {self.dtype}>"
36
71
 
37
- def __to_string__(self, name: str, var_name: T, type_: str):
38
- dicc: dict = {
39
- "dtype": type_,
40
- "column_name": f"'{name}'",
41
- "column_value": var_name, # must be the same variable name as the instance variable name in Table's __init__ class
72
+ def __to_string__(self, field: Field):
73
+ column_class_string: str = f"{Column.__name__}[{field.type_name}]("
74
+
75
+ dicc: dict[str, Callable[[Field], str]] = {
76
+ "dtype": lambda field: field.type_name,
77
+ "column_name": lambda field: f"'{field.name}'",
78
+ "column_value": lambda field: field.name, # must be the same variable name as the instance variable name in Table's __init__ class
42
79
  }
43
- exec_str: str = f"{Column.__name__}[{type_}]("
44
- for x in self.__slots__:
45
- self_value = getattr(self, x)
80
+ for self_var in self.__init__.__annotations__:
81
+ if not hasattr(self, self_var):
82
+ continue
83
+
84
+ self_value = dicc.get(self_var, lambda field: getattr(self, self_var))(field)
85
+ column_class_string += f" {self_var}={self_value}, "
46
86
 
47
- exec_str += f" {x}={dicc.get(x,self_value)},\n"
48
- exec_str += ")"
49
- return exec_str
87
+ column_class_string += ")"
88
+ return column_class_string
50
89
 
51
90
  def __hash__(self) -> int:
52
91
  return hash(
ormlambda/utils/dtypes.py CHANGED
@@ -48,8 +48,10 @@ MySQL 8.0 does not support year in two-digit format.
48
48
  """
49
49
 
50
50
  from decimal import Decimal
51
- import datetime
52
51
  from typing import Any, Literal
52
+ import datetime
53
+
54
+ from shapely import Point
53
55
  import numpy as np
54
56
 
55
57
  from .column import Column
@@ -66,17 +68,7 @@ DATE = Literal["DATE", "DATETIME(fsp)", "TIMESTAMP(fsp)", "TIME(fsp)", "YEAR"]
66
68
  def transform_py_dtype_into_query_dtype(dtype: Any) -> str:
67
69
  # TODOL: must be found a better way to convert python data type into SQL clauses
68
70
  # float -> DECIMAL(5,2) is an error
69
- dicc: dict[Any, str] = {
70
- int: "INTEGER",
71
- float: "FLOAT(5,2)",
72
- Decimal: "FLOAT",
73
- datetime.datetime: "DATETIME",
74
- datetime.date: "DATE",
75
- bytes: "BLOB",
76
- bytearray: "BLOB",
77
- str: "VARCHAR(255)",
78
- np.uint64: "BIGINT UNSIGNED",
79
- }
71
+ dicc: dict[Any, str] = {int: "INTEGER", float: "FLOAT(5,2)", Decimal: "FLOAT", datetime.datetime: "DATETIME", datetime.date: "DATE", bytes: "BLOB", bytearray: "BLOB", str: "VARCHAR(255)", np.uint64: "BIGINT UNSIGNED", Point: "Point"}
80
72
 
81
73
  res = dicc.get(dtype, None)
82
74
  if res is None:
@@ -0,0 +1,60 @@
1
+ import typing as tp
2
+ from .column import Column
3
+
4
+ __all__ = ["get_fields"]
5
+
6
+ MISSING = lambda: Column() # COMMENT: Very Important to avoid reusing the same variable across different classes. # noqa: E731
7
+
8
+
9
+ class Field[TProp: tp.AnnotatedAny]:
10
+ def __init__(self, name: str, type_: tp.Type, default: Column[TProp]) -> None:
11
+ self.name: str = name
12
+ self.type_: tp.Type[TProp] = type_
13
+ self.default: Column[TProp] = default
14
+
15
+ def __repr__(self) -> str:
16
+ return f"{Field.__name__}(name = {self.name}, type_ = {self.type_}, default = {self.default})"
17
+
18
+ @property
19
+ def has_default(self) -> bool:
20
+ return self.default is not MISSING()
21
+
22
+ @property
23
+ def init_arg(self) -> str:
24
+ default = f"={self.default_name}" # if self.has_default else ""}"
25
+ return f"{self.name}: {self.type_name}{default}"
26
+
27
+ @property
28
+ def default_name(self) -> str:
29
+ return f"_dflt_{self.name}"
30
+
31
+ @property
32
+ def type_name(self) -> str:
33
+ return f"_type_{self.name}"
34
+
35
+ @property
36
+ def assginment(self) -> str:
37
+ return f"self._{self.name} = {self.default.__to_string__(self)}"
38
+
39
+
40
+ def get_fields[T, TProp](cls: tp.Type[T]) -> tp.Iterable[Field]:
41
+ # COMMENT: Used the 'get_type_hints' method to resolve typing when 'from __future__ import annotations' is in use
42
+ annotations = {key: val for key, val in tp.get_type_hints(cls).items() if not key.startswith("_")}
43
+
44
+ # delete_special_variables(annotations)
45
+ fields = []
46
+ for name, type_ in annotations.items():
47
+ if hasattr(type_, "__origin__") and type_.__origin__ is Column: # __origin__ to get type of Generic value
48
+ field_type = type_.__args__[0]
49
+ else:
50
+ # type_ must by Column object
51
+ field_type: TProp = type_
52
+
53
+ default: Column = getattr(cls, name, MISSING())
54
+
55
+ default.dtype = field_type # COMMENT: Useful for setting the dtype variable after instantiation.
56
+ fields.append(Field[TProp](name, field_type, default))
57
+
58
+ # Update __annotations__ to create Columns
59
+ cls.__annotations__[name] = default
60
+ return fields