corvic-engine 0.3.0rc82__cp38-abi3-win_amd64.whl → 0.3.0rc84__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,14 +14,14 @@ import sqlalchemy.orm as sa_orm
14
14
  from sqlalchemy.orm.interfaces import LoaderOption
15
15
 
16
16
  from corvic import eorm, op_graph, system
17
- from corvic.model._base_model import BelongsToRoomModel
18
- from corvic.model._defaults import Defaults
19
- from corvic.model._proto_orm_convert import (
17
+ from corvic.emodel._base_model import StandardModel
18
+ from corvic.emodel._defaults import Defaults
19
+ from corvic.emodel._proto_orm_convert import (
20
20
  source_delete_orms,
21
21
  source_orm_to_proto,
22
22
  source_proto_to_orm,
23
23
  )
24
- from corvic.model._resource import Resource, ResourceID
24
+ from corvic.emodel._resource import Resource, ResourceID
25
25
  from corvic.result import InvalidArgumentError, NotFoundError, Ok
26
26
  from corvic.table import Table
27
27
  from corvic_generated.model.v1alpha import models_pb2
@@ -45,7 +45,7 @@ def foreign_key(
45
45
  )
46
46
 
47
47
 
48
- class Source(BelongsToRoomModel[SourceID, models_pb2.Source, eorm.Source]):
48
+ class Source(StandardModel[SourceID, models_pb2.Source, eorm.Source]):
49
49
  """Sources describe how resources should be treated.
50
50
 
51
51
  Example:
@@ -261,8 +261,8 @@ class Source(BelongsToRoomModel[SourceID, models_pb2.Source, eorm.Source]):
261
261
  Example:
262
262
  >>> with_feature_types(
263
263
  >>> {
264
- >>> "id": corvic.model.feature_type.primary_key(),
265
- >>> "customer_id": corvic.model.feature_type.foreign_key(
264
+ >>> "id": corvic.emodel.feature_type.primary_key(),
265
+ >>> "customer_id": corvic.emodel.feature_type.foreign_key(
266
266
  >>> customer_source.id
267
267
  >>> ),
268
268
  >>> },
@@ -14,10 +14,10 @@ import sqlalchemy as sa
14
14
  from sqlalchemy import orm as sa_orm
15
15
 
16
16
  from corvic import eorm, op_graph, system
17
- from corvic.model._base_model import BelongsToRoomModel
18
- from corvic.model._defaults import Defaults
19
- from corvic.model._feature_view import FeatureView, FeatureViewEdgeTableMetadata
20
- from corvic.model._proto_orm_convert import (
17
+ from corvic.emodel._base_model import StandardModel
18
+ from corvic.emodel._defaults import Defaults
19
+ from corvic.emodel._feature_view import FeatureView, FeatureViewEdgeTableMetadata
20
+ from corvic.emodel._proto_orm_convert import (
21
21
  space_delete_orms,
22
22
  space_orm_to_proto,
23
23
  space_proto_to_orm,
@@ -53,13 +53,13 @@ name_to_proto_embedding_model = {
53
53
  def image_model_proto_to_name(image_model: embedding_models_pb2.ImageModel):
54
54
  match image_model:
55
55
  case embedding_models_pb2.IMAGE_MODEL_CUSTOM:
56
- return Ok("random")
56
+ return Ok(system.RandomImageEmbedder.model_name())
57
57
  case embedding_models_pb2.IMAGE_MODEL_CLIP:
58
- return Ok("openai/clip-vit-base-patch32")
58
+ return Ok(system.Clip.model_name())
59
59
  case embedding_models_pb2.IMAGE_MODEL_IDENTITY:
60
- return Ok("identity")
60
+ return Ok(system.IdentityImageEmbedder.model_name())
61
61
  case embedding_models_pb2.IMAGE_MODEL_SIGLIP2:
62
- return Ok("google/siglip2-base-patch16-512")
62
+ return Ok(system.SigLIP2.model_name())
63
63
  case embedding_models_pb2.IMAGE_MODEL_UNSPECIFIED:
64
64
  return Ok("")
65
65
  case _:
@@ -114,7 +114,7 @@ name_to_proto_image_model = {
114
114
  }
115
115
 
116
116
 
117
- class Space(BelongsToRoomModel[SpaceID, models_pb2.Space, eorm.Space]):
117
+ class Space(StandardModel[SpaceID, models_pb2.Space, eorm.Space]):
118
118
  """Spaces apply embedding methods to FeatureViews.
119
119
 
120
120
  Example:
corvic/engine/_native.pyd CHANGED
Binary file
corvic/system/__init__.py CHANGED
@@ -22,6 +22,7 @@ from corvic.system._image_embedder import (
22
22
  CombinedImageEmbedder,
23
23
  IdentityImageEmbedder,
24
24
  RandomImageEmbedder,
25
+ SigLIP2,
25
26
  image_from_bytes,
26
27
  )
27
28
  from corvic.system._planner import OpGraphPlanner, ValidateFirstExecutor
@@ -88,6 +89,7 @@ __all__ = [
88
89
  "OpGraphPlanner",
89
90
  "RandomImageEmbedder",
90
91
  "RandomTextEmbedder",
92
+ "SigLIP2",
91
93
  "SigLIP2Text",
92
94
  "StagingDB",
93
95
  "StorageManager",
@@ -71,6 +71,9 @@ class EmbedImageResult:
71
71
  class ImageEmbedder(Protocol):
72
72
  """Use a model to embed text."""
73
73
 
74
+ @classmethod
75
+ def model_name(cls) -> str: ...
76
+
74
77
  def embed(
75
78
  self, context: EmbedImageContext
76
79
  ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
@@ -27,6 +27,10 @@ class RandomImageEmbedder(ImageEmbedder):
27
27
  Useful for testing.
28
28
  """
29
29
 
30
+ @classmethod
31
+ def model_name(cls) -> str:
32
+ return "random"
33
+
30
34
  def embed(
31
35
  self, context: EmbedImageContext
32
36
  ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
@@ -82,6 +86,10 @@ class LoadedModels:
82
86
  class HFModelImageEmbedder(ImageEmbedder):
83
87
  """Generic image embedder from hugging face models."""
84
88
 
89
+ @classmethod
90
+ @abc.abstractmethod
91
+ def model_revision(cls) -> str: ...
92
+
85
93
  @abc.abstractmethod
86
94
  def _load_models(self) -> LoadedModels: ...
87
95
 
@@ -165,7 +173,17 @@ class Clip(HFModelImageEmbedder):
165
173
  overcoming several major challenges in computer vision.
166
174
  """
167
175
 
176
+ @classmethod
177
+ def model_name(cls) -> str:
178
+ return "openai/clip-vit-base-patch32"
179
+
180
+ @classmethod
181
+ def model_revision(cls) -> str:
182
+ return "5812e510083bb2d23fa43778a39ac065d205ed4d"
183
+
168
184
  def _load_models(self) -> LoadedModels:
185
+ from transformers.models.auto.modeling_auto import AutoModel
186
+ from transformers.models.auto.processing_auto import AutoProcessor
169
187
  from transformers.models.clip import (
170
188
  CLIPModel,
171
189
  CLIPProcessor,
@@ -174,15 +192,15 @@ class Clip(HFModelImageEmbedder):
174
192
  model = cast(
175
193
  AutoModel,
176
194
  CLIPModel.from_pretrained( # pyright: ignore[reportUnknownMemberType]
177
- pretrained_model_name_or_path="openai/clip-vit-base-patch32",
178
- revision="5812e510083bb2d23fa43778a39ac065d205ed4d",
195
+ pretrained_model_name_or_path=self.model_name(),
196
+ revision=self.model_revision(),
179
197
  ),
180
198
  )
181
199
  processor = cast(
182
200
  AutoProcessor,
183
201
  CLIPProcessor.from_pretrained( # pyright: ignore[reportUnknownMemberType]
184
- pretrained_model_name_or_path="openai/clip-vit-base-patch32",
185
- revision="5812e510083bb2d23fa43778a39ac065d205ed4d",
202
+ pretrained_model_name_or_path=self.model_name(),
203
+ revision=self.model_revision(),
186
204
  use_fast=False,
187
205
  ),
188
206
  )
@@ -192,6 +210,14 @@ class Clip(HFModelImageEmbedder):
192
210
  class SigLIP2(HFModelImageEmbedder):
193
211
  """SigLIP2 image embedder."""
194
212
 
213
+ @classmethod
214
+ def model_name(cls) -> str:
215
+ return "google/siglip2-base-patch16-512"
216
+
217
+ @classmethod
218
+ def model_revision(cls) -> str:
219
+ return "a89f5c5093f902bf39d3cd4d81d2c09867f0724b"
220
+
195
221
  def _load_models(self):
196
222
  from transformers.models.auto.modeling_auto import AutoModel
197
223
  from transformers.models.auto.processing_auto import AutoProcessor
@@ -199,16 +225,16 @@ class SigLIP2(HFModelImageEmbedder):
199
225
  model = cast(
200
226
  AutoModel,
201
227
  AutoModel.from_pretrained( # pyright: ignore[reportUnknownMemberType]
202
- pretrained_model_name_or_path="google/siglip2-base-patch16-512",
203
- revision="a89f5c5093f902bf39d3cd4d81d2c09867f0724b",
228
+ pretrained_model_name_or_path=self.model_name(),
229
+ revision=self.model_revision(),
204
230
  device_map="auto",
205
231
  ),
206
232
  )
207
233
  processor = cast(
208
234
  AutoProcessor,
209
235
  AutoProcessor.from_pretrained( # pyright: ignore[reportUnknownMemberType]
210
- pretrained_model_name_or_path="google/siglip2-base-patch16-512",
211
- revision="a89f5c5093f902bf39d3cd4d81d2c09867f0724b",
236
+ pretrained_model_name_or_path=self.model_name(),
237
+ revision=self.model_revision(),
212
238
  use_fast=True,
213
239
  ),
214
240
  )
@@ -216,23 +242,25 @@ class SigLIP2(HFModelImageEmbedder):
216
242
 
217
243
 
218
244
  class CombinedImageEmbedder(ImageEmbedder):
245
+ @classmethod
246
+ def model_name(cls) -> str:
247
+ raise InvalidArgumentError(
248
+ "CombinedImageEmbedder does not have a specific model name"
249
+ )
250
+
219
251
  def __init__(self):
220
- self._clip_embedder = Clip()
221
- self._siglip2_embedder = SigLIP2()
222
- self._random_embedder = RandomImageEmbedder()
252
+ self._embedders = {
253
+ emb.model_name(): emb()
254
+ for emb in [Clip, SigLIP2, RandomImageEmbedder, IdentityImageEmbedder]
255
+ }
223
256
 
224
257
  def embed(
225
258
  self, context: EmbedImageContext
226
259
  ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
227
- match context.model_name:
228
- case "random":
229
- return self._random_embedder.embed(context)
230
- case "clip":
231
- return self._clip_embedder.embed(context)
232
- case "siglip2":
233
- return self._siglip2_embedder.embed(context)
234
- case _:
235
- return InvalidArgumentError(f"Unknown model name {context.model_name}")
260
+ embedder = self._embedders.get(context.model_name, None)
261
+ if not embedder:
262
+ return InvalidArgumentError(f"Unknown model name {context.model_name}")
263
+ return embedder.embed(context)
236
264
 
237
265
  async def aembed(
238
266
  self,
@@ -254,6 +282,10 @@ class IdentityImageEmbedder(ImageEmbedder):
254
282
  - The resulting list is truncated or padded to match the expected vector length.
255
283
  """
256
284
 
285
+ @classmethod
286
+ def model_name(cls) -> str:
287
+ return "identity"
288
+
257
289
  def _image_to_embedding(
258
290
  self, image: "Image.Image", vector_length: int, *, normalization: bool = False
259
291
  ) -> list[float]:
@@ -0,0 +1,43 @@
1
+ """Common machinery for using protocol buffers as transfer objects."""
2
+
3
+ from corvic.transfer._common_transformations import (
4
+ UNCOMMITTED_ID_PREFIX,
5
+ OrmIdT,
6
+ generate_uncommitted_id_str,
7
+ non_empty_timestamp_to_datetime,
8
+ translate_orm_id,
9
+ )
10
+ from corvic.transfer._orm_backed_proto import (
11
+ HasIdOrmBackedProto,
12
+ HasProtoSelf,
13
+ OrmBackedProto,
14
+ OrmHasIdModel,
15
+ OrmHasIdT,
16
+ OrmModel,
17
+ OrmT,
18
+ ProtoHasIdModel,
19
+ ProtoHasIdT,
20
+ ProtoModel,
21
+ ProtoT,
22
+ UsesOrmID,
23
+ )
24
+
25
+ __all__ = [
26
+ "UNCOMMITTED_ID_PREFIX",
27
+ "generate_uncommitted_id_str",
28
+ "OrmIdT",
29
+ "OrmModel",
30
+ "UsesOrmID",
31
+ "OrmT",
32
+ "ProtoT",
33
+ "HasProtoSelf",
34
+ "ProtoModel",
35
+ "ProtoHasIdT",
36
+ "OrmBackedProto",
37
+ "ProtoHasIdModel",
38
+ "OrmHasIdT",
39
+ "OrmHasIdModel",
40
+ "HasIdOrmBackedProto",
41
+ "translate_orm_id",
42
+ "non_empty_timestamp_to_datetime",
43
+ ]
@@ -0,0 +1,37 @@
1
+ import datetime
2
+ import uuid
3
+ from typing import Any, TypeVar
4
+
5
+ from google.protobuf import timestamp_pb2
6
+
7
+ from corvic import orm
8
+ from corvic.result import Ok
9
+
10
+ OrmIdT = TypeVar("OrmIdT", bound=orm.BaseID[Any])
11
+
12
+ UNCOMMITTED_ID_PREFIX = "__uncommitted_object-"
13
+
14
+
15
+ def generate_uncommitted_id_str():
16
+ return f"{UNCOMMITTED_ID_PREFIX}{uuid.uuid4()}"
17
+
18
+
19
+ def translate_orm_id(
20
+ obj_id: str, id_class: type[OrmIdT]
21
+ ) -> Ok[OrmIdT | None] | orm.InvalidORMIdentifierError:
22
+ if obj_id.startswith(UNCOMMITTED_ID_PREFIX):
23
+ return Ok(None)
24
+ parsed_obj_id = id_class(obj_id)
25
+ match parsed_obj_id.to_db():
26
+ case orm.InvalidORMIdentifierError() as err:
27
+ return err
28
+ case Ok():
29
+ return Ok(parsed_obj_id)
30
+
31
+
32
+ def non_empty_timestamp_to_datetime(
33
+ timestamp: timestamp_pb2.Timestamp,
34
+ ) -> datetime.datetime | None:
35
+ if timestamp != timestamp_pb2.Timestamp():
36
+ return timestamp.ToDatetime(tzinfo=datetime.UTC)
37
+ return None
@@ -3,66 +3,59 @@ import contextlib
3
3
  import copy
4
4
  import datetime
5
5
  import functools
6
- import uuid
7
6
  from collections.abc import Callable, Iterable, Iterator, Sequence
8
- from typing import Final, Generic, Self
7
+ from typing import Any, Final, Generic, Protocol, Self, TypeVar
9
8
 
10
9
  import sqlalchemy as sa
11
10
  import sqlalchemy.orm as sa_orm
12
11
  import structlog
13
12
  from google.protobuf import timestamp_pb2
14
13
 
15
- from corvic import eorm, system
16
- from corvic.model._proto_orm_convert import (
17
- UNCOMMITTED_ID_PREFIX,
18
- IdType,
19
- OrmBelongsToOrgObj,
20
- OrmBelongsToRoomObj,
21
- OrmObj,
22
- ProtoBelongsToOrgObj,
23
- ProtoBelongsToRoomObj,
24
- ProtoObj,
25
- )
26
- from corvic.result import (
27
- InvalidArgumentError,
28
- NotFoundError,
29
- Ok,
30
- UnavailableError,
14
+ from corvic import orm, system
15
+ from corvic.result import InvalidArgumentError, NotFoundError, Ok, UnavailableError
16
+ from corvic.transfer._common_transformations import (
17
+ OrmIdT,
18
+ generate_uncommitted_id_str,
19
+ non_empty_timestamp_to_datetime,
31
20
  )
32
21
 
33
22
  _logger = structlog.get_logger()
34
23
 
35
- _EMPTY_PROTO_TIMESTAMP = timestamp_pb2.Timestamp(seconds=0, nanos=0)
36
24
 
25
+ class OrmModel(Protocol):
26
+ @sa.ext.hybrid.hybrid_property
27
+ def created_at(self) -> datetime.datetime | None: ...
28
+
29
+ @created_at.inplace.expression
30
+ @classmethod
31
+ def _created_at_expression(cls): ...
32
+
33
+
34
+ class OrmHasIdModel(OrmModel, Protocol[OrmIdT]):
35
+ id: sa_orm.Mapped[OrmIdT | None]
37
36
 
38
- def non_empty_timestamp_to_datetime(
39
- timestamp: timestamp_pb2.Timestamp,
40
- ) -> datetime.datetime | None:
41
- if timestamp != _EMPTY_PROTO_TIMESTAMP:
42
- return timestamp.ToDatetime(tzinfo=datetime.UTC)
43
- return None
44
37
 
38
+ OrmT = TypeVar("OrmT", bound=OrmModel)
39
+ OrmHasIdT = TypeVar("OrmHasIdT", bound=OrmHasIdModel[Any])
45
40
 
46
- def _generate_uncommitted_id_str():
47
- return f"{UNCOMMITTED_ID_PREFIX}{uuid.uuid4()}"
48
41
 
42
+ class ProtoModel(Protocol):
43
+ created_at: timestamp_pb2.Timestamp
49
44
 
50
- @contextlib.contextmanager
51
- def _create_or_join_session(
52
- client: system.Client, existing_session: sa_orm.Session | None
53
- ) -> Iterator[sa_orm.Session]:
54
- if existing_session:
55
- yield existing_session
56
- else:
57
- with eorm.Session(client.sa_engine) as session:
58
- yield session
59
45
 
46
+ class ProtoHasIdModel(ProtoModel, Protocol):
47
+ id: str
60
48
 
61
- class HasProtoSelf(Generic[ProtoObj], abc.ABC):
49
+
50
+ ProtoT = TypeVar("ProtoT", bound=ProtoModel)
51
+ ProtoHasIdT = TypeVar("ProtoHasIdT", bound=ProtoHasIdModel)
52
+
53
+
54
+ class HasProtoSelf(Generic[ProtoT], abc.ABC):
62
55
  client: Final[system.Client]
63
- proto_self: Final[ProtoObj]
56
+ proto_self: Final[ProtoT]
64
57
 
65
- def __init__(self, client: system.Client, proto_self: ProtoObj):
58
+ def __init__(self, client: system.Client, proto_self: ProtoT):
66
59
  self.proto_self = proto_self
67
60
  self.client = client
68
61
 
@@ -71,22 +64,22 @@ class HasProtoSelf(Generic[ProtoObj], abc.ABC):
71
64
  return non_empty_timestamp_to_datetime(self.proto_self.created_at)
72
65
 
73
66
 
74
- class UsesOrmID(Generic[IdType, ProtoObj], HasProtoSelf[ProtoObj]):
75
- def __init__(self, client: system.Client, proto_self: ProtoObj):
67
+ class UsesOrmID(Generic[OrmIdT, ProtoHasIdT], HasProtoSelf[ProtoHasIdT]):
68
+ def __init__(self, client: system.Client, proto_self: ProtoHasIdT):
76
69
  if not proto_self.id:
77
- proto_self.id = _generate_uncommitted_id_str()
70
+ proto_self.id = generate_uncommitted_id_str()
78
71
  super().__init__(client, proto_self)
79
72
 
80
73
  @classmethod
81
74
  @abc.abstractmethod
82
- def id_class(cls) -> type[IdType]: ...
75
+ def id_class(cls) -> type[OrmIdT]: ...
83
76
 
84
77
  @functools.cached_property
85
- def id(self) -> IdType:
78
+ def id(self) -> OrmIdT:
86
79
  return self.id_class().from_str(self.proto_self.id)
87
80
 
88
81
 
89
- class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
82
+ class OrmBackedProto(Generic[ProtoT, OrmT], HasProtoSelf[ProtoT]):
90
83
  """Base for orm wrappers providing a unified update mechanism."""
91
84
 
92
85
  @property
@@ -95,51 +88,29 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
95
88
 
96
89
  @classmethod
97
90
  @abc.abstractmethod
98
- def orm_class(cls) -> type[OrmObj]: ...
91
+ def orm_class(cls) -> type[OrmT]: ...
99
92
 
100
93
  @classmethod
101
94
  @abc.abstractmethod
102
- def orm_to_proto(cls, orm_obj: OrmObj) -> ProtoObj: ...
95
+ def orm_to_proto(cls, orm_obj: OrmT) -> ProtoT: ...
103
96
 
104
97
  @classmethod
105
98
  @abc.abstractmethod
106
99
  def proto_to_orm(
107
- cls, proto_obj: ProtoObj, session: eorm.Session
108
- ) -> Ok[OrmObj] | InvalidArgumentError: ...
109
-
110
- @classmethod
111
- @abc.abstractmethod
112
- def delete_by_ids(
113
- cls, ids: Sequence[IdType], session: eorm.Session
114
- ) -> Ok[None] | InvalidArgumentError: ...
115
-
116
- @classmethod
117
- def load_proto_for(
118
- cls,
119
- obj_id: IdType,
120
- client: system.Client,
121
- existing_session: sa_orm.Session | None = None,
122
- ) -> Ok[ProtoObj] | NotFoundError:
123
- """Create a model object by loading it from the database."""
124
- with _create_or_join_session(client, existing_session) as session:
125
- orm_self = session.get(cls.orm_class(), obj_id)
126
- if orm_self is None:
127
- return NotFoundError("object with given id does not exist", id=obj_id)
128
- proto_self = cls.orm_to_proto(orm_self)
129
- return Ok(proto_self)
100
+ cls, proto_obj: ProtoT, session: orm.Session
101
+ ) -> Ok[OrmT] | InvalidArgumentError: ...
130
102
 
131
103
  @classmethod
132
104
  def _generate_query_results(
133
- cls, query: sa.Select[tuple[OrmObj]], session: sa_orm.Session
134
- ) -> Iterator[OrmObj]:
105
+ cls, query: sa.Select[tuple[OrmT]], session: sa_orm.Session
106
+ ) -> Iterator[OrmT]:
135
107
  it = iter(session.scalars(query))
136
108
  while True:
137
109
  try:
138
110
  yield from it
139
111
  except Exception:
140
112
  _logger.exception(
141
- "omitting source from list: "
142
- + "failed to parse source from database entry",
113
+ "omitting model from list: " + "failed to parse database entry",
143
114
  )
144
115
  else:
145
116
  break
@@ -155,31 +126,27 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
155
126
  client: system.Client,
156
127
  *,
157
128
  limit: int | None = None,
158
- room_id: eorm.RoomID | None = None,
159
129
  created_before: datetime.datetime | None = None,
160
- ids: Iterable[IdType] | None = None,
161
130
  additional_query_transform: Callable[
162
- [sa.Select[tuple[OrmObj]]], sa.Select[tuple[OrmObj]]
131
+ [sa.Select[tuple[OrmT]]], sa.Select[tuple[OrmT]]
163
132
  ]
164
133
  | None = None,
165
134
  existing_session: sa_orm.Session | None = None,
166
- ) -> Ok[list[ProtoObj]] | NotFoundError | InvalidArgumentError:
135
+ ) -> Ok[list[ProtoT]] | NotFoundError | InvalidArgumentError:
167
136
  """List sources that exist in storage."""
168
137
  orm_class = cls.orm_class()
169
- with _create_or_join_session(client, existing_session) as session:
138
+ with (
139
+ contextlib.nullcontext(existing_session)
140
+ if existing_session
141
+ else orm.Session(client.sa_engine) as session
142
+ ):
170
143
  query = sa.select(orm_class).order_by(sa.desc(orm_class.created_at))
171
144
  if limit is not None:
172
145
  if limit < 0:
173
146
  return InvalidArgumentError("limit cannot be negative")
174
147
  query = query.limit(limit)
175
- if room_id:
176
- if session.get(eorm.Room, room_id) is None:
177
- return NotFoundError("room not found", room_id=room_id)
178
- query = query.filter_by(room_id=room_id)
179
148
  if created_before:
180
149
  query = query.filter(orm_class.created_at < created_before)
181
- if ids is not None:
182
- query = query.filter(orm_class.id.in_(ids))
183
150
  if additional_query_transform:
184
151
  query = additional_query_transform(query)
185
152
  extra_orm_loaders = cls.orm_load_options()
@@ -198,7 +165,7 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
198
165
  This overwrites the entry at id in the database so that future readers will see
199
166
  this object. One of `id` or `derived_from_id` cannot be empty or None.
200
167
  """
201
- with eorm.Session(self.client.sa_engine) as session:
168
+ with orm.Session(self.client.sa_engine) as session:
202
169
  try:
203
170
  new_orm_self = self.proto_to_orm(
204
171
  self.proto_self, session
@@ -230,7 +197,7 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
230
197
  return InvalidArgumentError.from_(err)
231
198
 
232
199
  def add_to_session(
233
- self, session: eorm.Session
200
+ self, session: orm.Session
234
201
  ) -> Ok[None] | InvalidArgumentError | UnavailableError:
235
202
  """Like commit, but just calls session.flush to check for database errors.
236
203
 
@@ -246,8 +213,39 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
246
213
  return self._dbapi_error_to_result(err)
247
214
  return Ok(None)
248
215
 
216
+
217
+ class HasIdOrmBackedProto(
218
+ Generic[OrmIdT, ProtoHasIdT, OrmHasIdT],
219
+ UsesOrmID[OrmIdT, ProtoHasIdT],
220
+ OrmBackedProto[ProtoHasIdT, OrmHasIdT],
221
+ ):
222
+ @classmethod
223
+ @abc.abstractmethod
224
+ def delete_by_ids(
225
+ cls, ids: Sequence[OrmIdT], session: orm.Session
226
+ ) -> Ok[None] | InvalidArgumentError: ...
227
+
228
+ @classmethod
229
+ def load_proto_for(
230
+ cls,
231
+ obj_id: OrmIdT,
232
+ client: system.Client,
233
+ existing_session: sa_orm.Session | None = None,
234
+ ) -> Ok[ProtoHasIdT] | NotFoundError:
235
+ """Create a model object by loading it from the database."""
236
+ with (
237
+ contextlib.nullcontext(existing_session)
238
+ if existing_session
239
+ else orm.Session(client.sa_engine) as session
240
+ ):
241
+ orm_self = session.get(cls.orm_class(), obj_id)
242
+ if orm_self is None:
243
+ return NotFoundError("object with given id does not exist", id=obj_id)
244
+ proto_self = cls.orm_to_proto(orm_self)
245
+ return Ok(proto_self)
246
+
249
247
  def delete(self) -> Ok[Self] | NotFoundError | InvalidArgumentError:
250
- with eorm.Session(
248
+ with orm.Session(
251
249
  self.client.sa_engine, expire_on_commit=False, autoflush=False
252
250
  ) as session:
253
251
  try:
@@ -270,24 +268,33 @@ class BaseModel(Generic[IdType, ProtoObj, OrmObj], UsesOrmID[IdType, ProtoObj]):
270
268
  )
271
269
  )
272
270
 
273
-
274
- class BelongsToOrgModel(
275
- Generic[IdType, ProtoBelongsToOrgObj, OrmBelongsToOrgObj],
276
- BaseModel[IdType, ProtoBelongsToOrgObj, OrmBelongsToOrgObj],
277
- ):
278
- """Base for orm wrappers with org mixin providing a unified update mechanism."""
279
-
280
- @property
281
- def org_id(self) -> eorm.OrgID:
282
- return eorm.OrgID().from_str(self.proto_self.org_id)
283
-
284
-
285
- class BelongsToRoomModel(
286
- Generic[IdType, ProtoBelongsToRoomObj, OrmBelongsToRoomObj],
287
- BelongsToOrgModel[IdType, ProtoBelongsToRoomObj, OrmBelongsToRoomObj],
288
- ):
289
- """Base for orm wrappers with room mixin providing a unified update mechanism."""
290
-
291
- @property
292
- def room_id(self) -> eorm.RoomID:
293
- return eorm.RoomID().from_str(self.proto_self.room_id)
271
+ @classmethod
272
+ def list_as_proto(
273
+ cls,
274
+ client: system.Client,
275
+ *,
276
+ limit: int | None = None,
277
+ created_before: datetime.datetime | None = None,
278
+ ids: Iterable[OrmIdT] | None = None,
279
+ additional_query_transform: Callable[
280
+ [sa.Select[tuple[OrmHasIdT]]], sa.Select[tuple[OrmHasIdT]]
281
+ ]
282
+ | None = None,
283
+ existing_session: sa_orm.Session | None = None,
284
+ ) -> Ok[list[ProtoHasIdT]] | NotFoundError | InvalidArgumentError:
285
+ def query_transform(
286
+ query: sa.Select[tuple[OrmHasIdT]],
287
+ ) -> sa.Select[tuple[OrmHasIdT]]:
288
+ if ids:
289
+ query = query.where(cls.orm_class().id.in_(ids))
290
+ if additional_query_transform:
291
+ query = additional_query_transform(query)
292
+ return query
293
+
294
+ return super().list_as_proto(
295
+ client,
296
+ limit=limit,
297
+ created_before=created_before,
298
+ additional_query_transform=query_transform,
299
+ existing_session=existing_session,
300
+ )
File without changes