corvic-engine 0.3.0rc67__cp38-abi3-win_amd64.whl → 0.3.0rc69__cp38-abi3-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. corvic/context/__init__.py +0 -8
  2. corvic/engine/_native.pyd +0 -0
  3. corvic/model/_base_model.py +3 -4
  4. corvic/model/_completion_model.py +2 -4
  5. corvic/model/_feature_view.py +5 -6
  6. corvic/model/_pipeline.py +1 -2
  7. corvic/model/_resource.py +1 -2
  8. corvic/model/_source.py +1 -2
  9. corvic/model/_space.py +26 -2
  10. corvic/op_graph/row_filters/_jsonlogic.py +32 -1
  11. corvic/orm/base.py +4 -5
  12. corvic/orm/ids.py +1 -2
  13. corvic/orm/mixins.py +6 -8
  14. corvic/pa_scalar/_temporal.py +1 -1
  15. corvic/result/__init__.py +1 -2
  16. corvic/system/_column_encoding.py +215 -0
  17. corvic/system/_embedder.py +24 -2
  18. corvic/system/_image_embedder.py +38 -0
  19. corvic/system/_planner.py +6 -3
  20. corvic/system/_text_embedder.py +21 -0
  21. corvic/system/client.py +2 -1
  22. corvic/system/in_memory_executor.py +503 -507
  23. corvic/system/op_graph_executor.py +7 -3
  24. corvic/system/storage.py +1 -3
  25. corvic/table/table.py +5 -5
  26. {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/METADATA +3 -4
  27. {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/RECORD +35 -34
  28. corvic_generated/feature/v2/feature_view_pb2.py +21 -21
  29. corvic_generated/feature/v2/space_pb2.py +59 -51
  30. corvic_generated/feature/v2/space_pb2.pyi +12 -6
  31. corvic_generated/ingest/v2/resource_pb2.py +25 -25
  32. corvic_generated/orm/v1/agent_pb2.py +2 -2
  33. corvic_generated/orm/v1/agent_pb2.pyi +4 -0
  34. {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/WHEEL +0 -0
  35. {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,6 @@ Affect things like logging and the names of metrics.
4
4
  """
5
5
 
6
6
  import contextvars
7
- import uuid
8
7
  from dataclasses import dataclass
9
8
 
10
9
  # These are sentinels used only in the Requester object below rather than actual org
@@ -21,17 +20,11 @@ class Requester:
21
20
 
22
21
 
23
22
  _SERVICE_NAME = contextvars.ContextVar("service_name", default="corvic")
24
- _TRACE_ID = contextvars.ContextVar("trace_id", default="")
25
23
  _REQUESTER = contextvars.ContextVar(
26
24
  "requester_identity", default=Requester(org_id=NOBODY_ORG_ID)
27
25
  )
28
26
 
29
27
 
30
- def get_trace_id() -> str:
31
- """Get current trace id."""
32
- return _TRACE_ID.get()
33
-
34
-
35
28
  def get_service_name() -> str:
36
29
  """Get current service name."""
37
30
  return _SERVICE_NAME.get()
@@ -45,7 +38,6 @@ def get_requester() -> Requester:
45
38
  def reset_context(*, service_name: str):
46
39
  """Reset contextvars for a new request."""
47
40
  _SERVICE_NAME.set(service_name)
48
- _TRACE_ID.set(str(uuid.uuid4()))
49
41
  _REQUESTER.set(Requester(org_id=NOBODY_ORG_ID))
50
42
 
51
43
 
corvic/engine/_native.pyd CHANGED
Binary file
@@ -5,12 +5,11 @@ import datetime
5
5
  import functools
6
6
  import uuid
7
7
  from collections.abc import Callable, Iterable, Iterator, Sequence
8
- from typing import Final, Generic
8
+ from typing import Final, Generic, Self
9
9
 
10
10
  import sqlalchemy as sa
11
11
  import sqlalchemy.orm as sa_orm
12
12
  import structlog
13
- from typing_extensions import Self
14
13
 
15
14
  from corvic import orm, system
16
15
  from corvic.model._proto_orm_convert import (
@@ -54,7 +53,7 @@ class HasProtoSelf(Generic[ProtoObj], abc.ABC):
54
53
  @property
55
54
  def created_at(self) -> datetime.datetime | None:
56
55
  if self.proto_self.created_at:
57
- return self.proto_self.created_at.ToDatetime(tzinfo=datetime.timezone.utc)
56
+ return self.proto_self.created_at.ToDatetime(tzinfo=datetime.UTC)
58
57
  return None
59
58
 
60
59
 
@@ -119,7 +118,7 @@ class BaseModel(Generic[ID, ProtoObj, OrmObj], UsesOrmID[ID, ProtoObj]):
119
118
  while True:
120
119
  try:
121
120
  yield from it
122
- except Exception: # noqa: PERF203
121
+ except Exception:
123
122
  _logger.exception(
124
123
  "omitting source from list: "
125
124
  + "failed to parse source from database entry",
@@ -116,16 +116,14 @@ class CompletionModel(
116
116
  @property
117
117
  def last_validation_time(self) -> datetime.datetime | None:
118
118
  if self.proto_self.last_validation_time != UNIX_TIMESTAMP_START_DATETIME:
119
- return self.proto_self.last_validation_time.ToDatetime(
120
- tzinfo=datetime.timezone.utc
121
- )
119
+ return self.proto_self.last_validation_time.ToDatetime(tzinfo=datetime.UTC)
122
120
  return None
123
121
 
124
122
  @property
125
123
  def last_successful_validation(self) -> datetime.datetime | None:
126
124
  if self.proto_self.last_successful_validation != UNIX_TIMESTAMP_START_DATETIME:
127
125
  return self.proto_self.last_successful_validation.ToDatetime(
128
- tzinfo=datetime.timezone.utc
126
+ tzinfo=datetime.UTC
129
127
  )
130
128
  return None
131
129
 
@@ -7,15 +7,14 @@ import dataclasses
7
7
  import datetime
8
8
  import functools
9
9
  import uuid
10
- from collections.abc import Iterable, Mapping, MutableMapping, Sequence
11
- from typing import Any, Final, TypeAlias
10
+ from collections.abc import AsyncIterable, Iterable, Mapping, MutableMapping, Sequence
11
+ from typing import Any, Final, Self, TypeAlias
12
12
 
13
13
  import pyarrow as pa
14
14
  from google.protobuf import struct_pb2
15
15
  from more_itertools import flatten
16
16
  from sqlalchemy import orm as sa_orm
17
17
  from sqlalchemy.orm.interfaces import LoaderOption
18
- from typing_extensions import Self
19
18
 
20
19
  from corvic import op_graph, orm, system
21
20
  from corvic.model._base_model import BelongsToRoomModel, UsesOrmID
@@ -324,7 +323,7 @@ class Relationship:
324
323
  how="inner",
325
324
  )
326
325
 
327
- def edge_list(self) -> Iterable[tuple[Any, Any]]:
326
+ async def edge_list(self) -> AsyncIterable[tuple[Any, Any]]:
328
327
  start_pk = self.start_fv_source.table.schema.get_primary_key()
329
328
  end_pk = self.end_fv_source.table.schema.get_primary_key()
330
329
 
@@ -340,8 +339,8 @@ class Relationship:
340
339
 
341
340
  result = self.joined_table().select(result_columns)
342
341
 
343
- for batch in result.to_polars(
344
- room_id=self.start_source.room_id
342
+ for batch in (
343
+ await result.to_polars(room_id=self.start_source.room_id)
345
344
  ).unwrap_or_raise():
346
345
  for row in batch.rows(named=True):
347
346
  yield (row[result_columns[0]], row[result_columns[1]])
corvic/model/_pipeline.py CHANGED
@@ -6,12 +6,11 @@ import datetime
6
6
  import functools
7
7
  import uuid
8
8
  from collections.abc import Iterable, Mapping, Sequence
9
- from typing import TypeAlias, cast
9
+ from typing import Self, TypeAlias, cast
10
10
 
11
11
  import polars as pl
12
12
  from sqlalchemy import orm as sa_orm
13
13
  from sqlalchemy.orm.interfaces import LoaderOption
14
- from typing_extensions import Self
15
14
 
16
15
  import corvic.table
17
16
  from corvic import op_graph, orm, system
corvic/model/_resource.py CHANGED
@@ -6,13 +6,12 @@ import copy
6
6
  import datetime
7
7
  import uuid
8
8
  from collections.abc import Iterable, Sequence
9
- from typing import TypeAlias
9
+ from typing import Self, TypeAlias
10
10
 
11
11
  import polars as pl
12
12
  import sqlalchemy as sa
13
13
  from sqlalchemy import orm as sa_orm
14
14
  from sqlalchemy.orm.interfaces import LoaderOption
15
- from typing_extensions import Self
16
15
 
17
16
  from corvic import orm, system
18
17
  from corvic.model._base_model import BelongsToRoomModel
corvic/model/_source.py CHANGED
@@ -6,13 +6,12 @@ import copy
6
6
  import datetime
7
7
  import functools
8
8
  from collections.abc import Iterable, Mapping, Sequence
9
- from typing import TypeAlias
9
+ from typing import Self, TypeAlias
10
10
 
11
11
  import polars as pl
12
12
  import sqlalchemy as sa
13
13
  import sqlalchemy.orm as sa_orm
14
14
  from sqlalchemy.orm.interfaces import LoaderOption
15
- from typing_extensions import Self
16
15
 
17
16
  from corvic import op_graph, orm, system
18
17
  from corvic.model._base_model import BelongsToRoomModel
corvic/model/_space.py CHANGED
@@ -3,15 +3,15 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import abc
6
+ import copy
6
7
  import datetime
7
8
  import uuid
8
9
  from collections.abc import Iterable, Mapping, Sequence
9
- from typing import Final, Literal, TypeAlias
10
+ from typing import Final, Literal, Self, TypeAlias
10
11
 
11
12
  import pyarrow as pa
12
13
  import sqlalchemy as sa
13
14
  from sqlalchemy import orm as sa_orm
14
- from typing_extensions import Self
15
15
 
16
16
  from corvic import op_graph, orm, system
17
17
  from corvic.model._base_model import BelongsToRoomModel
@@ -183,6 +183,30 @@ class Space(BelongsToRoomModel[SpaceID, models_pb2.Space, orm.Space]):
183
183
  auto_sync=auto_sync,
184
184
  )
185
185
 
186
+ def with_name(self, name: str):
187
+ proto_self = copy.copy(self.proto_self)
188
+
189
+ proto_self.name = name
190
+
191
+ return Ok(
192
+ self.__class__(
193
+ self.feature_view.client,
194
+ proto_self,
195
+ )
196
+ )
197
+
198
+ def with_description(self, description: str):
199
+ proto_self = copy.copy(self.proto_self)
200
+
201
+ proto_self.description = description
202
+
203
+ return Ok(
204
+ self.__class__(
205
+ self.feature_view.client,
206
+ proto_self,
207
+ )
208
+ )
209
+
186
210
  @classmethod
187
211
  def from_id(
188
212
  cls,
@@ -16,7 +16,7 @@ from corvic.op_graph.row_filters._row_filters import (
16
16
  lt,
17
17
  ne,
18
18
  )
19
- from corvic.pa_scalar import from_value
19
+ from corvic.pa_scalar import from_value, to_value
20
20
  from corvic.result import Error, InvalidArgumentError, Ok
21
21
 
22
22
 
@@ -76,6 +76,35 @@ def _var_name(value: struct_pb2.Value) -> str:
76
76
  raise _Error("unexpected operation type")
77
77
 
78
78
 
79
+ def _coerce_literal(literal: struct_pb2.Value, dtype: pa.DataType) -> struct_pb2.Value:
80
+ # Attempt to coerce the literal to the type it needs to be compared against,
81
+ # if the types don't already align.
82
+ match literal.WhichOneof("kind"):
83
+ case "null_value":
84
+ types_match = pa.types.is_null(dtype)
85
+ case "bool_value":
86
+ types_match = pa.types.is_boolean(dtype)
87
+ case "list_value":
88
+ # TODO(aneesh): inner checks for nested types
89
+ types_match = pa.types.is_list(dtype)
90
+ case "number_value":
91
+ types_match = pa.types.is_integer(dtype) or pa.types.is_floating(dtype)
92
+ case "string_value":
93
+ types_match = pa.types.is_string(dtype)
94
+ case "struct_value":
95
+ # TODO(aneesh): inner checks for nested types
96
+ types_match = pa.types.is_struct(dtype)
97
+ case None:
98
+ raise _Error("Unknown literal type")
99
+ if not types_match:
100
+ match from_value(literal, dtype):
101
+ case Ok(coerced_literal):
102
+ literal = to_value(coerced_literal)
103
+ case err:
104
+ raise err
105
+ return literal
106
+
107
+
79
108
  def _simple_compare(
80
109
  op: Literal["==", "!=", "<=", ">=", "<", ">"],
81
110
  operands: Sequence[struct_pb2.Value],
@@ -92,6 +121,8 @@ def _simple_compare(
92
121
  if dtype is None:
93
122
  raise _Error("unknown literal type", column_name=column_name)
94
123
 
124
+ literal = _coerce_literal(literal, dtype)
125
+
95
126
  match op:
96
127
  case "==":
97
128
  return eq(column_name, literal, dtype)
corvic/orm/base.py CHANGED
@@ -3,14 +3,13 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import uuid
6
- from datetime import datetime, timezone
7
- from typing import Any, ClassVar, Protocol, runtime_checkable
6
+ from datetime import UTC, datetime
7
+ from typing import Any, ClassVar, Protocol, Self, runtime_checkable
8
8
 
9
9
  import sqlalchemy as sa
10
10
  import sqlalchemy.orm as sa_orm
11
11
  from google.protobuf import timestamp_pb2
12
12
  from sqlalchemy.ext import hybrid
13
- from typing_extensions import Self
14
13
 
15
14
  from corvic.orm._proto_columns import ProtoMessageDecorator
16
15
  from corvic.orm.func import utc_now
@@ -151,7 +150,7 @@ class Base(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
151
150
  def created_at(self) -> datetime | None:
152
151
  if not self._created_at:
153
152
  return None
154
- return self._created_at.replace(tzinfo=timezone.utc)
153
+ return self._created_at.replace(tzinfo=UTC)
155
154
 
156
155
  @created_at.inplace.expression
157
156
  @classmethod
@@ -162,7 +161,7 @@ class Base(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
162
161
  def updated_at(self) -> datetime | None:
163
162
  if not self._updated_at:
164
163
  return None
165
- return self._updated_at.replace(tzinfo=timezone.utc)
164
+ return self._updated_at.replace(tzinfo=UTC)
166
165
 
167
166
  @updated_at.inplace.expression
168
167
  @classmethod
corvic/orm/ids.py CHANGED
@@ -3,11 +3,10 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import abc
6
- from typing import Any, Generic, TypeVar
6
+ from typing import Any, Generic, Self, TypeVar
7
7
 
8
8
  import sqlalchemy as sa
9
9
  import sqlalchemy.types as sa_types
10
- from typing_extensions import Self
11
10
 
12
11
  import corvic.context
13
12
  from corvic.orm.errors import InvalidORMIdentifierError
corvic/orm/mixins.py CHANGED
@@ -3,8 +3,8 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable, Sequence
6
- from datetime import datetime, timezone
7
- from typing import Any, cast
6
+ from datetime import UTC, datetime
7
+ from typing import Any, LiteralString, cast
8
8
 
9
9
  import sqlalchemy as sa
10
10
  from google.protobuf import timestamp_pb2
@@ -12,7 +12,6 @@ from sqlalchemy import event, exc
12
12
  from sqlalchemy import orm as sa_orm
13
13
  from sqlalchemy.ext import hybrid
14
14
  from sqlalchemy.ext.hybrid import hybrid_property
15
- from typing_extensions import LiteralString
16
15
 
17
16
  import corvic.context
18
17
  from corvic.orm.base import EventBase, EventKey, OrgBase
@@ -137,7 +136,7 @@ class SoftDeleteMixin(sa_orm.MappedAsDataclass):
137
136
  def deleted_at(self) -> datetime | None:
138
137
  if not self._deleted_at:
139
138
  return None
140
- return self._deleted_at.replace(tzinfo=timezone.utc)
139
+ return self._deleted_at.replace(tzinfo=UTC)
141
140
 
142
141
  def reset_delete(self):
143
142
  self._deleted_at = None
@@ -170,7 +169,7 @@ class SoftDeleteMixin(sa_orm.MappedAsDataclass):
170
169
  # set is_live to None instead of False so that orm objects can use it to
171
170
  # build uniqueness constraints that are only enforced on non-deleted objects
172
171
  self.is_live = None
173
- self._deleted_at = datetime.now(tz=timezone.utc)
172
+ self._deleted_at = datetime.now(tz=UTC)
174
173
 
175
174
  @hybrid_property
176
175
  def is_deleted(self) -> bool:
@@ -343,7 +342,7 @@ class Session(sa_orm.Session):
343
342
  def _timestamp_or_utc_now(timestamp: datetime | None = None):
344
343
  if timestamp is not None:
345
344
  return timestamp
346
- return datetime.now(tz=timezone.utc)
345
+ return datetime.now(tz=UTC)
347
346
 
348
347
 
349
348
  class EventLoggerMixin(sa_orm.MappedAsDataclass):
@@ -383,12 +382,11 @@ class EventLoggerMixin(sa_orm.MappedAsDataclass):
383
382
  # this can occur when an event is set on a new object
384
383
  if not self._event_src_id:
385
384
  obj_session.flush()
386
- from datetime import timezone
387
385
 
388
386
  obj_session.add(
389
387
  EventBase(
390
388
  event=event.event_type,
391
- timestamp=event.timestamp.ToDatetime(tzinfo=timezone.utc),
389
+ timestamp=event.timestamp.ToDatetime(tzinfo=UTC),
392
390
  regarding=event.regarding,
393
391
  reason=event.reason,
394
392
  event_key=str(self.event_key),
@@ -86,7 +86,7 @@ def datetime_fromisoformat(s: str) -> datetime.datetime:
86
86
 
87
87
  tzinfo = None
88
88
  if m.group("utc"):
89
- tzinfo = datetime.timezone.utc
89
+ tzinfo = datetime.UTC
90
90
 
91
91
  tz_hour = m.group("tz_hour")
92
92
  if tz_hour:
corvic/result/__init__.py CHANGED
@@ -55,12 +55,11 @@ from typing import (
55
55
  Literal,
56
56
  NoReturn,
57
57
  ParamSpec,
58
+ Self,
58
59
  TypeVar,
59
60
  overload,
60
61
  )
61
62
 
62
- from typing_extensions import Self
63
-
64
63
  from corvic.well_known_types import JSONAble, JSONExpressable, to_json
65
64
 
66
65
  T_co = TypeVar("T_co", covariant=True)
@@ -0,0 +1,215 @@
1
+ import math
2
+ from typing import Final, cast
3
+
4
+ import numpy as np
5
+ import polars as pl
6
+ import structlog
7
+
8
+ REFERENCE_YEAR: Final = 1900
9
+ """Reference year for normalizing year in Datetime encoder"""
10
+
11
+ MAX_NUMBER_OF_YEARS: Final = 200
12
+ """Maximum number of years for normalizing year in Datetime encoder"""
13
+
14
+ _logger = structlog.get_logger()
15
+
16
+
17
+ def encode_one_hot(to_encode: pl.Series) -> tuple[pl.Series, list[str]]:
18
+ encoded = to_encode.to_dummies()
19
+ return (
20
+ encoded.select(
21
+ pl.concat_list(pl.all()).alias("val").cast(pl.List(pl.Boolean))
22
+ ).to_series(),
23
+ encoded.columns,
24
+ )
25
+
26
+
27
+ def encode_min_max_scale(
28
+ to_encode: pl.Series, range_min: float, range_max: float
29
+ ) -> pl.Series:
30
+ from sklearn.preprocessing import MinMaxScaler
31
+
32
+ encoder = MinMaxScaler(
33
+ feature_range=(
34
+ range_min,
35
+ range_max,
36
+ )
37
+ )
38
+ return pl.Series(
39
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
40
+ )
41
+
42
+
43
+ def encode_label_boolean(
44
+ to_encode: pl.Series, neg_label: int, pos_label: int
45
+ ) -> pl.Series:
46
+ from sklearn.preprocessing import LabelBinarizer
47
+
48
+ encoder = LabelBinarizer(
49
+ neg_label=neg_label,
50
+ pos_label=pos_label,
51
+ )
52
+ return pl.Series(encoder.fit_transform(to_encode.to_numpy().reshape(-1)))
53
+
54
+
55
+ def encode_label(to_encode: pl.Series, *, normalize: bool) -> pl.Series:
56
+ from sklearn.preprocessing import LabelEncoder
57
+
58
+ encoder = LabelEncoder()
59
+ encoded = encoder.fit_transform(to_encode.to_numpy().reshape(-1)).flatten()
60
+ # `classes_` is only set after fit,
61
+ # Creating custom typestubs will not solve this typing issue.
62
+ if normalize and hasattr(encoder, "classes_"):
63
+ classes_ = cast(list[int], encoder.classes_) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
64
+ max_class: int = len(classes_) - 1
65
+ if max_class > 0:
66
+ encoded = encoded.astype(np.float64)
67
+ encoded /= max_class
68
+
69
+ return pl.Series(encoded)
70
+
71
+
72
+ def encode_kbins(
73
+ to_encode: pl.Series, n_bins: int, method: str, strategy: str
74
+ ) -> pl.Series:
75
+ from sklearn.preprocessing import KBinsDiscretizer
76
+
77
+ encoder = KBinsDiscretizer(
78
+ n_bins=n_bins,
79
+ encode=method,
80
+ strategy=strategy,
81
+ dtype=np.float32,
82
+ )
83
+ return pl.Series(
84
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
85
+ )
86
+
87
+
88
+ def encode_boolean(to_encode: pl.Series, threshold: float) -> pl.Series:
89
+ from sklearn.preprocessing import Binarizer
90
+
91
+ encoder = Binarizer(
92
+ threshold=threshold,
93
+ )
94
+ return pl.Series(
95
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
96
+ )
97
+
98
+
99
+ def encode_max_abs_scale(to_encode: pl.Series) -> pl.Series:
100
+ from sklearn.preprocessing import MaxAbsScaler
101
+
102
+ encoder = MaxAbsScaler()
103
+ try:
104
+ encoded = encoder.fit_transform(
105
+ np.nan_to_num(to_encode.to_numpy()).reshape(-1, 1)
106
+ ).flatten()
107
+ except ValueError:
108
+ encoded = np.array([])
109
+
110
+ return pl.Series(encoded)
111
+
112
+
113
+ def encode_standard_scale(
114
+ to_encode: pl.Series, *, with_mean: bool, with_std: bool
115
+ ) -> pl.Series:
116
+ from sklearn.preprocessing import StandardScaler
117
+
118
+ encoder = StandardScaler(
119
+ with_mean=with_mean,
120
+ with_std=with_std,
121
+ )
122
+ return pl.Series(
123
+ encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
124
+ )
125
+
126
+
127
+ def encode_duration(to_encode: pl.Series) -> pl.Series:
128
+ if to_encode.dtype != pl.Duration:
129
+ raise ValueError("Invalid arguments, expected a duration series")
130
+ if to_encode.is_null().all():
131
+ return pl.zeros(len(to_encode), dtype=pl.Float32, eager=True)
132
+
133
+ return to_encode.dt.total_seconds().cast(pl.Float32).fill_null(0.0)
134
+
135
+
136
+ def _get_cyclic_encoding(
137
+ to_encode: pl.Series,
138
+ period: int,
139
+ ) -> tuple[pl.Series, pl.Series]:
140
+ sine_series = (
141
+ (2 * math.pi * to_encode / period).sin().alias(f"{to_encode.name}_sine")
142
+ )
143
+ cosine_series = (
144
+ (2 * math.pi * to_encode / period).cos().alias(f"{to_encode.name}_cosine")
145
+ )
146
+ return sine_series, cosine_series
147
+
148
+
149
+ def encode_datetime(to_encode: pl.Series) -> pl.Series:
150
+ match to_encode.dtype:
151
+ case pl.Date | pl.Time:
152
+ pass
153
+ case pl.Datetime:
154
+ to_encode = to_encode.dt.replace_time_zone("UTC")
155
+ case _:
156
+ raise ValueError(
157
+ "Invalid arguments column could not be endoded as datetime"
158
+ )
159
+
160
+ if to_encode.is_null().all():
161
+ zero_vector = pl.zeros(11, dtype=pl.Float32, eager=True)
162
+ return pl.Series([zero_vector] * len(to_encode), dtype=pl.List(pl.Float32))
163
+
164
+ n = len(to_encode)
165
+ year_norm = pl.zeros(n, dtype=pl.Float32, eager=True).alias("year")
166
+ month_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_sine")
167
+ month_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_cosine")
168
+ day_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_sine")
169
+ day_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_cosine")
170
+ hour_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_sine")
171
+ hour_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_cosine")
172
+ minute_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_sine")
173
+ minute_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_cosine")
174
+ second_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_sine")
175
+ second_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_cosine")
176
+
177
+ if to_encode.dtype in [pl.Date, pl.Datetime]:
178
+ try:
179
+ year = to_encode.dt.year().cast(pl.Float32).alias("year")
180
+ month = to_encode.dt.month().cast(pl.Float32).alias("month")
181
+ day = to_encode.dt.day().cast(pl.Float32).alias("day")
182
+
183
+ year_norm = (year - REFERENCE_YEAR) / MAX_NUMBER_OF_YEARS
184
+ month_sine, month_cosine = _get_cyclic_encoding(month, 12)
185
+ day_sine, day_cosine = _get_cyclic_encoding(day, 31)
186
+ except pl.exceptions.PanicException as e:
187
+ _logger.exception("Error extracting datetime", exc_info=e)
188
+
189
+ if to_encode.dtype in [pl.Time, pl.Datetime]:
190
+ try:
191
+ hour = to_encode.dt.hour().cast(pl.Float32).alias("hour")
192
+ minute = to_encode.dt.minute().cast(pl.Float32).alias("minute")
193
+ second = to_encode.dt.second().cast(pl.Float32).alias("second")
194
+
195
+ hour_sine, hour_cosine = _get_cyclic_encoding(hour, 24)
196
+ minute_sine, minute_cosine = _get_cyclic_encoding(minute, 60)
197
+ second_sine, second_cosine = _get_cyclic_encoding(second, 60)
198
+ except pl.exceptions.PanicException as e:
199
+ _logger.exception("Error extracting datetime", exc_info=e)
200
+
201
+ return pl.DataFrame(
202
+ [
203
+ year_norm.fill_null(0.0),
204
+ month_sine.fill_null(0.0),
205
+ month_cosine.fill_null(0.0),
206
+ day_sine.fill_null(0.0),
207
+ day_cosine.fill_null(0.0),
208
+ hour_sine.fill_null(0.0),
209
+ hour_cosine.fill_null(0.0),
210
+ minute_sine.fill_null(0.0),
211
+ minute_cosine.fill_null(0.0),
212
+ second_sine.fill_null(0.0),
213
+ second_cosine.fill_null(0.0),
214
+ ]
215
+ ).select(pl.concat_list(pl.all()).alias(to_encode.name))[to_encode.name]
@@ -1,10 +1,11 @@
1
+ import asyncio
1
2
  import dataclasses
2
3
  from collections.abc import Sequence
3
- from typing import TYPE_CHECKING, Any, Literal
4
+ from concurrent.futures import ThreadPoolExecutor
5
+ from typing import TYPE_CHECKING, Any, Literal, Protocol
4
6
 
5
7
  import numpy as np
6
8
  import polars as pl
7
- from typing_extensions import Protocol
8
9
 
9
10
  from corvic import orm
10
11
  from corvic.result import InternalError, InvalidArgumentError, Ok
@@ -43,6 +44,12 @@ class TextEmbedder(Protocol):
43
44
  self, context: EmbedTextContext
44
45
  ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError: ...
45
46
 
47
+ async def aembed(
48
+ self,
49
+ context: EmbedTextContext,
50
+ worker_threads: ThreadPoolExecutor | None = None,
51
+ ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError: ...
52
+
46
53
 
47
54
  @dataclasses.dataclass
48
55
  class EmbedImageContext:
@@ -69,6 +76,12 @@ class ImageEmbedder(Protocol):
69
76
  self, context: EmbedImageContext
70
77
  ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
71
78
 
79
+ async def aembed(
80
+ self,
81
+ context: EmbedImageContext,
82
+ worker_threads: ThreadPoolExecutor | None = None,
83
+ ) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
84
+
72
85
 
73
86
  @dataclasses.dataclass
74
87
  class ClipModels:
@@ -142,3 +155,12 @@ class ClipText(TextEmbedder):
142
155
  ),
143
156
  )
144
157
  )
158
+
159
+ async def aembed(
160
+ self,
161
+ context: EmbedTextContext,
162
+ worker_threads: ThreadPoolExecutor | None = None,
163
+ ) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError:
164
+ return await asyncio.get_running_loop().run_in_executor(
165
+ worker_threads, self.embed, context
166
+ )