corvic-engine 0.3.0rc82__cp38-abi3-win_amd64.whl → 0.3.0rc83__cp38-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- corvic/{model → emodel}/__init__.py +40 -37
- corvic/emodel/_base_model.py +161 -0
- corvic/{model → emodel}/_completion_model.py +10 -8
- corvic/{model → emodel}/_feature_type.py +1 -1
- corvic/{model → emodel}/_feature_view.py +9 -7
- corvic/{model → emodel}/_pipeline.py +5 -5
- corvic/{model → emodel}/_proto_orm_convert.py +56 -54
- corvic/{model → emodel}/_resource.py +4 -4
- corvic/{model → emodel}/_room.py +4 -4
- corvic/{model → emodel}/_source.py +7 -7
- corvic/{model → emodel}/_space.py +9 -9
- corvic/engine/_native.pyd +0 -0
- corvic/system/__init__.py +2 -0
- corvic/system/_embedder.py +3 -0
- corvic/system/_image_embedder.py +50 -20
- corvic/transfer/__init__.py +43 -0
- corvic/transfer/_common_transformations.py +37 -0
- corvic/{model/_base_model.py → transfer/_orm_backed_proto.py} +116 -109
- corvic/transfer/py.typed +0 -0
- {corvic_engine-0.3.0rc82.dist-info → corvic_engine-0.3.0rc83.dist-info}/METADATA +1 -2
- {corvic_engine-0.3.0rc82.dist-info → corvic_engine-0.3.0rc83.dist-info}/RECORD +28 -24
- {corvic_engine-0.3.0rc82.dist-info → corvic_engine-0.3.0rc83.dist-info}/WHEEL +1 -1
- corvic_generated/orm/v1/agent_pb2.py +8 -8
- corvic_generated/orm/v1/agent_pb2.pyi +8 -4
- /corvic/{model → emodel}/_defaults.py +0 -0
- /corvic/{model → emodel}/_errors.py +0 -0
- /corvic/{model → emodel}/py.typed +0 -0
- {corvic_engine-0.3.0rc82.dist-info → corvic_engine-0.3.0rc83.dist-info}/licenses/LICENSE +0 -0
@@ -14,14 +14,14 @@ import sqlalchemy.orm as sa_orm
|
|
14
14
|
from sqlalchemy.orm.interfaces import LoaderOption
|
15
15
|
|
16
16
|
from corvic import eorm, op_graph, system
|
17
|
-
from corvic.
|
18
|
-
from corvic.
|
19
|
-
from corvic.
|
17
|
+
from corvic.emodel._base_model import StandardModel
|
18
|
+
from corvic.emodel._defaults import Defaults
|
19
|
+
from corvic.emodel._proto_orm_convert import (
|
20
20
|
source_delete_orms,
|
21
21
|
source_orm_to_proto,
|
22
22
|
source_proto_to_orm,
|
23
23
|
)
|
24
|
-
from corvic.
|
24
|
+
from corvic.emodel._resource import Resource, ResourceID
|
25
25
|
from corvic.result import InvalidArgumentError, NotFoundError, Ok
|
26
26
|
from corvic.table import Table
|
27
27
|
from corvic_generated.model.v1alpha import models_pb2
|
@@ -45,7 +45,7 @@ def foreign_key(
|
|
45
45
|
)
|
46
46
|
|
47
47
|
|
48
|
-
class Source(
|
48
|
+
class Source(StandardModel[SourceID, models_pb2.Source, eorm.Source]):
|
49
49
|
"""Sources describe how resources should be treated.
|
50
50
|
|
51
51
|
Example:
|
@@ -261,8 +261,8 @@ class Source(BelongsToRoomModel[SourceID, models_pb2.Source, eorm.Source]):
|
|
261
261
|
Example:
|
262
262
|
>>> with_feature_types(
|
263
263
|
>>> {
|
264
|
-
>>> "id": corvic.
|
265
|
-
>>> "customer_id": corvic.
|
264
|
+
>>> "id": corvic.emodel.feature_type.primary_key(),
|
265
|
+
>>> "customer_id": corvic.emodel.feature_type.foreign_key(
|
266
266
|
>>> customer_source.id
|
267
267
|
>>> ),
|
268
268
|
>>> },
|
@@ -14,10 +14,10 @@ import sqlalchemy as sa
|
|
14
14
|
from sqlalchemy import orm as sa_orm
|
15
15
|
|
16
16
|
from corvic import eorm, op_graph, system
|
17
|
-
from corvic.
|
18
|
-
from corvic.
|
19
|
-
from corvic.
|
20
|
-
from corvic.
|
17
|
+
from corvic.emodel._base_model import StandardModel
|
18
|
+
from corvic.emodel._defaults import Defaults
|
19
|
+
from corvic.emodel._feature_view import FeatureView, FeatureViewEdgeTableMetadata
|
20
|
+
from corvic.emodel._proto_orm_convert import (
|
21
21
|
space_delete_orms,
|
22
22
|
space_orm_to_proto,
|
23
23
|
space_proto_to_orm,
|
@@ -53,13 +53,13 @@ name_to_proto_embedding_model = {
|
|
53
53
|
def image_model_proto_to_name(image_model: embedding_models_pb2.ImageModel):
|
54
54
|
match image_model:
|
55
55
|
case embedding_models_pb2.IMAGE_MODEL_CUSTOM:
|
56
|
-
return Ok(
|
56
|
+
return Ok(system.RandomImageEmbedder.model_name())
|
57
57
|
case embedding_models_pb2.IMAGE_MODEL_CLIP:
|
58
|
-
return Ok(
|
58
|
+
return Ok(system.Clip.model_name())
|
59
59
|
case embedding_models_pb2.IMAGE_MODEL_IDENTITY:
|
60
|
-
return Ok(
|
60
|
+
return Ok(system.IdentityImageEmbedder.model_name())
|
61
61
|
case embedding_models_pb2.IMAGE_MODEL_SIGLIP2:
|
62
|
-
return Ok(
|
62
|
+
return Ok(system.SigLIP2.model_name())
|
63
63
|
case embedding_models_pb2.IMAGE_MODEL_UNSPECIFIED:
|
64
64
|
return Ok("")
|
65
65
|
case _:
|
@@ -114,7 +114,7 @@ name_to_proto_image_model = {
|
|
114
114
|
}
|
115
115
|
|
116
116
|
|
117
|
-
class Space(
|
117
|
+
class Space(StandardModel[SpaceID, models_pb2.Space, eorm.Space]):
|
118
118
|
"""Spaces apply embedding methods to FeatureViews.
|
119
119
|
|
120
120
|
Example:
|
corvic/engine/_native.pyd
CHANGED
Binary file
|
corvic/system/__init__.py
CHANGED
@@ -22,6 +22,7 @@ from corvic.system._image_embedder import (
|
|
22
22
|
CombinedImageEmbedder,
|
23
23
|
IdentityImageEmbedder,
|
24
24
|
RandomImageEmbedder,
|
25
|
+
SigLIP2,
|
25
26
|
image_from_bytes,
|
26
27
|
)
|
27
28
|
from corvic.system._planner import OpGraphPlanner, ValidateFirstExecutor
|
@@ -88,6 +89,7 @@ __all__ = [
|
|
88
89
|
"OpGraphPlanner",
|
89
90
|
"RandomImageEmbedder",
|
90
91
|
"RandomTextEmbedder",
|
92
|
+
"SigLIP2",
|
91
93
|
"SigLIP2Text",
|
92
94
|
"StagingDB",
|
93
95
|
"StorageManager",
|
corvic/system/_embedder.py
CHANGED
@@ -71,6 +71,9 @@ class EmbedImageResult:
|
|
71
71
|
class ImageEmbedder(Protocol):
|
72
72
|
"""Use a model to embed text."""
|
73
73
|
|
74
|
+
@classmethod
|
75
|
+
def model_name(cls) -> str: ...
|
76
|
+
|
74
77
|
def embed(
|
75
78
|
self, context: EmbedImageContext
|
76
79
|
) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
|
corvic/system/_image_embedder.py
CHANGED
@@ -27,6 +27,10 @@ class RandomImageEmbedder(ImageEmbedder):
|
|
27
27
|
Useful for testing.
|
28
28
|
"""
|
29
29
|
|
30
|
+
@classmethod
|
31
|
+
def model_name(cls) -> str:
|
32
|
+
return "random"
|
33
|
+
|
30
34
|
def embed(
|
31
35
|
self, context: EmbedImageContext
|
32
36
|
) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
|
@@ -82,6 +86,10 @@ class LoadedModels:
|
|
82
86
|
class HFModelImageEmbedder(ImageEmbedder):
|
83
87
|
"""Generic image embedder from hugging face models."""
|
84
88
|
|
89
|
+
@classmethod
|
90
|
+
@abc.abstractmethod
|
91
|
+
def model_revision(cls) -> str: ...
|
92
|
+
|
85
93
|
@abc.abstractmethod
|
86
94
|
def _load_models(self) -> LoadedModels: ...
|
87
95
|
|
@@ -165,6 +173,14 @@ class Clip(HFModelImageEmbedder):
|
|
165
173
|
overcoming several major challenges in computer vision.
|
166
174
|
"""
|
167
175
|
|
176
|
+
@classmethod
|
177
|
+
def model_name(cls) -> str:
|
178
|
+
return "openai/clip-vit-base-patch32"
|
179
|
+
|
180
|
+
@classmethod
|
181
|
+
def model_revision(cls) -> str:
|
182
|
+
return "5812e510083bb2d23fa43778a39ac065d205ed4d"
|
183
|
+
|
168
184
|
def _load_models(self) -> LoadedModels:
|
169
185
|
from transformers.models.clip import (
|
170
186
|
CLIPModel,
|
@@ -174,15 +190,15 @@ class Clip(HFModelImageEmbedder):
|
|
174
190
|
model = cast(
|
175
191
|
AutoModel,
|
176
192
|
CLIPModel.from_pretrained( # pyright: ignore[reportUnknownMemberType]
|
177
|
-
pretrained_model_name_or_path=
|
178
|
-
revision=
|
193
|
+
pretrained_model_name_or_path=self.model_name(),
|
194
|
+
revision=self.model_revision(),
|
179
195
|
),
|
180
196
|
)
|
181
197
|
processor = cast(
|
182
198
|
AutoProcessor,
|
183
199
|
CLIPProcessor.from_pretrained( # pyright: ignore[reportUnknownMemberType]
|
184
|
-
pretrained_model_name_or_path=
|
185
|
-
revision=
|
200
|
+
pretrained_model_name_or_path=self.model_name(),
|
201
|
+
revision=self.model_revision(),
|
186
202
|
use_fast=False,
|
187
203
|
),
|
188
204
|
)
|
@@ -192,6 +208,14 @@ class Clip(HFModelImageEmbedder):
|
|
192
208
|
class SigLIP2(HFModelImageEmbedder):
|
193
209
|
"""SigLIP2 image embedder."""
|
194
210
|
|
211
|
+
@classmethod
|
212
|
+
def model_name(cls) -> str:
|
213
|
+
return "google/siglip2-base-patch16-512"
|
214
|
+
|
215
|
+
@classmethod
|
216
|
+
def model_revision(cls) -> str:
|
217
|
+
return "a89f5c5093f902bf39d3cd4d81d2c09867f0724b"
|
218
|
+
|
195
219
|
def _load_models(self):
|
196
220
|
from transformers.models.auto.modeling_auto import AutoModel
|
197
221
|
from transformers.models.auto.processing_auto import AutoProcessor
|
@@ -199,16 +223,16 @@ class SigLIP2(HFModelImageEmbedder):
|
|
199
223
|
model = cast(
|
200
224
|
AutoModel,
|
201
225
|
AutoModel.from_pretrained( # pyright: ignore[reportUnknownMemberType]
|
202
|
-
pretrained_model_name_or_path=
|
203
|
-
revision=
|
226
|
+
pretrained_model_name_or_path=self.model_name(),
|
227
|
+
revision=self.model_revision(),
|
204
228
|
device_map="auto",
|
205
229
|
),
|
206
230
|
)
|
207
231
|
processor = cast(
|
208
232
|
AutoProcessor,
|
209
233
|
AutoProcessor.from_pretrained( # pyright: ignore[reportUnknownMemberType]
|
210
|
-
pretrained_model_name_or_path=
|
211
|
-
revision=
|
234
|
+
pretrained_model_name_or_path=self.model_name(),
|
235
|
+
revision=self.model_revision(),
|
212
236
|
use_fast=True,
|
213
237
|
),
|
214
238
|
)
|
@@ -216,23 +240,25 @@ class SigLIP2(HFModelImageEmbedder):
|
|
216
240
|
|
217
241
|
|
218
242
|
class CombinedImageEmbedder(ImageEmbedder):
|
243
|
+
@classmethod
|
244
|
+
def model_name(cls) -> str:
|
245
|
+
raise InvalidArgumentError(
|
246
|
+
"CombinedImageEmbedder does not have a specific model name"
|
247
|
+
)
|
248
|
+
|
219
249
|
def __init__(self):
|
220
|
-
self.
|
221
|
-
|
222
|
-
|
250
|
+
self._embedders = {
|
251
|
+
emb.model_name(): emb()
|
252
|
+
for emb in [Clip, SigLIP2, RandomImageEmbedder, IdentityImageEmbedder]
|
253
|
+
}
|
223
254
|
|
224
255
|
def embed(
|
225
256
|
self, context: EmbedImageContext
|
226
257
|
) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
return self._clip_embedder.embed(context)
|
232
|
-
case "siglip2":
|
233
|
-
return self._siglip2_embedder.embed(context)
|
234
|
-
case _:
|
235
|
-
return InvalidArgumentError(f"Unknown model name {context.model_name}")
|
258
|
+
embedder = self._embedders.get(context.model_name, None)
|
259
|
+
if not embedder:
|
260
|
+
return InvalidArgumentError(f"Unknown model name {context.model_name}")
|
261
|
+
return embedder.embed(context)
|
236
262
|
|
237
263
|
async def aembed(
|
238
264
|
self,
|
@@ -254,6 +280,10 @@ class IdentityImageEmbedder(ImageEmbedder):
|
|
254
280
|
- The resulting list is truncated or padded to match the expected vector length.
|
255
281
|
"""
|
256
282
|
|
283
|
+
@classmethod
|
284
|
+
def model_name(cls) -> str:
|
285
|
+
return "identity"
|
286
|
+
|
257
287
|
def _image_to_embedding(
|
258
288
|
self, image: "Image.Image", vector_length: int, *, normalization: bool = False
|
259
289
|
) -> list[float]:
|
@@ -0,0 +1,43 @@
|
|
1
|
+
"""Common machinery for using protocol buffers as transfer objects."""
|
2
|
+
|
3
|
+
from corvic.transfer._common_transformations import (
|
4
|
+
UNCOMMITTED_ID_PREFIX,
|
5
|
+
OrmIdT,
|
6
|
+
generate_uncommitted_id_str,
|
7
|
+
non_empty_timestamp_to_datetime,
|
8
|
+
translate_orm_id,
|
9
|
+
)
|
10
|
+
from corvic.transfer._orm_backed_proto import (
|
11
|
+
HasIdOrmBackedProto,
|
12
|
+
HasProtoSelf,
|
13
|
+
OrmBackedProto,
|
14
|
+
OrmHasIdModel,
|
15
|
+
OrmHasIdT,
|
16
|
+
OrmModel,
|
17
|
+
OrmT,
|
18
|
+
ProtoHasIdModel,
|
19
|
+
ProtoHasIdT,
|
20
|
+
ProtoModel,
|
21
|
+
ProtoT,
|
22
|
+
UsesOrmID,
|
23
|
+
)
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"UNCOMMITTED_ID_PREFIX",
|
27
|
+
"generate_uncommitted_id_str",
|
28
|
+
"OrmIdT",
|
29
|
+
"OrmModel",
|
30
|
+
"UsesOrmID",
|
31
|
+
"OrmT",
|
32
|
+
"ProtoT",
|
33
|
+
"HasProtoSelf",
|
34
|
+
"ProtoModel",
|
35
|
+
"ProtoHasIdT",
|
36
|
+
"OrmBackedProto",
|
37
|
+
"ProtoHasIdModel",
|
38
|
+
"OrmHasIdT",
|
39
|
+
"OrmHasIdModel",
|
40
|
+
"HasIdOrmBackedProto",
|
41
|
+
"translate_orm_id",
|
42
|
+
"non_empty_timestamp_to_datetime",
|
43
|
+
]
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import datetime
|
2
|
+
import uuid
|
3
|
+
from typing import Any, TypeVar
|
4
|
+
|
5
|
+
from google.protobuf import timestamp_pb2
|
6
|
+
|
7
|
+
from corvic import orm
|
8
|
+
from corvic.result import Ok
|
9
|
+
|
10
|
+
OrmIdT = TypeVar("OrmIdT", bound=orm.BaseID[Any])
|
11
|
+
|
12
|
+
UNCOMMITTED_ID_PREFIX = "__uncommitted_object-"
|
13
|
+
|
14
|
+
|
15
|
+
def generate_uncommitted_id_str():
|
16
|
+
return f"{UNCOMMITTED_ID_PREFIX}{uuid.uuid4()}"
|
17
|
+
|
18
|
+
|
19
|
+
def translate_orm_id(
|
20
|
+
obj_id: str, id_class: type[OrmIdT]
|
21
|
+
) -> Ok[OrmIdT | None] | orm.InvalidORMIdentifierError:
|
22
|
+
if obj_id.startswith(UNCOMMITTED_ID_PREFIX):
|
23
|
+
return Ok(None)
|
24
|
+
parsed_obj_id = id_class(obj_id)
|
25
|
+
match parsed_obj_id.to_db():
|
26
|
+
case orm.InvalidORMIdentifierError() as err:
|
27
|
+
return err
|
28
|
+
case Ok():
|
29
|
+
return Ok(parsed_obj_id)
|
30
|
+
|
31
|
+
|
32
|
+
def non_empty_timestamp_to_datetime(
|
33
|
+
timestamp: timestamp_pb2.Timestamp,
|
34
|
+
) -> datetime.datetime | None:
|
35
|
+
if timestamp != timestamp_pb2.Timestamp():
|
36
|
+
return timestamp.ToDatetime(tzinfo=datetime.UTC)
|
37
|
+
return None
|
@@ -3,66 +3,59 @@ import contextlib
|
|
3
3
|
import copy
|
4
4
|
import datetime
|
5
5
|
import functools
|
6
|
-
import uuid
|
7
6
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
8
|
-
from typing import Final, Generic, Self
|
7
|
+
from typing import Any, Final, Generic, Protocol, Self, TypeVar
|
9
8
|
|
10
9
|
import sqlalchemy as sa
|
11
10
|
import sqlalchemy.orm as sa_orm
|
12
11
|
import structlog
|
13
12
|
from google.protobuf import timestamp_pb2
|
14
13
|
|
15
|
-
from corvic import
|
16
|
-
from corvic.
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
OrmObj,
|
22
|
-
ProtoBelongsToOrgObj,
|
23
|
-
ProtoBelongsToRoomObj,
|
24
|
-
ProtoObj,
|
25
|
-
)
|
26
|
-
from corvic.result import (
|
27
|
-
InvalidArgumentError,
|
28
|
-
NotFoundError,
|
29
|
-
Ok,
|
30
|
-
UnavailableError,
|
14
|
+
from corvic import orm, system
|
15
|
+
from corvic.result import InvalidArgumentError, NotFoundError, Ok, UnavailableError
|
16
|
+
from corvic.transfer._common_transformations import (
|
17
|
+
OrmIdT,
|
18
|
+
generate_uncommitted_id_str,
|
19
|
+
non_empty_timestamp_to_datetime,
|
31
20
|
)
|
32
21
|
|
33
22
|
_logger = structlog.get_logger()
|
34
23
|
|
35
|
-
_EMPTY_PROTO_TIMESTAMP = timestamp_pb2.Timestamp(seconds=0, nanos=0)
|
36
24
|
|
25
|
+
class OrmModel(Protocol):
|
26
|
+
@sa.ext.hybrid.hybrid_property
|
27
|
+
def created_at(self) -> datetime.datetime | None: ...
|
28
|
+
|
29
|
+
@created_at.inplace.expression
|
30
|
+
@classmethod
|
31
|
+
def _created_at_expression(cls): ...
|
32
|
+
|
33
|
+
|
34
|
+
class OrmHasIdModel(OrmModel, Protocol[OrmIdT]):
|
35
|
+
id: sa_orm.Mapped[OrmIdT | None]
|
37
36
|
|
38
|
-
def non_empty_timestamp_to_datetime(
|
39
|
-
timestamp: timestamp_pb2.Timestamp,
|
40
|
-
) -> datetime.datetime | None:
|
41
|
-
if timestamp != _EMPTY_PROTO_TIMESTAMP:
|
42
|
-
return timestamp.ToDatetime(tzinfo=datetime.UTC)
|
43
|
-
return None
|
44
37
|
|
38
|
+
OrmT = TypeVar("OrmT", bound=OrmModel)
|
39
|
+
OrmHasIdT = TypeVar("OrmHasIdT", bound=OrmHasIdModel[Any])
|
45
40
|
|
46
|
-
def _generate_uncommitted_id_str():
|
47
|
-
return f"{UNCOMMITTED_ID_PREFIX}{uuid.uuid4()}"
|
48
41
|
|
42
|
+
class ProtoModel(Protocol):
|
43
|
+
created_at: timestamp_pb2.Timestamp
|
49
44
|
|
50
|
-
@contextlib.contextmanager
|
51
|
-
def _create_or_join_session(
|
52
|
-
client: system.Client, existing_session: sa_orm.Session | None
|
53
|
-
) -> Iterator[sa_orm.Session]:
|
54
|
-
if existing_session:
|
55
|
-
yield existing_session
|
56
|
-
else:
|
57
|
-
with eorm.Session(client.sa_engine) as session:
|
58
|
-
yield session
|
59
45
|
|
46
|
+
class ProtoHasIdModel(ProtoModel, Protocol):
|
47
|
+
id: str
|
60
48
|
|
61
|
-
|
49
|
+
|
50
|
+
ProtoT = TypeVar("ProtoT", bound=ProtoModel)
|
51
|
+
ProtoHasIdT = TypeVar("ProtoHasIdT", bound=ProtoHasIdModel)
|
52
|
+
|
53
|
+
|
54
|
+
class HasProtoSelf(Generic[ProtoT], abc.ABC):
|
62
55
|
client: Final[system.Client]
|
63
|
-
proto_self: Final[
|
56
|
+
proto_self: Final[ProtoT]
|
64
57
|
|
65
|
-
def __init__(self, client: system.Client, proto_self:
|
58
|
+
def __init__(self, client: system.Client, proto_self: ProtoT):
|
66
59
|
self.proto_self = proto_self
|
67
60
|
self.client = client
|
68
61
|
|
@@ -71,22 +64,22 @@ class HasProtoSelf(Generic[ProtoObj], abc.ABC):
|
|
71
64
|
return non_empty_timestamp_to_datetime(self.proto_self.created_at)
|
72
65
|
|
73
66
|
|
74
|
-
class UsesOrmID(Generic[
|
75
|
-
def __init__(self, client: system.Client, proto_self:
|
67
|
+
class UsesOrmID(Generic[OrmIdT, ProtoHasIdT], HasProtoSelf[ProtoHasIdT]):
|
68
|
+
def __init__(self, client: system.Client, proto_self: ProtoHasIdT):
|
76
69
|
if not proto_self.id:
|
77
|
-
proto_self.id =
|
70
|
+
proto_self.id = generate_uncommitted_id_str()
|
78
71
|
super().__init__(client, proto_self)
|
79
72
|
|
80
73
|
@classmethod
|
81
74
|
@abc.abstractmethod
|
82
|
-
def id_class(cls) -> type[
|
75
|
+
def id_class(cls) -> type[OrmIdT]: ...
|
83
76
|
|
84
77
|
@functools.cached_property
|
85
|
-
def id(self) ->
|
78
|
+
def id(self) -> OrmIdT:
|
86
79
|
return self.id_class().from_str(self.proto_self.id)
|
87
80
|
|
88
81
|
|
89
|
-
class
|
82
|
+
class OrmBackedProto(Generic[ProtoT, OrmT], HasProtoSelf[ProtoT]):
|
90
83
|
"""Base for orm wrappers providing a unified update mechanism."""
|
91
84
|
|
92
85
|
@property
|
@@ -95,51 +88,29 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
|
|
95
88
|
|
96
89
|
@classmethod
|
97
90
|
@abc.abstractmethod
|
98
|
-
def orm_class(cls) -> type[
|
91
|
+
def orm_class(cls) -> type[OrmT]: ...
|
99
92
|
|
100
93
|
@classmethod
|
101
94
|
@abc.abstractmethod
|
102
|
-
def orm_to_proto(cls, orm_obj:
|
95
|
+
def orm_to_proto(cls, orm_obj: OrmT) -> ProtoT: ...
|
103
96
|
|
104
97
|
@classmethod
|
105
98
|
@abc.abstractmethod
|
106
99
|
def proto_to_orm(
|
107
|
-
cls, proto_obj:
|
108
|
-
) -> Ok[
|
109
|
-
|
110
|
-
@classmethod
|
111
|
-
@abc.abstractmethod
|
112
|
-
def delete_by_ids(
|
113
|
-
cls, ids: Sequence[IdType], session: eorm.Session
|
114
|
-
) -> Ok[None] | InvalidArgumentError: ...
|
115
|
-
|
116
|
-
@classmethod
|
117
|
-
def load_proto_for(
|
118
|
-
cls,
|
119
|
-
obj_id: IdType,
|
120
|
-
client: system.Client,
|
121
|
-
existing_session: sa_orm.Session | None = None,
|
122
|
-
) -> Ok[ProtoObj] | NotFoundError:
|
123
|
-
"""Create a model object by loading it from the database."""
|
124
|
-
with _create_or_join_session(client, existing_session) as session:
|
125
|
-
orm_self = session.get(cls.orm_class(), obj_id)
|
126
|
-
if orm_self is None:
|
127
|
-
return NotFoundError("object with given id does not exist", id=obj_id)
|
128
|
-
proto_self = cls.orm_to_proto(orm_self)
|
129
|
-
return Ok(proto_self)
|
100
|
+
cls, proto_obj: ProtoT, session: orm.Session
|
101
|
+
) -> Ok[OrmT] | InvalidArgumentError: ...
|
130
102
|
|
131
103
|
@classmethod
|
132
104
|
def _generate_query_results(
|
133
|
-
cls, query: sa.Select[tuple[
|
134
|
-
) -> Iterator[
|
105
|
+
cls, query: sa.Select[tuple[OrmT]], session: sa_orm.Session
|
106
|
+
) -> Iterator[OrmT]:
|
135
107
|
it = iter(session.scalars(query))
|
136
108
|
while True:
|
137
109
|
try:
|
138
110
|
yield from it
|
139
111
|
except Exception:
|
140
112
|
_logger.exception(
|
141
|
-
"omitting
|
142
|
-
+ "failed to parse source from database entry",
|
113
|
+
"omitting model from list: " + "failed to parse database entry",
|
143
114
|
)
|
144
115
|
else:
|
145
116
|
break
|
@@ -155,31 +126,27 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
|
|
155
126
|
client: system.Client,
|
156
127
|
*,
|
157
128
|
limit: int | None = None,
|
158
|
-
room_id: eorm.RoomID | None = None,
|
159
129
|
created_before: datetime.datetime | None = None,
|
160
|
-
ids: Iterable[IdType] | None = None,
|
161
130
|
additional_query_transform: Callable[
|
162
|
-
[sa.Select[tuple[
|
131
|
+
[sa.Select[tuple[OrmT]]], sa.Select[tuple[OrmT]]
|
163
132
|
]
|
164
133
|
| None = None,
|
165
134
|
existing_session: sa_orm.Session | None = None,
|
166
|
-
) -> Ok[list[
|
135
|
+
) -> Ok[list[ProtoT]] | NotFoundError | InvalidArgumentError:
|
167
136
|
"""List sources that exist in storage."""
|
168
137
|
orm_class = cls.orm_class()
|
169
|
-
with
|
138
|
+
with (
|
139
|
+
contextlib.nullcontext(existing_session)
|
140
|
+
if existing_session
|
141
|
+
else orm.Session(client.sa_engine) as session
|
142
|
+
):
|
170
143
|
query = sa.select(orm_class).order_by(sa.desc(orm_class.created_at))
|
171
144
|
if limit is not None:
|
172
145
|
if limit < 0:
|
173
146
|
return InvalidArgumentError("limit cannot be negative")
|
174
147
|
query = query.limit(limit)
|
175
|
-
if room_id:
|
176
|
-
if session.get(eorm.Room, room_id) is None:
|
177
|
-
return NotFoundError("room not found", room_id=room_id)
|
178
|
-
query = query.filter_by(room_id=room_id)
|
179
148
|
if created_before:
|
180
149
|
query = query.filter(orm_class.created_at < created_before)
|
181
|
-
if ids is not None:
|
182
|
-
query = query.filter(orm_class.id.in_(ids))
|
183
150
|
if additional_query_transform:
|
184
151
|
query = additional_query_transform(query)
|
185
152
|
extra_orm_loaders = cls.orm_load_options()
|
@@ -198,7 +165,7 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
|
|
198
165
|
This overwrites the entry at id in the database so that future readers will see
|
199
166
|
this object. One of `id` or `derived_from_id` cannot be empty or None.
|
200
167
|
"""
|
201
|
-
with
|
168
|
+
with orm.Session(self.client.sa_engine) as session:
|
202
169
|
try:
|
203
170
|
new_orm_self = self.proto_to_orm(
|
204
171
|
self.proto_self, session
|
@@ -230,7 +197,7 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
|
|
230
197
|
return InvalidArgumentError.from_(err)
|
231
198
|
|
232
199
|
def add_to_session(
|
233
|
-
self, session:
|
200
|
+
self, session: orm.Session
|
234
201
|
) -> Ok[None] | InvalidArgumentError | UnavailableError:
|
235
202
|
"""Like commit, but just calls session.flush to check for database errors.
|
236
203
|
|
@@ -246,8 +213,39 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
|
|
246
213
|
return self._dbapi_error_to_result(err)
|
247
214
|
return Ok(None)
|
248
215
|
|
216
|
+
|
217
|
+
class HasIdOrmBackedProto(
|
218
|
+
Generic[OrmIdT, ProtoHasIdT, OrmHasIdT],
|
219
|
+
UsesOrmID[OrmIdT, ProtoHasIdT],
|
220
|
+
OrmBackedProto[ProtoHasIdT, OrmHasIdT],
|
221
|
+
):
|
222
|
+
@classmethod
|
223
|
+
@abc.abstractmethod
|
224
|
+
def delete_by_ids(
|
225
|
+
cls, ids: Sequence[OrmIdT], session: orm.Session
|
226
|
+
) -> Ok[None] | InvalidArgumentError: ...
|
227
|
+
|
228
|
+
@classmethod
|
229
|
+
def load_proto_for(
|
230
|
+
cls,
|
231
|
+
obj_id: OrmIdT,
|
232
|
+
client: system.Client,
|
233
|
+
existing_session: sa_orm.Session | None = None,
|
234
|
+
) -> Ok[ProtoHasIdT] | NotFoundError:
|
235
|
+
"""Create a model object by loading it from the database."""
|
236
|
+
with (
|
237
|
+
contextlib.nullcontext(existing_session)
|
238
|
+
if existing_session
|
239
|
+
else orm.Session(client.sa_engine) as session
|
240
|
+
):
|
241
|
+
orm_self = session.get(cls.orm_class(), obj_id)
|
242
|
+
if orm_self is None:
|
243
|
+
return NotFoundError("object with given id does not exist", id=obj_id)
|
244
|
+
proto_self = cls.orm_to_proto(orm_self)
|
245
|
+
return Ok(proto_self)
|
246
|
+
|
249
247
|
def delete(self) -> Ok[Self] | NotFoundError | InvalidArgumentError:
|
250
|
-
with
|
248
|
+
with orm.Session(
|
251
249
|
self.client.sa_engine, expire_on_commit=False, autoflush=False
|
252
250
|
) as session:
|
253
251
|
try:
|
@@ -270,24 +268,33 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
|
|
270
268
|
)
|
271
269
|
)
|
272
270
|
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
271
|
+
@classmethod
|
272
|
+
def list_as_proto(
|
273
|
+
cls,
|
274
|
+
client: system.Client,
|
275
|
+
*,
|
276
|
+
limit: int | None = None,
|
277
|
+
created_before: datetime.datetime | None = None,
|
278
|
+
ids: Iterable[OrmIdT] | None = None,
|
279
|
+
additional_query_transform: Callable[
|
280
|
+
[sa.Select[tuple[OrmHasIdT]]], sa.Select[tuple[OrmHasIdT]]
|
281
|
+
]
|
282
|
+
| None = None,
|
283
|
+
existing_session: sa_orm.Session | None = None,
|
284
|
+
) -> Ok[list[ProtoHasIdT]] | NotFoundError | InvalidArgumentError:
|
285
|
+
def query_transform(
|
286
|
+
query: sa.Select[tuple[OrmHasIdT]],
|
287
|
+
) -> sa.Select[tuple[OrmHasIdT]]:
|
288
|
+
if ids:
|
289
|
+
query = query.where(cls.orm_class().id.in_(ids))
|
290
|
+
if additional_query_transform:
|
291
|
+
query = additional_query_transform(query)
|
292
|
+
return query
|
293
|
+
|
294
|
+
return super().list_as_proto(
|
295
|
+
client,
|
296
|
+
limit=limit,
|
297
|
+
created_before=created_before,
|
298
|
+
additional_query_transform=query_transform,
|
299
|
+
existing_session=existing_session,
|
300
|
+
)
|
corvic/transfer/py.typed
ADDED
File without changes
|