sera-2 1.21.2__py3-none-any.whl → 1.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sera/libs/api_helper.py CHANGED
@@ -1,116 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import re
4
- from typing import Callable, Collection, Generic, Mapping, TypeVar, cast
3
+ from typing import Collection, Generic, TypeVar, cast
5
4
 
6
- from litestar import Request, status_codes
7
5
  from litestar.connection import ASGIConnection
8
6
  from litestar.dto import MsgspecDTO
9
7
  from litestar.dto._backend import DTOBackend
10
8
  from litestar.dto._codegen_backend import DTOCodegenBackend
11
9
  from litestar.enums import RequestEncodingType
12
- from litestar.exceptions import HTTPException
13
10
  from litestar.serialization import decode_json, decode_msgpack
14
11
  from litestar.typing import FieldDefinition
15
12
  from msgspec import Struct
16
13
 
17
- from sera.libs.base_service import Query, QueryOp
18
14
  from sera.libs.middlewares.uscp import SKIP_UPDATE_SYSTEM_CONTROLLED_PROPS_KEY
19
- from sera.typing import T
20
-
21
- # for parsing field names and operations from query string
22
- FIELD_REG = re.compile(r"(?P<name>[a-zA-Z_0-9]+)(?:\[(?P<op>[a-zA-Z_0-9]+)\])?")
23
- QUERY_OPS = {op.value for op in QueryOp}
24
- KEYWORDS = {"field", "limit", "offset", "unique", "sorted_by", "group_by"}
25
-
26
-
27
- class TypeConversion:
28
-
29
- to_int = int
30
- to_float = float
31
-
32
- @staticmethod
33
- def to_bool(value: str) -> bool:
34
- if value == "1":
35
- return True
36
- elif value == "0":
37
- return False
38
- raise ValueError(f"Invalid boolean value: {value}")
39
-
40
-
41
- def parse_query(
42
- request: Request,
43
- fields: Mapping[str, Callable[[str], str | int | bool | float]],
44
- debug: bool,
45
- ) -> Query:
46
- """Parse query for retrieving records that match a query.
47
-
48
- If a field name collides with a keyword, you can add `_` to the field name.
49
-
50
- To filter records, you can apply a condition on a column using <field>=<value> (equal condition). Or you can
51
- be explicit by using <field>[op]=<value>, where op is one of the operators defined in QueryOp.
52
- """
53
- query: Query = {}
54
-
55
- for k, v in request.query_params.items():
56
- if k in KEYWORDS:
57
- continue
58
- m = FIELD_REG.match(k)
59
- if m:
60
- field_name = m.group("name")
61
- operation = m.group("op") # This will be None if no operation is specified
62
-
63
- # If field name ends with '_' and it's to avoid keyword conflict, remove it
64
- if field_name.endswith("_") and field_name[:-1] in KEYWORDS:
65
- field_name = field_name[:-1]
66
-
67
- if field_name not in fields:
68
- # Invalid field name, skip
69
- if debug:
70
- raise HTTPException(
71
- status_code=status_codes.HTTP_400_BAD_REQUEST,
72
- detail=f"Invalid field name: {field_name}",
73
- )
74
- continue
75
-
76
- # Process based on operation or default to equality check
77
- # TODO: validate if the operation is allowed for the field
78
- if not operation:
79
- operation = QueryOp.eq
80
- else:
81
- if operation not in QUERY_OPS:
82
- raise HTTPException(
83
- status_code=status_codes.HTTP_400_BAD_REQUEST,
84
- detail=f"Invalid operation: {operation}",
85
- )
86
- operation = QueryOp(operation)
87
-
88
- try:
89
- norm_func = fields[field_name]
90
- if isinstance(v, list):
91
- v = [norm_func(x) for x in v]
92
- else:
93
- v = norm_func(v)
94
- except (ValueError, KeyError):
95
- if debug:
96
- raise HTTPException(
97
- status_code=status_codes.HTTP_400_BAD_REQUEST,
98
- detail=f"Invalid value for field {field_name}: {v}",
99
- )
100
- continue
101
-
102
- query[field_name] = {operation: v}
103
- else:
104
- # Invalid field name format
105
- if debug:
106
- raise HTTPException(
107
- status_code=status_codes.HTTP_400_BAD_REQUEST,
108
- detail=f"Invalid field name: {k}",
109
- )
110
- continue
111
-
112
- return query
113
-
114
15
 
115
16
  S = TypeVar("S", bound=Struct)
116
17
 
@@ -34,7 +34,7 @@ def test_get_by_id(
34
34
  assert (
35
35
  resp.status_code == 200
36
36
  ), f"Record {record} should exist but got {resp.status_code}"
37
- assert resp.json() == data
37
+ assert resp.json() == data, (resp.json(), data)
38
38
 
39
39
  for record in non_exist_records:
40
40
  resp = client.get(f"{base_url}/{record}")
sera/libs/base_service.py CHANGED
@@ -1,39 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
- from enum import Enum
4
- from math import dist
5
- from typing import Annotated, Any, Generic, NamedTuple, Optional, Sequence, TypeVar
3
+ from typing import Generic, NamedTuple, Optional, Sequence, TypeVar
6
4
 
7
5
  from litestar.exceptions import HTTPException
8
6
  from sqlalchemy import Result, Select, delete, exists, func, select
9
7
  from sqlalchemy.exc import IntegrityError
10
8
  from sqlalchemy.ext.asyncio import AsyncSession
11
- from sqlalchemy.orm import load_only
9
+ from sqlalchemy.orm import contains_eager, load_only
12
10
 
13
11
  from sera.libs.base_orm import BaseORM
14
- from sera.misc import assert_not_null
15
- from sera.models import Class
16
- from sera.typing import FieldName, T, doc
17
-
18
-
19
- class QueryOp(str, Enum):
20
- lt = "lt"
21
- lte = "lte"
22
- gt = "gt"
23
- gte = "gte"
24
- eq = "eq"
25
- ne = "ne"
26
- # select records where values are in the given list
27
- in_ = "in"
28
- not_in = "not_in"
29
- # for full text search
30
- fuzzy = "fuzzy"
31
-
32
-
33
- Query = Annotated[
34
- dict[FieldName, dict[QueryOp, Annotated[Any, doc("query value")]]],
35
- doc("query operations"),
36
- ]
12
+ from sera.libs.search_helper import Query, QueryOp
13
+ from sera.misc import assert_not_null, to_snake_case
14
+ from sera.models import Cardinality, Class, DataProperty, ObjectProperty
15
+
37
16
  R = TypeVar("R", bound=BaseORM)
38
17
  ID = TypeVar("ID") # ID of a class
39
18
  SqlResult = TypeVar("SqlResult", bound=Result)
@@ -41,21 +20,100 @@ SqlResult = TypeVar("SqlResult", bound=Result)
41
20
 
42
21
  class QueryResult(NamedTuple, Generic[R]):
43
22
  records: Sequence[R]
44
- total: int
23
+ total: Optional[int]
45
24
 
46
25
 
47
26
  class BaseAsyncService(Generic[ID, R]):
48
27
 
49
28
  instance = None
50
29
 
51
- def __init__(self, cls: Class, orm_cls: type[R]):
30
+ def __init__(self, cls: Class, orm_classes: dict[str, type[R]]):
31
+ # schema of the class
52
32
  self.cls = cls
53
- self.orm_cls = orm_cls
33
+ self.orm_cls = orm_classes[cls.name]
54
34
  self.id_prop = assert_not_null(cls.get_id_property())
55
35
 
56
36
  self._cls_id_prop = getattr(self.orm_cls, self.id_prop.name)
57
37
  self.is_id_auto_increment = assert_not_null(self.id_prop.db).is_auto_increment
58
38
 
39
+ self.prop2orm: dict[str, type] = {
40
+ prop.name: orm_classes[prop.target.name]
41
+ for prop in cls.properties.values()
42
+ if isinstance(prop, ObjectProperty) and prop.target.db is not None
43
+ }
44
+
45
+ # figure out the join clauses so we can join the tables
46
+ # for example, to join between User, UserGroup, and Group
47
+ # the query can look like this:
48
+ # select(User)
49
+ # .join(UserGroup, UserGroup.user_id == User.id)
50
+ # .join(Group, Group.id == UserGroup.group_id)
51
+ # .options(contains_eager(User.group).contains_eager(UserGroup.group))
52
+ self.join_clauses: dict[str, list[dict]] = {}
53
+ for prop in cls.properties.values():
54
+ if (
55
+ isinstance(prop, DataProperty)
56
+ and prop.db is not None
57
+ and prop.db.foreign_key is not None
58
+ ):
59
+ target_tbl = orm_classes[prop.db.foreign_key.name]
60
+ target_cls = prop.db.foreign_key
61
+ source_fk = prop.name
62
+ # the property storing the SQLAlchemy relationship of the foreign key
63
+ source_relprop = prop.name + "_relobj"
64
+ cardinality = Cardinality.ONE_TO_ONE
65
+ elif isinstance(prop, ObjectProperty) and prop.target.db is not None:
66
+ target_tbl = orm_classes[prop.target.name]
67
+ target_cls = prop.target
68
+ source_fk = prop.name + "_id"
69
+ source_relprop = prop.name
70
+ cardinality = prop.cardinality
71
+ else:
72
+ continue
73
+
74
+ if cardinality == Cardinality.MANY_TO_MANY:
75
+ # for many-to-many, we need to import the association tables
76
+ assoc_tbl = orm_classes[f"{cls.name}{target_cls.name}"]
77
+ assoc_tbl_source_fk = to_snake_case(cls.name) + "_id"
78
+ assoc_tbl_target_fk = to_snake_case(target_cls.name) + "_id"
79
+ self.join_clauses[prop.name] = [
80
+ {
81
+ "class": assoc_tbl,
82
+ "condition": getattr(assoc_tbl, assoc_tbl_source_fk)
83
+ == getattr(self.orm_cls, self.id_prop.name),
84
+ "contains_eager": getattr(self.orm_cls, source_relprop),
85
+ },
86
+ {
87
+ "class": target_tbl,
88
+ "condition": getattr(assoc_tbl, assoc_tbl_target_fk)
89
+ == getattr(
90
+ target_tbl,
91
+ assert_not_null(target_cls.get_id_property()).name,
92
+ ),
93
+ "contains_eager": getattr(
94
+ assoc_tbl, to_snake_case(target_cls.name)
95
+ ),
96
+ },
97
+ ]
98
+ elif cardinality == Cardinality.ONE_TO_MANY:
99
+ # A -> B is 1:N, A.id is stored in B, this does not supported in SERA yet so we do not need
100
+ # to implement it
101
+ raise NotImplementedError()
102
+ else:
103
+ # A -> B is either 1:1 or N:1, we will store the foreign key is in A
104
+ # .join(B, A.<foreign_key> == B.id)
105
+ self.join_clauses[prop.name] = [
106
+ {
107
+ "class": target_tbl,
108
+ "condition": getattr(
109
+ target_tbl,
110
+ assert_not_null(target_cls.get_id_property()).name,
111
+ )
112
+ == getattr(self.orm_cls, source_fk),
113
+ "contains_eager": getattr(self.orm_cls, source_relprop),
114
+ },
115
+ ]
116
+
59
117
  @classmethod
60
118
  def get_instance(cls):
61
119
  """Get the singleton instance of the service."""
@@ -65,73 +123,103 @@ class BaseAsyncService(Generic[ID, R]):
65
123
  cls.instance = cls() # type: ignore[call-arg]
66
124
  return cls.instance
67
125
 
68
- async def get(
126
+ async def search(
69
127
  self,
70
128
  query: Query,
71
- limit: int,
72
- offset: int,
73
- unique: bool,
74
- sorted_by: list[str],
75
- group_by: list[str],
76
- fields: list[str],
77
129
  session: AsyncSession,
78
130
  ) -> QueryResult[R]:
79
131
  """Retrieving records matched a query.
80
132
 
81
133
  Args:
82
- query: The query to filter the records
83
- limit: The maximum number of records to return
84
- offset: The number of records to skip before returning results
85
- unique: Whether to return unique results only
86
- sorted_by: list of field names to sort by, prefix a field with '-' to sort that field in descending order
87
- group_by: list of field names to group by
88
- fields: list of field names to include in the results -- empty means all fields
134
+ query: The search query
135
+ session: The database session
89
136
  """
90
137
  q = self._select()
91
- if fields:
138
+
139
+ if len(query.fields) > 0:
92
140
  q = q.options(
93
- load_only(*[getattr(self.orm_cls, field) for field in fields])
141
+ load_only(*[getattr(self.orm_cls, field) for field in query.fields])
94
142
  )
95
- if unique:
143
+
144
+ if query.unique:
96
145
  q = q.distinct()
97
- for field in sorted_by:
98
- if field.startswith("-"):
99
- q = q.order_by(getattr(self.orm_cls, field[1:]).desc())
146
+
147
+ if len(query.sorted_by) > 0:
148
+ q = q.order_by(
149
+ *[
150
+ (
151
+ (
152
+ getattr(self.orm_cls, field.field).desc()
153
+ if field.order == "desc"
154
+ else getattr(self.orm_cls, field.field)
155
+ )
156
+ if field.prop is None
157
+ else (
158
+ getattr(self.prop2orm[field.prop], field.field).desc()
159
+ if field.order == "desc"
160
+ else getattr(self.prop2orm[field.prop], field.field)
161
+ )
162
+ )
163
+ for field in query.sorted_by
164
+ ]
165
+ )
166
+
167
+ if len(query.group_by) > 0:
168
+ q = q.group_by(
169
+ *[
170
+ (
171
+ getattr(self.orm_cls, field.field)
172
+ if field.prop is None
173
+ else getattr(self.prop2orm[field.prop], field.field)
174
+ )
175
+ for field in query.group_by
176
+ ]
177
+ )
178
+
179
+ for clause in query.conditions:
180
+ if clause.op == QueryOp.eq:
181
+ q = q.where(getattr(self.orm_cls, clause.field) == clause.value)
182
+ elif clause.op == QueryOp.ne:
183
+ q = q.where(getattr(self.orm_cls, clause.field) != clause.value)
184
+ elif clause.op == QueryOp.lt:
185
+ q = q.where(getattr(self.orm_cls, clause.field) < clause.value)
186
+ elif clause.op == QueryOp.lte:
187
+ q = q.where(getattr(self.orm_cls, clause.field) <= clause.value)
188
+ elif clause.op == QueryOp.gt:
189
+ q = q.where(getattr(self.orm_cls, clause.field) > clause.value)
190
+ elif clause.op == QueryOp.gte:
191
+ q = q.where(getattr(self.orm_cls, clause.field) >= clause.value)
192
+ elif clause.op == QueryOp.in_:
193
+ q = q.where(getattr(self.orm_cls, clause.field).in_(clause.value))
194
+ elif clause.op == QueryOp.not_in:
195
+ q = q.where(~getattr(self.orm_cls, clause.field).in_(clause.value))
100
196
  else:
101
- q = q.order_by(getattr(self.orm_cls, field))
102
- for field in group_by:
103
- q = q.group_by(getattr(self.orm_cls, field))
104
-
105
- for field, conditions in query.items():
106
- for op, value in conditions.items():
107
- # TODO: check if the operation is valid for the field.
108
- if op == QueryOp.eq:
109
- q = q.where(getattr(self.orm_cls, field) == value)
110
- elif op == QueryOp.ne:
111
- q = q.where(getattr(self.orm_cls, field) != value)
112
- elif op == QueryOp.lt:
113
- q = q.where(getattr(self.orm_cls, field) < value)
114
- elif op == QueryOp.lte:
115
- q = q.where(getattr(self.orm_cls, field) <= value)
116
- elif op == QueryOp.gt:
117
- q = q.where(getattr(self.orm_cls, field) > value)
118
- elif op == QueryOp.gte:
119
- q = q.where(getattr(self.orm_cls, field) >= value)
120
- elif op == QueryOp.in_:
121
- q = q.where(getattr(self.orm_cls, field).in_(value))
122
- elif op == QueryOp.not_in:
123
- q = q.where(~getattr(self.orm_cls, field).in_(value))
124
- else:
125
- assert op == QueryOp.fuzzy
126
- # Assuming fuzzy search is implemented as a full-text search
127
- q = q.where(
128
- func.to_tsvector(getattr(self.orm_cls, field)).match(value)
197
+ assert clause.op == QueryOp.fuzzy
198
+ # Assuming fuzzy search is implemented as a full-text search
199
+ q = q.where(
200
+ func.to_tsvector(getattr(self.orm_cls, clause.field)).match(
201
+ clause.value
129
202
  )
203
+ )
204
+
205
+ for join_condition in query.join_conditions:
206
+ for join_clause in self.join_clauses[join_condition.prop]:
207
+ q = q.join(
208
+ join_clause["class"],
209
+ join_clause["condition"],
210
+ isouter=join_condition.join_type == "left",
211
+ full=join_condition.join_type == "full",
212
+ ).options(contains_eager(join_clause["contains_eager"]))
213
+
214
+ print(">>>", join_clause)
130
215
 
131
216
  cq = select(func.count()).select_from(q.subquery())
132
- rq = q.limit(limit).offset(offset)
217
+ rq = q.limit(query.limit).offset(query.offset)
133
218
  records = self._process_result(await session.execute(rq)).scalars().all()
134
- total = (await session.execute(cq)).scalar_one()
219
+ if query.return_total:
220
+ total = (await session.execute(cq)).scalar_one()
221
+ else:
222
+ total = None
135
223
  return QueryResult(records, total)
136
224
 
137
225
  async def get_by_id(self, id: ID, session: AsyncSession) -> Optional[R]: