corvic-engine 0.3.0rc42__cp38-abi3-win_amd64.whl → 0.3.0rc44__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -84,8 +84,21 @@ def stable_rank(
84
84
  case Ok():
85
85
  pass
86
86
 
87
- numerator = norm(embeddings, order="fro")
88
- denominator = norm(embeddings, order=2) ** 2
87
+ try:
88
+ numerator = norm(embeddings, order="fro")
89
+ denominator = norm(embeddings, order=2) ** 2
90
+ except np.linalg.LinAlgError as err:
91
+ return InvalidArgumentError.from_(err)
92
+
93
+ if not np.isfinite(numerator):
94
+ return InvalidArgumentError("embeddings norm_fro is not finite")
95
+
96
+ if not np.isfinite(denominator):
97
+ return InvalidArgumentError("embeddings norm_2 is not finite")
98
+
99
+ if not denominator:
100
+ return InvalidArgumentError("embeddings norm_2 is zero")
101
+
89
102
  metric = numerator / denominator
90
103
  if normalize:
91
104
  metric = 1 - 1 / (1 + metric)
@@ -162,7 +175,10 @@ def ne_sum(
162
175
  pass
163
176
 
164
177
  covariance = np.cov(embeddings.T)
165
- eigenvalues = linalg.eigvals(covariance)
178
+ try:
179
+ eigenvalues = linalg.eigvals(covariance)
180
+ except np.linalg.LinAlgError as err:
181
+ return InvalidArgumentError.from_(err)
166
182
 
167
183
  # Discard imaginary part
168
184
  eigenvalues = eigenvalues.real
@@ -172,7 +188,7 @@ def ne_sum(
172
188
  eigenvalues = eigenvalues[sorted_indices]
173
189
 
174
190
  if eigenvalues[0] == 0:
175
- return Ok(0)
191
+ return Ok(0.0)
176
192
 
177
193
  ne_sum_value = float(np.sum(eigenvalues) / eigenvalues[0])
178
194
  if normalize:
@@ -227,7 +243,11 @@ def condition_number(
227
243
  case Ok():
228
244
  pass
229
245
 
230
- metric = float(np.linalg.cond(embeddings, p=p))
246
+ try:
247
+ metric = float(np.linalg.cond(embeddings, p=p))
248
+ except np.linalg.LinAlgError as err:
249
+ return InvalidArgumentError.from_(err)
250
+
231
251
  if normalize:
232
252
  metric = 1 - 1 / (1 + metric)
233
253
  return Ok(metric)
@@ -246,7 +266,11 @@ def rcondition_number(
246
266
  case Ok():
247
267
  pass
248
268
 
249
- metric = float(1 / np.linalg.cond(embeddings, p=p))
269
+ try:
270
+ metric = float(1 / np.linalg.cond(embeddings, p=p))
271
+ except np.linalg.LinAlgError as err:
272
+ return InvalidArgumentError.from_(err)
273
+
250
274
  if normalize:
251
275
  metric = 1 - 1 / (1 + metric)
252
276
  return Ok(metric)
corvic/engine/_native.pyd CHANGED
Binary file
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import copy
6
6
  import datetime
7
7
  from collections.abc import Iterable, Sequence
8
- from typing import TypeAlias
8
+ from typing import Literal, TypeAlias
9
9
 
10
10
  from sqlalchemy import orm as sa_orm
11
11
 
@@ -19,6 +19,7 @@ from corvic.model._proto_orm_convert import (
19
19
  )
20
20
  from corvic.result import InvalidArgumentError, NotFoundError, Ok
21
21
  from corvic_generated.model.v1alpha import models_pb2
22
+ from corvic_generated.orm.v1 import completion_model_pb2
22
23
 
23
24
  CompletionModelID: TypeAlias = orm.CompletionModelID
24
25
  OrgID: TypeAlias = orm.OrgID
@@ -62,12 +63,46 @@ class CompletionModel(
62
63
  return OrgID(self.proto_self.org_id)
63
64
 
64
65
  @property
65
- def model_name(self) -> str:
66
- return self.proto_self.model_name
66
+ def provider(self) -> Literal["openai-generic", "azure-openai"] | None:
67
+ match self.proto_self.parameters.WhichOneof("params"):
68
+ case "azure_openai_parameters":
69
+ return "azure-openai"
70
+ case "generic_openai_parameters":
71
+ return "openai-generic"
72
+ case _:
73
+ return None
67
74
 
68
75
  @property
69
- def endpoint(self) -> str:
70
- return self.proto_self.endpoint
76
+ def parameters(
77
+ self,
78
+ ) -> (
79
+ completion_model_pb2.AzureOpenAIParameters
80
+ | completion_model_pb2.GenericOpenAIParameters
81
+ | None
82
+ ):
83
+ match self.provider:
84
+ case "azure-openai":
85
+ return self.azure_openai_parameters
86
+ case "openai-generic":
87
+ return self.generic_openai_parameters
88
+ case None:
89
+ return None
90
+
91
+ @property
92
+ def azure_openai_parameters(
93
+ self,
94
+ ) -> completion_model_pb2.AzureOpenAIParameters | None:
95
+ if self.proto_self.parameters.HasField("azure_openai_parameters"):
96
+ return self.proto_self.parameters.azure_openai_parameters
97
+ return None
98
+
99
+ @property
100
+ def generic_openai_parameters(
101
+ self,
102
+ ) -> completion_model_pb2.GenericOpenAIParameters | None:
103
+ if self.proto_self.parameters.HasField("generic_openai_parameters"):
104
+ return self.proto_self.parameters.generic_openai_parameters
105
+ return None
71
106
 
72
107
  @property
73
108
  def secret_api_key(self) -> str:
@@ -83,8 +118,7 @@ class CompletionModel(
83
118
  *,
84
119
  name: str,
85
120
  description: str,
86
- model_name: str,
87
- endpoint: str,
121
+ parameters: completion_model_pb2.CompletionModelParameters,
88
122
  secret_api_key: str,
89
123
  client: system.Client | None = None,
90
124
  ):
@@ -94,8 +128,7 @@ class CompletionModel(
94
128
  models_pb2.CompletionModel(
95
129
  name=name,
96
130
  description=description,
97
- model_name=model_name,
98
- endpoint=endpoint,
131
+ parameters=parameters,
99
132
  secret_api_key=secret_api_key,
100
133
  ),
101
134
  )
@@ -153,14 +186,11 @@ class CompletionModel(
153
186
  proto_self.description = description
154
187
  return CompletionModel(self.client, proto_self)
155
188
 
156
- def with_model_name(self, model_name: str) -> CompletionModel:
157
- proto_self = copy.deepcopy(self.proto_self)
158
- proto_self.model_name = model_name
159
- return CompletionModel(self.client, proto_self)
160
-
161
- def with_endpoint(self, endpoint: str) -> CompletionModel:
189
+ def with_parameters(
190
+ self, parameters: completion_model_pb2.CompletionModelParameters
191
+ ) -> CompletionModel:
162
192
  proto_self = copy.deepcopy(self.proto_self)
163
- proto_self.endpoint = endpoint
193
+ proto_self.parameters.CopyFrom(parameters)
164
194
  return CompletionModel(self.client, proto_self)
165
195
 
166
196
  def with_secret_api_key(self, secret_api_key: str) -> CompletionModel:
@@ -1075,6 +1075,7 @@ class FeatureView(BaseModel[FeatureViewID, models_pb2.FeatureView, orm.FeatureVi
1075
1075
  )
1076
1076
 
1077
1077
  proto_feature_view_source = models_pb2.FeatureViewSource(
1078
+ room_id=str(source.room_id),
1078
1079
  table_op_graph=new_table.op_graph.to_proto(),
1079
1080
  drop_disconnected=drop_disconnected,
1080
1081
  source=source.proto_self,
@@ -66,6 +66,7 @@ def _translate_orm_ids(
66
66
  | models_pb2.Space()
67
67
  | models_pb2.Agent()
68
68
  | models_pb2.Pipeline()
69
+ | models_pb2.FeatureViewSource()
69
70
  ):
70
71
  room_id = orm.RoomID(proto_obj.room_id)
71
72
  match room_id.to_db():
@@ -73,7 +74,7 @@ def _translate_orm_ids(
73
74
  return err
74
75
  case Ok():
75
76
  pass
76
- case models_pb2.FeatureViewSource() | models_pb2.CompletionModel():
77
+ case models_pb2.CompletionModel():
77
78
  room_id = None
78
79
  case models_pb2.Room():
79
80
  room_id = cast(orm.RoomID, obj_id)
@@ -146,6 +147,7 @@ def feature_view_source_orm_to_proto(
146
147
  ) -> models_pb2.FeatureViewSource:
147
148
  return models_pb2.FeatureViewSource(
148
149
  id=str(feature_view_source_orm.id),
150
+ room_id=str(feature_view_source_orm.room_id),
149
151
  source=source_orm_to_proto(feature_view_source_orm.source),
150
152
  table_op_graph=feature_view_source_orm.table_op_graph,
151
153
  drop_disconnected=feature_view_source_orm.drop_disconnected,
@@ -225,8 +227,7 @@ def completion_model_orm_to_proto(
225
227
  name=completion_model_orm.name,
226
228
  description=completion_model_orm.description,
227
229
  org_id=str(completion_model_orm.org_id),
228
- model_name=completion_model_orm.model_name,
229
- endpoint=completion_model_orm.endpoint,
230
+ parameters=completion_model_orm.parameters,
230
231
  secret_api_key=completion_model_orm.secret_api_key,
231
232
  created_at=timestamp_orm_to_proto(completion_model_orm.created_at),
232
233
  )
@@ -263,7 +264,9 @@ def resource_proto_to_orm(
263
264
  latest_event=proto_obj.recent_events[-1] if proto_obj.recent_events else None,
264
265
  room_id=ids.room_id,
265
266
  source_associations=[
266
- orm.SourceResourceAssociation(source_id=src_id, resource_id=ids.obj_id)
267
+ orm.SourceResourceAssociation(
268
+ room_id=ids.room_id, source_id=src_id, resource_id=ids.obj_id
269
+ )
267
270
  for src_id in source_ids
268
271
  ],
269
272
  )
@@ -317,7 +320,10 @@ def pipeline_proto_to_orm( # noqa: C901
317
320
  session.add(resource_orm)
318
321
  session.merge(
319
322
  orm.PipelineInput(
320
- pipeline=orm_obj, resource=resource_orm, name=name
323
+ room_id=resource_orm.room_id,
324
+ pipeline=orm_obj,
325
+ resource=resource_orm,
326
+ name=name,
321
327
  )
322
328
  )
323
329
 
@@ -333,7 +339,12 @@ def pipeline_proto_to_orm( # noqa: C901
333
339
  else:
334
340
  session.add(source_orm)
335
341
  session.merge(
336
- orm.PipelineOutput(pipeline=orm_obj, source=source_orm, name=name)
342
+ orm.PipelineOutput(
343
+ room_id=source_orm.room_id,
344
+ pipeline=orm_obj,
345
+ source=source_orm,
346
+ name=name,
347
+ )
337
348
  )
338
349
  if proto_obj.org_id:
339
350
  org_id = orm.OrgID(proto_obj.org_id)
@@ -356,7 +367,9 @@ def source_proto_to_orm(
356
367
  resource_id = orm.ResourceID(proto_obj.resource_id)
357
368
  if resource_id:
358
369
  associations = [
359
- orm.SourceResourceAssociation(source_id=ids.obj_id, resource_id=resource_id)
370
+ orm.SourceResourceAssociation(
371
+ room_id=ids.room_id, source_id=ids.obj_id, resource_id=resource_id
372
+ )
360
373
  ]
361
374
  else:
362
375
  associations = list[orm.SourceResourceAssociation]()
@@ -499,6 +512,7 @@ def feature_view_source_proto_to_orm(
499
512
  else:
500
513
  session.add(source)
501
514
  orm_obj = orm.FeatureViewSource(
515
+ room_id=source.room_id,
502
516
  table_op_graph=proto_obj.table_op_graph,
503
517
  drop_disconnected=proto_obj.drop_disconnected,
504
518
  source=source,
@@ -543,8 +557,7 @@ def completion_model_proto_to_orm(
543
557
  id=ids.obj_id,
544
558
  name=proto_obj.name,
545
559
  description=proto_obj.description,
546
- model_name=proto_obj.model_name,
547
- endpoint=proto_obj.endpoint,
560
+ parameters=proto_obj.parameters,
548
561
  secret_api_key=proto_obj.secret_api_key,
549
562
  )
550
563
 
corvic/model/_space.py CHANGED
@@ -782,50 +782,67 @@ class TabularSpace(Space):
782
782
 
783
783
  embedding_column_tmp_name = f"__embed-{uuid.uuid4()}"
784
784
 
785
- op = (
786
- op.concat_list(
787
- column_names=embedding_column_tmp_names,
788
- concat_list_column_name=embedding_column_tmp_name,
789
- )
790
- .and_then(
791
- lambda t,
792
- embedding_column_name=embedding_column_tmp_name: op_graph.op.coordinates_from_embedding( # noqa: E501
793
- table=t,
794
- embedding_column_name=embedding_column_name,
795
- output_dims=parameters.ndim,
796
- )
785
+ # Avoid 0 padding for spaces with small numbers of columns
786
+ target_list_length = min(parameters.ndim, len(embedding_column_tmp_names))
787
+
788
+ def reduce_dimension(
789
+ op: op_graph.Op,
790
+ embedding_column_tmp_name=embedding_column_tmp_name,
791
+ target_list_length=target_list_length,
792
+ ):
793
+ return op.truncate_list(
794
+ list_column_name=embedding_column_tmp_name,
795
+ target_list_length=target_list_length,
796
+ padding_value=0,
797
797
  )
798
- .and_then(
799
- lambda t,
800
- pk_field=pk_field,
801
- embedding_column_tmp_name=embedding_column_tmp_name,
802
- output_source=output_source: t.select_columns(
803
- [pk_field.name, embedding_column_tmp_name]
804
- )
798
+
799
+ def select_columns(
800
+ op: op_graph.Op,
801
+ pk_field=pk_field,
802
+ embedding_column_tmp_name=embedding_column_tmp_name,
803
+ ):
804
+ return op.select_columns([pk_field.name, embedding_column_tmp_name])
805
+
806
+ def update_feature_types(
807
+ op: op_graph.Op,
808
+ embedding_column_tmp_name=embedding_column_tmp_name,
809
+ ):
810
+ return op.update_feature_types(
811
+ {embedding_column_tmp_name: op_graph.feature_type.embedding()}
805
812
  )
806
- .and_then(
807
- lambda t,
808
- embedding_column_tmp_name=embedding_column_tmp_name: t.update_feature_types( # noqa: E501
809
- {embedding_column_tmp_name: op_graph.feature_type.embedding()}
810
- )
813
+
814
+ def rename_columns(
815
+ op: op_graph.Op,
816
+ pk_field=pk_field,
817
+ embedding_column_tmp_name=embedding_column_tmp_name,
818
+ ):
819
+ return op.rename_columns(
820
+ {
821
+ pk_field.name: "entity_id",
822
+ embedding_column_tmp_name: "embedding",
823
+ }
811
824
  )
812
- .and_then(
813
- lambda t,
814
- pk_field=pk_field,
815
- embedding_column_tmp_name=embedding_column_tmp_name: t.rename_columns( # noqa: E501
816
- {
817
- pk_field.name: "entity_id",
818
- embedding_column_tmp_name: "embedding",
819
- }
820
- )
825
+
826
+ def add_literal_column(
827
+ op: op_graph.Op,
828
+ output_source=output_source,
829
+ ):
830
+ return op.add_literal_column(
831
+ "source_id",
832
+ str(output_source.id),
833
+ pa.string(),
821
834
  )
822
- .and_then(
823
- lambda t, output_source=output_source: t.add_literal_column(
824
- "source_id",
825
- str(output_source.id),
826
- pa.string(),
827
- )
835
+
836
+ op = (
837
+ op.concat_list(
838
+ column_names=embedding_column_tmp_names,
839
+ concat_list_column_name=embedding_column_tmp_name,
828
840
  )
841
+ .and_then(reduce_dimension)
842
+ .and_then(select_columns)
843
+ .and_then(update_feature_types)
844
+ .and_then(rename_columns)
845
+ .and_then(add_literal_column)
829
846
  )
830
847
 
831
848
  match op:
corvic/orm/__init__.py CHANGED
@@ -48,6 +48,7 @@ from corvic.orm.mixins import (
48
48
  from corvic_generated.orm.v1 import (
49
49
  agent_pb2,
50
50
  common_pb2,
51
+ completion_model_pb2,
51
52
  feature_view_pb2,
52
53
  pipeline_pb2,
53
54
  space_pb2,
@@ -104,6 +105,13 @@ class Room(BelongsToOrgMixin, SoftDeleteMixin, Base):
104
105
  return self.name
105
106
 
106
107
 
108
+ class BelongsToRoomMixin(sa_orm.MappedAsDataclass):
109
+ room_id: sa_orm.Mapped[RoomID | None] = sa_orm.mapped_column(
110
+ ForeignKey(Room).make(ondelete="CASCADE"),
111
+ nullable=True,
112
+ )
113
+
114
+
107
115
  class DefaultObjects(Base):
108
116
  """Holds the identifiers for default objects."""
109
117
 
@@ -117,7 +125,7 @@ class DefaultObjects(Base):
117
125
  version: sa_orm.Mapped[int | None] = primary_key_identity_column(type_=INT_PK_TYPE)
118
126
 
119
127
 
120
- class Resource(BelongsToOrgMixin, Base):
128
+ class Resource(BelongsToOrgMixin, BelongsToRoomMixin, Base):
121
129
  """A Resource is a reference to some durably stored file.
122
130
 
123
131
  E.g., a document could be a PDF file, an image, or a text transcript of a
@@ -129,9 +137,6 @@ class Resource(BelongsToOrgMixin, Base):
129
137
  name: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text)
130
138
  mime_type: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text)
131
139
  url: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text)
132
- room_id: sa_orm.Mapped[RoomID] = sa_orm.mapped_column(
133
- ForeignKey(Room).make(ondelete="CASCADE"), name="room_id"
134
- )
135
140
  md5: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.CHAR(32), nullable=True)
136
141
  size: sa_orm.Mapped[int] = sa_orm.mapped_column(nullable=True)
137
142
  original_path: sa_orm.Mapped[str] = sa_orm.mapped_column(nullable=True)
@@ -156,16 +161,13 @@ class Resource(BelongsToOrgMixin, Base):
156
161
  )
157
162
 
158
163
 
159
- class Source(BelongsToOrgMixin, Base):
164
+ class Source(BelongsToOrgMixin, BelongsToRoomMixin, Base):
160
165
  """A source."""
161
166
 
162
167
  __tablename__ = "source"
163
168
  __table_args__ = (sa.UniqueConstraint("name", "room_id"),)
164
169
 
165
170
  name: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text)
166
- room_id: sa_orm.Mapped[RoomID] = sa_orm.mapped_column(
167
- ForeignKey(Room).make(ondelete="CASCADE"),
168
- )
169
171
  # protobuf describing the operations required to construct a table
170
172
  table_op_graph: sa_orm.Mapped[table_pb2.TableComputeOp] = sa_orm.mapped_column()
171
173
  id: sa_orm.Mapped[SourceID | None] = primary_key_identity_column()
@@ -196,7 +198,7 @@ class Source(BelongsToOrgMixin, Base):
196
198
  return self.name
197
199
 
198
200
 
199
- class Pipeline(BelongsToOrgMixin, Base):
201
+ class Pipeline(BelongsToOrgMixin, BelongsToRoomMixin, Base):
200
202
  """A resource to source pipeline."""
201
203
 
202
204
  __tablename__ = "pipeline"
@@ -205,9 +207,6 @@ class Pipeline(BelongsToOrgMixin, Base):
205
207
  transformation: sa_orm.Mapped[pipeline_pb2.PipelineTransformation] = (
206
208
  sa_orm.mapped_column()
207
209
  )
208
- room_id: sa_orm.Mapped[RoomID] = sa_orm.mapped_column(
209
- ForeignKey(Room).make(ondelete="CASCADE")
210
- )
211
210
  name: sa_orm.Mapped[str] = sa_orm.mapped_column()
212
211
  description: sa_orm.Mapped[str | None] = sa_orm.mapped_column()
213
212
  id: sa_orm.Mapped[PipelineID | None] = primary_key_identity_column()
@@ -227,7 +226,7 @@ class Pipeline(BelongsToOrgMixin, Base):
227
226
  )
228
227
 
229
228
 
230
- class PipelineInput(BelongsToOrgMixin, Base):
229
+ class PipelineInput(BelongsToOrgMixin, BelongsToRoomMixin, Base):
231
230
  """Pipeline input resources."""
232
231
 
233
232
  __tablename__ = "pipeline_input"
@@ -250,7 +249,7 @@ class PipelineInput(BelongsToOrgMixin, Base):
250
249
  )
251
250
 
252
251
 
253
- class PipelineOutput(BelongsToOrgMixin, Base):
252
+ class PipelineOutput(BelongsToOrgMixin, BelongsToRoomMixin, Base):
254
253
  """Objects for tracking pipeline output sources."""
255
254
 
256
255
  __tablename__ = "pipeline_output"
@@ -273,7 +272,7 @@ class PipelineOutput(BelongsToOrgMixin, Base):
273
272
  )
274
273
 
275
274
 
276
- class SourceResourceAssociation(BelongsToOrgMixin, Base):
275
+ class SourceResourceAssociation(BelongsToOrgMixin, BelongsToRoomMixin, Base):
277
276
  __tablename__ = "source_resource_association"
278
277
 
279
278
  source_id: sa_orm.Mapped[SourceID | None] = (
@@ -294,7 +293,7 @@ class SourceResourceAssociation(BelongsToOrgMixin, Base):
294
293
  )
295
294
 
296
295
 
297
- class FeatureView(SoftDeleteMixin, BelongsToOrgMixin, Base):
296
+ class FeatureView(SoftDeleteMixin, BelongsToOrgMixin, BelongsToRoomMixin, Base):
298
297
  """A FeatureView is a logical collection of sources used by various spaces."""
299
298
 
300
299
  __tablename__ = "feature_view"
@@ -304,12 +303,6 @@ class FeatureView(SoftDeleteMixin, BelongsToOrgMixin, Base):
304
303
  name: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default=None)
305
304
  description: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default="")
306
305
 
307
- room_id: sa_orm.Mapped[RoomID | None] = sa_orm.mapped_column(
308
- ForeignKey(Room).make(ondelete="CASCADE"),
309
- nullable=True,
310
- init=True,
311
- default=None,
312
- )
313
306
  room: sa_orm.Mapped[Room] = sa_orm.relationship(
314
307
  back_populates="feature_views", init=False
315
308
  )
@@ -333,10 +326,11 @@ class FeatureView(SoftDeleteMixin, BelongsToOrgMixin, Base):
333
326
  )
334
327
 
335
328
 
336
- class FeatureViewSource(BelongsToOrgMixin, Base):
329
+ class FeatureViewSource(BelongsToOrgMixin, BelongsToRoomMixin, Base):
337
330
  """A source inside of a feature view."""
338
331
 
339
332
  __tablename__ = "feature_view_source"
333
+
340
334
  table_op_graph: sa_orm.Mapped[table_pb2.TableComputeOp] = sa_orm.mapped_column()
341
335
  id: sa_orm.Mapped[FeatureViewSourceID | None] = primary_key_identity_column()
342
336
  drop_disconnected: sa_orm.Mapped[bool] = sa_orm.mapped_column(default=False)
@@ -356,18 +350,12 @@ class FeatureViewSource(BelongsToOrgMixin, Base):
356
350
  )
357
351
 
358
352
 
359
- class Space(BelongsToOrgMixin, Base):
353
+ class Space(BelongsToOrgMixin, BelongsToRoomMixin, Base):
360
354
  """A space is a named evaluation of space parameters."""
361
355
 
362
356
  __tablename__ = "space"
363
357
  __table_args__ = (sa.UniqueConstraint("name", "room_id"),)
364
358
 
365
- room_id: sa_orm.Mapped[RoomID] = sa_orm.mapped_column(
366
- ForeignKey(Room).make(ondelete="CASCADE"),
367
- nullable=True,
368
- init=True,
369
- default=None,
370
- )
371
359
  room: sa_orm.Mapped[Room] = sa_orm.relationship(
372
360
  back_populates="spaces", init=True, default=None
373
361
  )
@@ -404,7 +392,7 @@ class Space(BelongsToOrgMixin, Base):
404
392
  return self.name
405
393
 
406
394
 
407
- class SpaceRun(BelongsToOrgMixin, Base):
395
+ class SpaceRun(BelongsToOrgMixin, BelongsToRoomMixin, Base):
408
396
  """A Space run."""
409
397
 
410
398
  __tablename__ = "space_run"
@@ -439,18 +427,12 @@ class SpaceRun(BelongsToOrgMixin, Base):
439
427
  )
440
428
 
441
429
 
442
- class Agent(SoftDeleteMixin, BelongsToOrgMixin, Base):
430
+ class Agent(SoftDeleteMixin, BelongsToOrgMixin, BelongsToRoomMixin, Base):
443
431
  """An Agent."""
444
432
 
445
433
  __tablename__ = "agent"
446
434
  __table_args__ = (live_unique_constraint("name", "room_id"),)
447
435
 
448
- room_id: sa_orm.Mapped[RoomID | None] = sa_orm.mapped_column(
449
- ForeignKey(Room).make(ondelete="CASCADE"),
450
- nullable=True,
451
- init=True,
452
- default=None,
453
- )
454
436
  id: sa_orm.Mapped[AgentID | None] = primary_key_identity_column()
455
437
  name: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default=None)
456
438
 
@@ -478,7 +460,7 @@ class Agent(SoftDeleteMixin, BelongsToOrgMixin, Base):
478
460
  )
479
461
 
480
462
 
481
- class AgentSpaceAssociation(BelongsToOrgMixin, Base):
463
+ class AgentSpaceAssociation(BelongsToOrgMixin, BelongsToRoomMixin, Base):
482
464
  __tablename__ = "agent_space_association"
483
465
 
484
466
  agent_id: sa_orm.Mapped[AgentID] = primary_key_foreign_column(
@@ -495,7 +477,7 @@ class AgentSpaceAssociation(BelongsToOrgMixin, Base):
495
477
  )
496
478
 
497
479
 
498
- class UserMessage(SoftDeleteMixin, BelongsToOrgMixin, Base):
480
+ class UserMessage(SoftDeleteMixin, BelongsToOrgMixin, BelongsToRoomMixin, Base):
499
481
  """A message sent by an user."""
500
482
 
501
483
  __tablename__ = "user_message"
@@ -509,7 +491,7 @@ class UserMessage(SoftDeleteMixin, BelongsToOrgMixin, Base):
509
491
  message: sa_orm.Mapped[str | None] = sa_orm.mapped_column(sa.Text, default=None)
510
492
 
511
493
 
512
- class AgentMessage(SoftDeleteMixin, BelongsToOrgMixin, Base):
494
+ class AgentMessage(SoftDeleteMixin, BelongsToOrgMixin, BelongsToRoomMixin, Base):
513
495
  """A message sent by an agent."""
514
496
 
515
497
  __tablename__ = "agent_message"
@@ -537,7 +519,7 @@ class AgentMessage(SoftDeleteMixin, BelongsToOrgMixin, Base):
537
519
  )
538
520
 
539
521
 
540
- class MessageEntry(SoftDeleteMixin, BelongsToOrgMixin, Base):
522
+ class MessageEntry(SoftDeleteMixin, BelongsToOrgMixin, BelongsToRoomMixin, Base):
541
523
  """A message either sent by an Agent or an User."""
542
524
 
543
525
  __tablename__ = "message_entry"
@@ -580,8 +562,9 @@ class CompletionModel(SoftDeleteMixin, BelongsToOrgMixin, Base):
580
562
  id: sa_orm.Mapped[CompletionModelID | None] = primary_key_identity_column()
581
563
  name: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default=None)
582
564
  description: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default=None)
583
- model_name: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default=None)
584
- endpoint: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default=None)
565
+ parameters: sa_orm.Mapped[completion_model_pb2.CompletionModelParameters | None] = (
566
+ sa_orm.mapped_column(default=None)
567
+ )
585
568
  secret_api_key: sa_orm.Mapped[str] = sa_orm.mapped_column(sa.Text, default=None)
586
569
 
587
570
  @property
corvic/orm/base.py CHANGED
@@ -41,6 +41,7 @@ from corvic.orm.keys import (
41
41
  from corvic_generated.orm.v1 import (
42
42
  agent_pb2,
43
43
  common_pb2,
44
+ completion_model_pb2,
44
45
  feature_view_pb2,
45
46
  pipeline_pb2,
46
47
  space_pb2,
@@ -112,6 +113,9 @@ class Base(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
112
113
  pipeline_pb2.PipelineTransformation()
113
114
  ),
114
115
  event_pb2.Event: ProtoMessageDecorator(event_pb2.Event()),
116
+ completion_model_pb2.CompletionModelParameters: ProtoMessageDecorator(
117
+ completion_model_pb2.CompletionModelParameters()
118
+ ),
115
119
  # ID types
116
120
  OrgID: StrIDDecorator(OrgID()),
117
121
  RoomID: IntIDDecorator(RoomID()),