corvic-engine 0.3.0rc67__cp38-abi3-win_amd64.whl → 0.3.0rc68__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
corvic/engine/_native.pyd CHANGED
Binary file
@@ -5,12 +5,11 @@ import datetime
5
5
  import functools
6
6
  import uuid
7
7
  from collections.abc import Callable, Iterable, Iterator, Sequence
8
- from typing import Final, Generic
8
+ from typing import Final, Generic, Self
9
9
 
10
10
  import sqlalchemy as sa
11
11
  import sqlalchemy.orm as sa_orm
12
12
  import structlog
13
- from typing_extensions import Self
14
13
 
15
14
  from corvic import orm, system
16
15
  from corvic.model._proto_orm_convert import (
@@ -54,7 +53,7 @@ class HasProtoSelf(Generic[ProtoObj], abc.ABC):
54
53
  @property
55
54
  def created_at(self) -> datetime.datetime | None:
56
55
  if self.proto_self.created_at:
57
- return self.proto_self.created_at.ToDatetime(tzinfo=datetime.timezone.utc)
56
+ return self.proto_self.created_at.ToDatetime(tzinfo=datetime.UTC)
58
57
  return None
59
58
 
60
59
 
@@ -119,7 +118,7 @@ class BaseModel(Generic[ID, ProtoObj, OrmObj], UsesOrmID[ID, ProtoObj]):
119
118
  while True:
120
119
  try:
121
120
  yield from it
122
- except Exception: # noqa: PERF203
121
+ except Exception:
123
122
  _logger.exception(
124
123
  "omitting source from list: "
125
124
  + "failed to parse source from database entry",
@@ -116,16 +116,14 @@ class CompletionModel(
116
116
  @property
117
117
  def last_validation_time(self) -> datetime.datetime | None:
118
118
  if self.proto_self.last_validation_time != UNIX_TIMESTAMP_START_DATETIME:
119
- return self.proto_self.last_validation_time.ToDatetime(
120
- tzinfo=datetime.timezone.utc
121
- )
119
+ return self.proto_self.last_validation_time.ToDatetime(tzinfo=datetime.UTC)
122
120
  return None
123
121
 
124
122
  @property
125
123
  def last_successful_validation(self) -> datetime.datetime | None:
126
124
  if self.proto_self.last_successful_validation != UNIX_TIMESTAMP_START_DATETIME:
127
125
  return self.proto_self.last_successful_validation.ToDatetime(
128
- tzinfo=datetime.timezone.utc
126
+ tzinfo=datetime.UTC
129
127
  )
130
128
  return None
131
129
 
@@ -7,15 +7,14 @@ import dataclasses
7
7
  import datetime
8
8
  import functools
9
9
  import uuid
10
- from collections.abc import Iterable, Mapping, MutableMapping, Sequence
11
- from typing import Any, Final, TypeAlias
10
+ from collections.abc import AsyncIterable, Iterable, Mapping, MutableMapping, Sequence
11
+ from typing import Any, Final, Self, TypeAlias
12
12
 
13
13
  import pyarrow as pa
14
14
  from google.protobuf import struct_pb2
15
15
  from more_itertools import flatten
16
16
  from sqlalchemy import orm as sa_orm
17
17
  from sqlalchemy.orm.interfaces import LoaderOption
18
- from typing_extensions import Self
19
18
 
20
19
  from corvic import op_graph, orm, system
21
20
  from corvic.model._base_model import BelongsToRoomModel, UsesOrmID
@@ -324,7 +323,7 @@ class Relationship:
324
323
  how="inner",
325
324
  )
326
325
 
327
- def edge_list(self) -> Iterable[tuple[Any, Any]]:
326
+ async def edge_list(self) -> AsyncIterable[tuple[Any, Any]]:
328
327
  start_pk = self.start_fv_source.table.schema.get_primary_key()
329
328
  end_pk = self.end_fv_source.table.schema.get_primary_key()
330
329
 
@@ -340,8 +339,8 @@ class Relationship:
340
339
 
341
340
  result = self.joined_table().select(result_columns)
342
341
 
343
- for batch in result.to_polars(
344
- room_id=self.start_source.room_id
342
+ for batch in (
343
+ await result.to_polars(room_id=self.start_source.room_id)
345
344
  ).unwrap_or_raise():
346
345
  for row in batch.rows(named=True):
347
346
  yield (row[result_columns[0]], row[result_columns[1]])
corvic/model/_pipeline.py CHANGED
@@ -6,12 +6,11 @@ import datetime
6
6
  import functools
7
7
  import uuid
8
8
  from collections.abc import Iterable, Mapping, Sequence
9
- from typing import TypeAlias, cast
9
+ from typing import Self, TypeAlias, cast
10
10
 
11
11
  import polars as pl
12
12
  from sqlalchemy import orm as sa_orm
13
13
  from sqlalchemy.orm.interfaces import LoaderOption
14
- from typing_extensions import Self
15
14
 
16
15
  import corvic.table
17
16
  from corvic import op_graph, orm, system
corvic/model/_resource.py CHANGED
@@ -6,13 +6,12 @@ import copy
6
6
  import datetime
7
7
  import uuid
8
8
  from collections.abc import Iterable, Sequence
9
- from typing import TypeAlias
9
+ from typing import Self, TypeAlias
10
10
 
11
11
  import polars as pl
12
12
  import sqlalchemy as sa
13
13
  from sqlalchemy import orm as sa_orm
14
14
  from sqlalchemy.orm.interfaces import LoaderOption
15
- from typing_extensions import Self
16
15
 
17
16
  from corvic import orm, system
18
17
  from corvic.model._base_model import BelongsToRoomModel
corvic/model/_source.py CHANGED
@@ -6,13 +6,12 @@ import copy
6
6
  import datetime
7
7
  import functools
8
8
  from collections.abc import Iterable, Mapping, Sequence
9
- from typing import TypeAlias
9
+ from typing import Self, TypeAlias
10
10
 
11
11
  import polars as pl
12
12
  import sqlalchemy as sa
13
13
  import sqlalchemy.orm as sa_orm
14
14
  from sqlalchemy.orm.interfaces import LoaderOption
15
- from typing_extensions import Self
16
15
 
17
16
  from corvic import op_graph, orm, system
18
17
  from corvic.model._base_model import BelongsToRoomModel
corvic/model/_space.py CHANGED
@@ -6,12 +6,11 @@ import abc
6
6
  import datetime
7
7
  import uuid
8
8
  from collections.abc import Iterable, Mapping, Sequence
9
- from typing import Final, Literal, TypeAlias
9
+ from typing import Final, Literal, Self, TypeAlias
10
10
 
11
11
  import pyarrow as pa
12
12
  import sqlalchemy as sa
13
13
  from sqlalchemy import orm as sa_orm
14
- from typing_extensions import Self
15
14
 
16
15
  from corvic import op_graph, orm, system
17
16
  from corvic.model._base_model import BelongsToRoomModel
corvic/orm/base.py CHANGED
@@ -3,14 +3,13 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import uuid
6
- from datetime import datetime, timezone
7
- from typing import Any, ClassVar, Protocol, runtime_checkable
6
+ from datetime import UTC, datetime
7
+ from typing import Any, ClassVar, Protocol, Self, runtime_checkable
8
8
 
9
9
  import sqlalchemy as sa
10
10
  import sqlalchemy.orm as sa_orm
11
11
  from google.protobuf import timestamp_pb2
12
12
  from sqlalchemy.ext import hybrid
13
- from typing_extensions import Self
14
13
 
15
14
  from corvic.orm._proto_columns import ProtoMessageDecorator
16
15
  from corvic.orm.func import utc_now
@@ -151,7 +150,7 @@ class Base(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
151
150
  def created_at(self) -> datetime | None:
152
151
  if not self._created_at:
153
152
  return None
154
- return self._created_at.replace(tzinfo=timezone.utc)
153
+ return self._created_at.replace(tzinfo=UTC)
155
154
 
156
155
  @created_at.inplace.expression
157
156
  @classmethod
@@ -162,7 +161,7 @@ class Base(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
162
161
  def updated_at(self) -> datetime | None:
163
162
  if not self._updated_at:
164
163
  return None
165
- return self._updated_at.replace(tzinfo=timezone.utc)
164
+ return self._updated_at.replace(tzinfo=UTC)
166
165
 
167
166
  @updated_at.inplace.expression
168
167
  @classmethod
corvic/orm/ids.py CHANGED
@@ -3,11 +3,10 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import abc
6
- from typing import Any, Generic, TypeVar
6
+ from typing import Any, Generic, Self, TypeVar
7
7
 
8
8
  import sqlalchemy as sa
9
9
  import sqlalchemy.types as sa_types
10
- from typing_extensions import Self
11
10
 
12
11
  import corvic.context
13
12
  from corvic.orm.errors import InvalidORMIdentifierError
corvic/orm/mixins.py CHANGED
@@ -3,8 +3,8 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable, Sequence
6
- from datetime import datetime, timezone
7
- from typing import Any, cast
6
+ from datetime import UTC, datetime
7
+ from typing import Any, LiteralString, cast
8
8
 
9
9
  import sqlalchemy as sa
10
10
  from google.protobuf import timestamp_pb2
@@ -12,7 +12,6 @@ from sqlalchemy import event, exc
12
12
  from sqlalchemy import orm as sa_orm
13
13
  from sqlalchemy.ext import hybrid
14
14
  from sqlalchemy.ext.hybrid import hybrid_property
15
- from typing_extensions import LiteralString
16
15
 
17
16
  import corvic.context
18
17
  from corvic.orm.base import EventBase, EventKey, OrgBase
@@ -137,7 +136,7 @@ class SoftDeleteMixin(sa_orm.MappedAsDataclass):
137
136
  def deleted_at(self) -> datetime | None:
138
137
  if not self._deleted_at:
139
138
  return None
140
- return self._deleted_at.replace(tzinfo=timezone.utc)
139
+ return self._deleted_at.replace(tzinfo=UTC)
141
140
 
142
141
  def reset_delete(self):
143
142
  self._deleted_at = None
@@ -170,7 +169,7 @@ class SoftDeleteMixin(sa_orm.MappedAsDataclass):
170
169
  # set is_live to None instead of False so that orm objects can use it to
171
170
  # build uniqueness constraints that are only enforced on non-deleted objects
172
171
  self.is_live = None
173
- self._deleted_at = datetime.now(tz=timezone.utc)
172
+ self._deleted_at = datetime.now(tz=UTC)
174
173
 
175
174
  @hybrid_property
176
175
  def is_deleted(self) -> bool:
@@ -343,7 +342,7 @@ class Session(sa_orm.Session):
343
342
  def _timestamp_or_utc_now(timestamp: datetime | None = None):
344
343
  if timestamp is not None:
345
344
  return timestamp
346
- return datetime.now(tz=timezone.utc)
345
+ return datetime.now(tz=UTC)
347
346
 
348
347
 
349
348
  class EventLoggerMixin(sa_orm.MappedAsDataclass):
@@ -383,12 +382,11 @@ class EventLoggerMixin(sa_orm.MappedAsDataclass):
383
382
  # this can occur when an event is set on a new object
384
383
  if not self._event_src_id:
385
384
  obj_session.flush()
386
- from datetime import timezone
387
385
 
388
386
  obj_session.add(
389
387
  EventBase(
390
388
  event=event.event_type,
391
- timestamp=event.timestamp.ToDatetime(tzinfo=timezone.utc),
389
+ timestamp=event.timestamp.ToDatetime(tzinfo=UTC),
392
390
  regarding=event.regarding,
393
391
  reason=event.reason,
394
392
  event_key=str(self.event_key),
@@ -86,7 +86,7 @@ def datetime_fromisoformat(s: str) -> datetime.datetime:
86
86
 
87
87
  tzinfo = None
88
88
  if m.group("utc"):
89
- tzinfo = datetime.timezone.utc
89
+ tzinfo = datetime.UTC
90
90
 
91
91
  tz_hour = m.group("tz_hour")
92
92
  if tz_hour:
corvic/result/__init__.py CHANGED
@@ -55,12 +55,11 @@ from typing import (
55
55
  Literal,
56
56
  NoReturn,
57
57
  ParamSpec,
58
+ Self,
58
59
  TypeVar,
59
60
  overload,
60
61
  )
61
62
 
62
- from typing_extensions import Self
63
-
64
63
  from corvic.well_known_types import JSONAble, JSONExpressable, to_json
65
64
 
66
65
  T_co = TypeVar("T_co", covariant=True)
@@ -0,0 +1,215 @@
1
+ import math
2
+ from typing import Final, cast
3
+
4
+ import numpy as np
5
+ import polars as pl
6
+ import structlog
7
+
8
+ REFERENCE_YEAR: Final = 1900
9
+ """Reference year for normalizing year in Datetime encoder"""
10
+
11
+ MAX_NUMBER_OF_YEARS: Final = 200
12
+ """Maximum number of years for normalizing year in Datetime encoder"""
13
+
14
+ _logger = structlog.get_logger()
15
+
16
+
17
+ def encode_one_hot(to_encode: pl.Series) -> tuple[pl.Series, list[str]]:
18
+ encoded = to_encode.to_dummies()
19
+ return (
20
+ encoded.select(
21
+ pl.concat_list(pl.all()).alias("val").cast(pl.List(pl.Boolean))
22
+ ).to_series(),
23
+ encoded.columns,
24
+ )
25
+
26
+
27
+ def encode_min_max_scale(
28
+ to_encode: pl.Series, range_min: float, range_max: float
29
+ ) -> pl.Series:
30
+ from sklearn.preprocessing import MinMaxScaler
31
+
32
+ encoder = MinMaxScaler(
33
+ feature_range=(
34
+ range_min,
35
+ range_max,
36
+ )
37
+ )
38
+ return pl.Series(
39
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
40
+ )
41
+
42
+
43
+ def encode_label_boolean(
44
+ to_encode: pl.Series, neg_label: int, pos_label: int
45
+ ) -> pl.Series:
46
+ from sklearn.preprocessing import LabelBinarizer
47
+
48
+ encoder = LabelBinarizer(
49
+ neg_label=neg_label,
50
+ pos_label=pos_label,
51
+ )
52
+ return pl.Series(encoder.fit_transform(to_encode.to_numpy().reshape(-1)))
53
+
54
+
55
+ def encode_label(to_encode: pl.Series, *, normalize: bool) -> pl.Series:
56
+ from sklearn.preprocessing import LabelEncoder
57
+
58
+ encoder = LabelEncoder()
59
+ encoded = encoder.fit_transform(to_encode.to_numpy().reshape(-1)).flatten()
60
+ # `classes_` is only set after fit,
61
+ # Creating custom typestubs will not solve this typing issue.
62
+ if normalize and hasattr(encoder, "classes_"):
63
+ classes_ = cast(list[int], encoder.classes_) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
64
+ max_class: int = len(classes_) - 1
65
+ if max_class > 0:
66
+ encoded = encoded.astype(np.float64)
67
+ encoded /= max_class
68
+
69
+ return pl.Series(encoded)
70
+
71
+
72
+ def encode_kbins(
73
+ to_encode: pl.Series, n_bins: int, method: str, strategy: str
74
+ ) -> pl.Series:
75
+ from sklearn.preprocessing import KBinsDiscretizer
76
+
77
+ encoder = KBinsDiscretizer(
78
+ n_bins=n_bins,
79
+ encode=method,
80
+ strategy=strategy,
81
+ dtype=np.float32,
82
+ )
83
+ return pl.Series(
84
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
85
+ )
86
+
87
+
88
+ def encode_boolean(to_encode: pl.Series, threshold: float) -> pl.Series:
89
+ from sklearn.preprocessing import Binarizer
90
+
91
+ encoder = Binarizer(
92
+ threshold=threshold,
93
+ )
94
+ return pl.Series(
95
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
96
+ )
97
+
98
+
99
+ def encode_max_abs_scale(to_encode: pl.Series) -> pl.Series:
100
+ from sklearn.preprocessing import MaxAbsScaler
101
+
102
+ encoder = MaxAbsScaler()
103
+ try:
104
+ encoded = encoder.fit_transform(
105
+ np.nan_to_num(to_encode.to_numpy()).reshape(-1, 1)
106
+ ).flatten()
107
+ except ValueError:
108
+ encoded = np.array([])
109
+
110
+ return pl.Series(encoded)
111
+
112
+
113
+ def encode_standard_scale(
114
+ to_encode: pl.Series, *, with_mean: bool, with_std: bool
115
+ ) -> pl.Series:
116
+ from sklearn.preprocessing import StandardScaler
117
+
118
+ encoder = StandardScaler(
119
+ with_mean=with_mean,
120
+ with_std=with_std,
121
+ )
122
+ return pl.Series(
123
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
124
+ )
125
+
126
+
127
+ def encode_duration(to_encode: pl.Series) -> pl.Series:
128
+ if to_encode.dtype != pl.Duration:
129
+ raise ValueError("Invalid arguments, expected a duration series")
130
+ if to_encode.is_null().all():
131
+ return pl.zeros(len(to_encode), dtype=pl.Float32, eager=True)
132
+
133
+ return to_encode.dt.total_seconds().cast(pl.Float32).fill_null(0.0)
134
+
135
+
136
+ def _get_cyclic_encoding(
137
+ to_encode: pl.Series,
138
+ period: int,
139
+ ) -> tuple[pl.Series, pl.Series]:
140
+ sine_series = (
141
+ (2 * math.pi * to_encode / period).sin().alias(f"{to_encode.name}_sine")
142
+ )
143
+ cosine_series = (
144
+ (2 * math.pi * to_encode / period).cos().alias(f"{to_encode.name}_cosine")
145
+ )
146
+ return sine_series, cosine_series
147
+
148
+
149
+ def encode_datetime(to_encode: pl.Series) -> pl.Series:
150
+ match to_encode.dtype:
151
+ case pl.Date | pl.Time:
152
+ pass
153
+ case pl.Datetime:
154
+ to_encode = to_encode.dt.replace_time_zone("UTC")
155
+ case _:
156
+ raise ValueError(
157
+ "Invalid arguments column could not be endoded as datetime"
158
+ )
159
+
160
+ if to_encode.is_null().all():
161
+ zero_vector = pl.zeros(11, dtype=pl.Float32, eager=True)
162
+ return pl.Series([zero_vector] * len(to_encode), dtype=pl.List(pl.Float32))
163
+
164
+ n = len(to_encode)
165
+ year_norm = pl.zeros(n, dtype=pl.Float32, eager=True).alias("year")
166
+ month_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_sine")
167
+ month_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_cosine")
168
+ day_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_sine")
169
+ day_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_cosine")
170
+ hour_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_sine")
171
+ hour_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_cosine")
172
+ minute_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_sine")
173
+ minute_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_cosine")
174
+ second_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_sine")
175
+ second_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_cosine")
176
+
177
+ if to_encode.dtype in [pl.Date, pl.Datetime]:
178
+ try:
179
+ year = to_encode.dt.year().cast(pl.Float32).alias("year")
180
+ month = to_encode.dt.month().cast(pl.Float32).alias("month")
181
+ day = to_encode.dt.day().cast(pl.Float32).alias("day")
182
+
183
+ year_norm = (year - REFERENCE_YEAR) / MAX_NUMBER_OF_YEARS
184
+ month_sine, month_cosine = _get_cyclic_encoding(month, 12)
185
+ day_sine, day_cosine = _get_cyclic_encoding(day, 31)
186
+ except pl.exceptions.PanicException as e:
187
+ _logger.exception("Error extracting datetime", exc_info=e)
188
+
189
+ if to_encode.dtype in [pl.Time, pl.Datetime]:
190
+ try:
191
+ hour = to_encode.dt.hour().cast(pl.Float32).alias("hour")
192
+ minute = to_encode.dt.minute().cast(pl.Float32).alias("minute")
193
+ second = to_encode.dt.second().cast(pl.Float32).alias("second")
194
+
195
+ hour_sine, hour_cosine = _get_cyclic_encoding(hour, 24)
196
+ minute_sine, minute_cosine = _get_cyclic_encoding(minute, 60)
197
+ second_sine, second_cosine = _get_cyclic_encoding(second, 60)
198
+ except pl.exceptions.PanicException as e:
199
+ _logger.exception("Error extracting datetime", exc_info=e)
200
+
201
+ return pl.DataFrame(
202
+ [
203
+ year_norm.fill_null(0.0),
204
+ month_sine.fill_null(0.0),
205
+ month_cosine.fill_null(0.0),
206
+ day_sine.fill_null(0.0),
207
+ day_cosine.fill_null(0.0),
208
+ hour_sine.fill_null(0.0),
209
+ hour_cosine.fill_null(0.0),
210
+ minute_sine.fill_null(0.0),
211
+ minute_cosine.fill_null(0.0),
212
+ second_sine.fill_null(0.0),
213
+ second_cosine.fill_null(0.0),
214
+ ]
215
+ ).select(pl.concat_list(pl.all()).alias(to_encode.name))[to_encode.name]
@@ -1,10 +1,11 @@
1
+ import asyncio
1
2
  import dataclasses
2
3
  from collections.abc import Sequence
3
- from typing import TYPE_CHECKING, Any, Literal
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from typing import TYPE_CHECKING, Any, Literal, Protocol
4
6
 
5
7
  import numpy as np
6
8
  import polars as pl
7
- from typing_extensions import Protocol
8
9
 
9
10
  from corvic import orm
10
11
  from corvic.result import InternalError, InvalidArgumentError, Ok
@@ -43,6 +44,12 @@ class TextEmbedder(Protocol):
43
44
  self, context: EmbedTextContext
44
45
  ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError: ...
45
46
 
47
+ async def aembed(
48
+ self,
49
+ context: EmbedTextContext,
50
+ worker_threads: ThreadPoolExecutor | None = None,
51
+ ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError: ...
52
+
46
53
 
47
54
  @dataclasses.dataclass
48
55
  class EmbedImageContext:
@@ -69,6 +76,12 @@ class ImageEmbedder(Protocol):
69
76
  self, context: EmbedImageContext
70
77
  ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
71
78
 
79
+ async def aembed(
80
+ self,
81
+ context: EmbedImageContext,
82
+ worker_threads: ThreadPoolExecutor | None = None,
83
+ ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
84
+
72
85
 
73
86
  @dataclasses.dataclass
74
87
  class ClipModels:
@@ -142,3 +155,12 @@ class ClipText(TextEmbedder):
142
155
  ),
143
156
  )
144
157
  )
158
+
159
+ async def aembed(
160
+ self,
161
+ context: EmbedTextContext,
162
+ worker_threads: ThreadPoolExecutor | None = None,
163
+ ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError:
164
+ return await asyncio.get_running_loop().run_in_executor(
165
+ worker_threads, self.embed, context
166
+ )
@@ -1,4 +1,6 @@
1
+ import asyncio
1
2
  import dataclasses
3
+ from concurrent.futures import ThreadPoolExecutor
2
4
  from io import BytesIO
3
5
  from typing import TYPE_CHECKING, Any
4
6
 
@@ -51,6 +53,15 @@ class RandomImageEmbedder(ImageEmbedder):
51
53
  )
52
54
  )
53
55
 
56
+ async def aembed(
57
+ self,
58
+ context: EmbedImageContext,
59
+ worker_threads: ThreadPoolExecutor | None = None,
60
+ ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
61
+ return await asyncio.get_running_loop().run_in_executor(
62
+ worker_threads, self.embed, context
63
+ )
64
+
54
65
 
55
66
  def image_from_bytes(
56
67
  image: bytes, mode: str = "RGB"
@@ -155,6 +166,15 @@ class Clip(ImageEmbedder):
155
166
  )
156
167
  )
157
168
 
169
+ async def aembed(
170
+ self,
171
+ context: EmbedImageContext,
172
+ worker_threads: ThreadPoolExecutor | None = None,
173
+ ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
174
+ return await asyncio.get_running_loop().run_in_executor(
175
+ worker_threads, self.embed, context
176
+ )
177
+
158
178
 
159
179
  class CombinedImageEmbedder(ImageEmbedder):
160
180
  def __init__(self):
@@ -168,6 +188,15 @@ class CombinedImageEmbedder(ImageEmbedder):
168
188
  return self._random_embedder.embed(context)
169
189
  return self._clip_embedder.embed(context)
170
190
 
191
+ async def aembed(
192
+ self,
193
+ context: EmbedImageContext,
194
+ worker_threads: ThreadPoolExecutor | None = None,
195
+ ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
196
+ return await asyncio.get_running_loop().run_in_executor(
197
+ worker_threads, self.embed, context
198
+ )
199
+
171
200
 
172
201
  class IdentityImageEmbedder(ImageEmbedder):
173
202
  """A deterministic image embedder.
@@ -251,6 +280,15 @@ class IdentityImageEmbedder(ImageEmbedder):
251
280
  )
252
281
  )
253
282
 
283
+ async def aembed(
284
+ self,
285
+ context: EmbedImageContext,
286
+ worker_threads: ThreadPoolExecutor | None = None,
287
+ ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError:
288
+ return await asyncio.get_running_loop().run_in_executor(
289
+ worker_threads, self.embed, context
290
+ )
291
+
254
292
  def preimage(
255
293
  self,
256
294
  embedding: list[float],
corvic/system/_planner.py CHANGED
@@ -1,3 +1,4 @@
1
+ from concurrent.futures import ThreadPoolExecutor
1
2
  from typing import ClassVar
2
3
 
3
4
  from more_itertools import flatten
@@ -182,8 +183,10 @@ class ValidateFirstExecutor(OpGraphExecutor):
182
183
  def __init__(self, wrapped_executor: OpGraphExecutor):
183
184
  self._wrapped_executor = wrapped_executor
184
185
 
185
- def execute(
186
- self, context: ExecutionContext
186
+ async def execute(
187
+ self,
188
+ context: ExecutionContext,
189
+ worker_threads: ThreadPoolExecutor | None = None,
187
190
  ) -> (
188
191
  Ok[ExecutionResult]
189
192
  | InvalidArgumentError
@@ -201,4 +204,4 @@ class ValidateFirstExecutor(OpGraphExecutor):
201
204
  return InvalidArgumentError(
202
205
  "sql_output_slice_args set but impossible on a given op_graph"
203
206
  )
204
- return self._wrapped_executor.execute(context)
207
+ return await self._wrapped_executor.execute(context, worker_threads)
@@ -1,3 +1,6 @@
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+
1
4
  import numpy as np
2
5
  import polars as pl
3
6
 
@@ -40,6 +43,15 @@ class RandomTextEmbedder(TextEmbedder):
40
43
  )
41
44
  )
42
45
 
46
+ async def aembed(
47
+ self,
48
+ context: EmbedTextContext,
49
+ worker_threads: ThreadPoolExecutor | None = None,
50
+ ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError:
51
+ return await asyncio.get_running_loop().run_in_executor(
52
+ worker_threads, self.embed, context
53
+ )
54
+
43
55
 
44
56
  class IdentityTextEmbedder(TextEmbedder):
45
57
  """A deterministic text embedder.
@@ -90,6 +102,15 @@ class IdentityTextEmbedder(TextEmbedder):
90
102
  )
91
103
  )
92
104
 
105
+ async def aembed(
106
+ self,
107
+ context: EmbedTextContext,
108
+ worker_threads: ThreadPoolExecutor | None = None,
109
+ ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError:
110
+ return await asyncio.get_running_loop().run_in_executor(
111
+ worker_threads, self.embed, context
112
+ )
113
+
93
114
  def preimage(self, embedding: list[float], *, normalized: bool = False) -> str:
94
115
  """Reconstruct the text from a given embedding vector."""
95
116
  if normalized: