corvic-engine 0.3.0rc67__cp38-abi3-win_amd64.whl → 0.3.0rc69__cp38-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- corvic/context/__init__.py +0 -8
- corvic/engine/_native.pyd +0 -0
- corvic/model/_base_model.py +3 -4
- corvic/model/_completion_model.py +2 -4
- corvic/model/_feature_view.py +5 -6
- corvic/model/_pipeline.py +1 -2
- corvic/model/_resource.py +1 -2
- corvic/model/_source.py +1 -2
- corvic/model/_space.py +26 -2
- corvic/op_graph/row_filters/_jsonlogic.py +32 -1
- corvic/orm/base.py +4 -5
- corvic/orm/ids.py +1 -2
- corvic/orm/mixins.py +6 -8
- corvic/pa_scalar/_temporal.py +1 -1
- corvic/result/__init__.py +1 -2
- corvic/system/_column_encoding.py +215 -0
- corvic/system/_embedder.py +24 -2
- corvic/system/_image_embedder.py +38 -0
- corvic/system/_planner.py +6 -3
- corvic/system/_text_embedder.py +21 -0
- corvic/system/client.py +2 -1
- corvic/system/in_memory_executor.py +503 -507
- corvic/system/op_graph_executor.py +7 -3
- corvic/system/storage.py +1 -3
- corvic/table/table.py +5 -5
- {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/METADATA +3 -4
- {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/RECORD +35 -34
- corvic_generated/feature/v2/feature_view_pb2.py +21 -21
- corvic_generated/feature/v2/space_pb2.py +59 -51
- corvic_generated/feature/v2/space_pb2.pyi +12 -6
- corvic_generated/ingest/v2/resource_pb2.py +25 -25
- corvic_generated/orm/v1/agent_pb2.py +2 -2
- corvic_generated/orm/v1/agent_pb2.pyi +4 -0
- {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/WHEEL +0 -0
- {corvic_engine-0.3.0rc67.dist-info → corvic_engine-0.3.0rc69.dist-info}/licenses/LICENSE +0 -0
corvic/context/__init__.py
CHANGED
@@ -4,7 +4,6 @@ Affect things like logging and the names of metrics.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import contextvars
|
7
|
-
import uuid
|
8
7
|
from dataclasses import dataclass
|
9
8
|
|
10
9
|
# These are sentinels used only in the Requester object below rather than actual org
|
@@ -21,17 +20,11 @@ class Requester:
|
|
21
20
|
|
22
21
|
|
23
22
|
_SERVICE_NAME = contextvars.ContextVar("service_name", default="corvic")
|
24
|
-
_TRACE_ID = contextvars.ContextVar("trace_id", default="")
|
25
23
|
_REQUESTER = contextvars.ContextVar(
|
26
24
|
"requester_identity", default=Requester(org_id=NOBODY_ORG_ID)
|
27
25
|
)
|
28
26
|
|
29
27
|
|
30
|
-
def get_trace_id() -> str:
|
31
|
-
"""Get current trace id."""
|
32
|
-
return _TRACE_ID.get()
|
33
|
-
|
34
|
-
|
35
28
|
def get_service_name() -> str:
|
36
29
|
"""Get current service name."""
|
37
30
|
return _SERVICE_NAME.get()
|
@@ -45,7 +38,6 @@ def get_requester() -> Requester:
|
|
45
38
|
def reset_context(*, service_name: str):
|
46
39
|
"""Reset contextvars for a new request."""
|
47
40
|
_SERVICE_NAME.set(service_name)
|
48
|
-
_TRACE_ID.set(str(uuid.uuid4()))
|
49
41
|
_REQUESTER.set(Requester(org_id=NOBODY_ORG_ID))
|
50
42
|
|
51
43
|
|
corvic/engine/_native.pyd
CHANGED
Binary file
|
corvic/model/_base_model.py
CHANGED
@@ -5,12 +5,11 @@ import datetime
|
|
5
5
|
import functools
|
6
6
|
import uuid
|
7
7
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
8
|
-
from typing import Final, Generic
|
8
|
+
from typing import Final, Generic, Self
|
9
9
|
|
10
10
|
import sqlalchemy as sa
|
11
11
|
import sqlalchemy.orm as sa_orm
|
12
12
|
import structlog
|
13
|
-
from typing_extensions import Self
|
14
13
|
|
15
14
|
from corvic import orm, system
|
16
15
|
from corvic.model._proto_orm_convert import (
|
@@ -54,7 +53,7 @@ class HasProtoSelf(Generic[ProtoObj], abc.ABC):
|
|
54
53
|
@property
|
55
54
|
def created_at(self) -> datetime.datetime | None:
|
56
55
|
if self.proto_self.created_at:
|
57
|
-
return self.proto_self.created_at.ToDatetime(tzinfo=datetime.
|
56
|
+
return self.proto_self.created_at.ToDatetime(tzinfo=datetime.UTC)
|
58
57
|
return None
|
59
58
|
|
60
59
|
|
@@ -119,7 +118,7 @@ class BaseModel(Generic[ID, ProtoObj, OrmObj], UsesOrmID[ID, ProtoObj]):
|
|
119
118
|
while True:
|
120
119
|
try:
|
121
120
|
yield from it
|
122
|
-
except Exception:
|
121
|
+
except Exception:
|
123
122
|
_logger.exception(
|
124
123
|
"omitting source from list: "
|
125
124
|
+ "failed to parse source from database entry",
|
@@ -116,16 +116,14 @@ class CompletionModel(
|
|
116
116
|
@property
|
117
117
|
def last_validation_time(self) -> datetime.datetime | None:
|
118
118
|
if self.proto_self.last_validation_time != UNIX_TIMESTAMP_START_DATETIME:
|
119
|
-
return self.proto_self.last_validation_time.ToDatetime(
|
120
|
-
tzinfo=datetime.timezone.utc
|
121
|
-
)
|
119
|
+
return self.proto_self.last_validation_time.ToDatetime(tzinfo=datetime.UTC)
|
122
120
|
return None
|
123
121
|
|
124
122
|
@property
|
125
123
|
def last_successful_validation(self) -> datetime.datetime | None:
|
126
124
|
if self.proto_self.last_successful_validation != UNIX_TIMESTAMP_START_DATETIME:
|
127
125
|
return self.proto_self.last_successful_validation.ToDatetime(
|
128
|
-
tzinfo=datetime.
|
126
|
+
tzinfo=datetime.UTC
|
129
127
|
)
|
130
128
|
return None
|
131
129
|
|
corvic/model/_feature_view.py
CHANGED
@@ -7,15 +7,14 @@ import dataclasses
|
|
7
7
|
import datetime
|
8
8
|
import functools
|
9
9
|
import uuid
|
10
|
-
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
|
11
|
-
from typing import Any, Final, TypeAlias
|
10
|
+
from collections.abc import AsyncIterable, Iterable, Mapping, MutableMapping, Sequence
|
11
|
+
from typing import Any, Final, Self, TypeAlias
|
12
12
|
|
13
13
|
import pyarrow as pa
|
14
14
|
from google.protobuf import struct_pb2
|
15
15
|
from more_itertools import flatten
|
16
16
|
from sqlalchemy import orm as sa_orm
|
17
17
|
from sqlalchemy.orm.interfaces import LoaderOption
|
18
|
-
from typing_extensions import Self
|
19
18
|
|
20
19
|
from corvic import op_graph, orm, system
|
21
20
|
from corvic.model._base_model import BelongsToRoomModel, UsesOrmID
|
@@ -324,7 +323,7 @@ class Relationship:
|
|
324
323
|
how="inner",
|
325
324
|
)
|
326
325
|
|
327
|
-
def edge_list(self) ->
|
326
|
+
async def edge_list(self) -> AsyncIterable[tuple[Any, Any]]:
|
328
327
|
start_pk = self.start_fv_source.table.schema.get_primary_key()
|
329
328
|
end_pk = self.end_fv_source.table.schema.get_primary_key()
|
330
329
|
|
@@ -340,8 +339,8 @@ class Relationship:
|
|
340
339
|
|
341
340
|
result = self.joined_table().select(result_columns)
|
342
341
|
|
343
|
-
for batch in
|
344
|
-
room_id=self.start_source.room_id
|
342
|
+
for batch in (
|
343
|
+
await result.to_polars(room_id=self.start_source.room_id)
|
345
344
|
).unwrap_or_raise():
|
346
345
|
for row in batch.rows(named=True):
|
347
346
|
yield (row[result_columns[0]], row[result_columns[1]])
|
corvic/model/_pipeline.py
CHANGED
@@ -6,12 +6,11 @@ import datetime
|
|
6
6
|
import functools
|
7
7
|
import uuid
|
8
8
|
from collections.abc import Iterable, Mapping, Sequence
|
9
|
-
from typing import TypeAlias, cast
|
9
|
+
from typing import Self, TypeAlias, cast
|
10
10
|
|
11
11
|
import polars as pl
|
12
12
|
from sqlalchemy import orm as sa_orm
|
13
13
|
from sqlalchemy.orm.interfaces import LoaderOption
|
14
|
-
from typing_extensions import Self
|
15
14
|
|
16
15
|
import corvic.table
|
17
16
|
from corvic import op_graph, orm, system
|
corvic/model/_resource.py
CHANGED
@@ -6,13 +6,12 @@ import copy
|
|
6
6
|
import datetime
|
7
7
|
import uuid
|
8
8
|
from collections.abc import Iterable, Sequence
|
9
|
-
from typing import TypeAlias
|
9
|
+
from typing import Self, TypeAlias
|
10
10
|
|
11
11
|
import polars as pl
|
12
12
|
import sqlalchemy as sa
|
13
13
|
from sqlalchemy import orm as sa_orm
|
14
14
|
from sqlalchemy.orm.interfaces import LoaderOption
|
15
|
-
from typing_extensions import Self
|
16
15
|
|
17
16
|
from corvic import orm, system
|
18
17
|
from corvic.model._base_model import BelongsToRoomModel
|
corvic/model/_source.py
CHANGED
@@ -6,13 +6,12 @@ import copy
|
|
6
6
|
import datetime
|
7
7
|
import functools
|
8
8
|
from collections.abc import Iterable, Mapping, Sequence
|
9
|
-
from typing import TypeAlias
|
9
|
+
from typing import Self, TypeAlias
|
10
10
|
|
11
11
|
import polars as pl
|
12
12
|
import sqlalchemy as sa
|
13
13
|
import sqlalchemy.orm as sa_orm
|
14
14
|
from sqlalchemy.orm.interfaces import LoaderOption
|
15
|
-
from typing_extensions import Self
|
16
15
|
|
17
16
|
from corvic import op_graph, orm, system
|
18
17
|
from corvic.model._base_model import BelongsToRoomModel
|
corvic/model/_space.py
CHANGED
@@ -3,15 +3,15 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import abc
|
6
|
+
import copy
|
6
7
|
import datetime
|
7
8
|
import uuid
|
8
9
|
from collections.abc import Iterable, Mapping, Sequence
|
9
|
-
from typing import Final, Literal, TypeAlias
|
10
|
+
from typing import Final, Literal, Self, TypeAlias
|
10
11
|
|
11
12
|
import pyarrow as pa
|
12
13
|
import sqlalchemy as sa
|
13
14
|
from sqlalchemy import orm as sa_orm
|
14
|
-
from typing_extensions import Self
|
15
15
|
|
16
16
|
from corvic import op_graph, orm, system
|
17
17
|
from corvic.model._base_model import BelongsToRoomModel
|
@@ -183,6 +183,30 @@ class Space(BelongsToRoomModel[SpaceID, models_pb2.Space, orm.Space]):
|
|
183
183
|
auto_sync=auto_sync,
|
184
184
|
)
|
185
185
|
|
186
|
+
def with_name(self, name: str):
|
187
|
+
proto_self = copy.copy(self.proto_self)
|
188
|
+
|
189
|
+
proto_self.name = name
|
190
|
+
|
191
|
+
return Ok(
|
192
|
+
self.__class__(
|
193
|
+
self.feature_view.client,
|
194
|
+
proto_self,
|
195
|
+
)
|
196
|
+
)
|
197
|
+
|
198
|
+
def with_description(self, description: str):
|
199
|
+
proto_self = copy.copy(self.proto_self)
|
200
|
+
|
201
|
+
proto_self.description = description
|
202
|
+
|
203
|
+
return Ok(
|
204
|
+
self.__class__(
|
205
|
+
self.feature_view.client,
|
206
|
+
proto_self,
|
207
|
+
)
|
208
|
+
)
|
209
|
+
|
186
210
|
@classmethod
|
187
211
|
def from_id(
|
188
212
|
cls,
|
@@ -16,7 +16,7 @@ from corvic.op_graph.row_filters._row_filters import (
|
|
16
16
|
lt,
|
17
17
|
ne,
|
18
18
|
)
|
19
|
-
from corvic.pa_scalar import from_value
|
19
|
+
from corvic.pa_scalar import from_value, to_value
|
20
20
|
from corvic.result import Error, InvalidArgumentError, Ok
|
21
21
|
|
22
22
|
|
@@ -76,6 +76,35 @@ def _var_name(value: struct_pb2.Value) -> str:
|
|
76
76
|
raise _Error("unexpected operation type")
|
77
77
|
|
78
78
|
|
79
|
+
def _coerce_literal(literal: struct_pb2.Value, dtype: pa.DataType) -> struct_pb2.Value:
|
80
|
+
# Attempt to coerce the literal to the type it needs to be compared against,
|
81
|
+
# if the types don't already align.
|
82
|
+
match literal.WhichOneof("kind"):
|
83
|
+
case "null_value":
|
84
|
+
types_match = pa.types.is_null(dtype)
|
85
|
+
case "bool_value":
|
86
|
+
types_match = pa.types.is_boolean(dtype)
|
87
|
+
case "list_value":
|
88
|
+
# TODO(aneesh): inner checks for nested types
|
89
|
+
types_match = pa.types.is_list(dtype)
|
90
|
+
case "number_value":
|
91
|
+
types_match = pa.types.is_integer(dtype) or pa.types.is_floating(dtype)
|
92
|
+
case "string_value":
|
93
|
+
types_match = pa.types.is_string(dtype)
|
94
|
+
case "struct_value":
|
95
|
+
# TODO(aneesh): inner checks for nested types
|
96
|
+
types_match = pa.types.is_struct(dtype)
|
97
|
+
case None:
|
98
|
+
raise _Error("Unknown literal type")
|
99
|
+
if not types_match:
|
100
|
+
match from_value(literal, dtype):
|
101
|
+
case Ok(coerced_literal):
|
102
|
+
literal = to_value(coerced_literal)
|
103
|
+
case err:
|
104
|
+
raise err
|
105
|
+
return literal
|
106
|
+
|
107
|
+
|
79
108
|
def _simple_compare(
|
80
109
|
op: Literal["==", "!=", "<=", ">=", "<", ">"],
|
81
110
|
operands: Sequence[struct_pb2.Value],
|
@@ -92,6 +121,8 @@ def _simple_compare(
|
|
92
121
|
if dtype is None:
|
93
122
|
raise _Error("unknown literal type", column_name=column_name)
|
94
123
|
|
124
|
+
literal = _coerce_literal(literal, dtype)
|
125
|
+
|
95
126
|
match op:
|
96
127
|
case "==":
|
97
128
|
return eq(column_name, literal, dtype)
|
corvic/orm/base.py
CHANGED
@@ -3,14 +3,13 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import uuid
|
6
|
-
from datetime import
|
7
|
-
from typing import Any, ClassVar, Protocol, runtime_checkable
|
6
|
+
from datetime import UTC, datetime
|
7
|
+
from typing import Any, ClassVar, Protocol, Self, runtime_checkable
|
8
8
|
|
9
9
|
import sqlalchemy as sa
|
10
10
|
import sqlalchemy.orm as sa_orm
|
11
11
|
from google.protobuf import timestamp_pb2
|
12
12
|
from sqlalchemy.ext import hybrid
|
13
|
-
from typing_extensions import Self
|
14
13
|
|
15
14
|
from corvic.orm._proto_columns import ProtoMessageDecorator
|
16
15
|
from corvic.orm.func import utc_now
|
@@ -151,7 +150,7 @@ class Base(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
|
|
151
150
|
def created_at(self) -> datetime | None:
|
152
151
|
if not self._created_at:
|
153
152
|
return None
|
154
|
-
return self._created_at.replace(tzinfo=
|
153
|
+
return self._created_at.replace(tzinfo=UTC)
|
155
154
|
|
156
155
|
@created_at.inplace.expression
|
157
156
|
@classmethod
|
@@ -162,7 +161,7 @@ class Base(sa_orm.MappedAsDataclass, sa_orm.DeclarativeBase):
|
|
162
161
|
def updated_at(self) -> datetime | None:
|
163
162
|
if not self._updated_at:
|
164
163
|
return None
|
165
|
-
return self._updated_at.replace(tzinfo=
|
164
|
+
return self._updated_at.replace(tzinfo=UTC)
|
166
165
|
|
167
166
|
@updated_at.inplace.expression
|
168
167
|
@classmethod
|
corvic/orm/ids.py
CHANGED
@@ -3,11 +3,10 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import abc
|
6
|
-
from typing import Any, Generic, TypeVar
|
6
|
+
from typing import Any, Generic, Self, TypeVar
|
7
7
|
|
8
8
|
import sqlalchemy as sa
|
9
9
|
import sqlalchemy.types as sa_types
|
10
|
-
from typing_extensions import Self
|
11
10
|
|
12
11
|
import corvic.context
|
13
12
|
from corvic.orm.errors import InvalidORMIdentifierError
|
corvic/orm/mixins.py
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
from collections.abc import Callable, Sequence
|
6
|
-
from datetime import
|
7
|
-
from typing import Any, cast
|
6
|
+
from datetime import UTC, datetime
|
7
|
+
from typing import Any, LiteralString, cast
|
8
8
|
|
9
9
|
import sqlalchemy as sa
|
10
10
|
from google.protobuf import timestamp_pb2
|
@@ -12,7 +12,6 @@ from sqlalchemy import event, exc
|
|
12
12
|
from sqlalchemy import orm as sa_orm
|
13
13
|
from sqlalchemy.ext import hybrid
|
14
14
|
from sqlalchemy.ext.hybrid import hybrid_property
|
15
|
-
from typing_extensions import LiteralString
|
16
15
|
|
17
16
|
import corvic.context
|
18
17
|
from corvic.orm.base import EventBase, EventKey, OrgBase
|
@@ -137,7 +136,7 @@ class SoftDeleteMixin(sa_orm.MappedAsDataclass):
|
|
137
136
|
def deleted_at(self) -> datetime | None:
|
138
137
|
if not self._deleted_at:
|
139
138
|
return None
|
140
|
-
return self._deleted_at.replace(tzinfo=
|
139
|
+
return self._deleted_at.replace(tzinfo=UTC)
|
141
140
|
|
142
141
|
def reset_delete(self):
|
143
142
|
self._deleted_at = None
|
@@ -170,7 +169,7 @@ class SoftDeleteMixin(sa_orm.MappedAsDataclass):
|
|
170
169
|
# set is_live to None instead of False so that orm objects can use it to
|
171
170
|
# build uniqueness constraints that are only enforced on non-deleted objects
|
172
171
|
self.is_live = None
|
173
|
-
self._deleted_at = datetime.now(tz=
|
172
|
+
self._deleted_at = datetime.now(tz=UTC)
|
174
173
|
|
175
174
|
@hybrid_property
|
176
175
|
def is_deleted(self) -> bool:
|
@@ -343,7 +342,7 @@ class Session(sa_orm.Session):
|
|
343
342
|
def _timestamp_or_utc_now(timestamp: datetime | None = None):
|
344
343
|
if timestamp is not None:
|
345
344
|
return timestamp
|
346
|
-
return datetime.now(tz=
|
345
|
+
return datetime.now(tz=UTC)
|
347
346
|
|
348
347
|
|
349
348
|
class EventLoggerMixin(sa_orm.MappedAsDataclass):
|
@@ -383,12 +382,11 @@ class EventLoggerMixin(sa_orm.MappedAsDataclass):
|
|
383
382
|
# this can occur when an event is set on a new object
|
384
383
|
if not self._event_src_id:
|
385
384
|
obj_session.flush()
|
386
|
-
from datetime import timezone
|
387
385
|
|
388
386
|
obj_session.add(
|
389
387
|
EventBase(
|
390
388
|
event=event.event_type,
|
391
|
-
timestamp=event.timestamp.ToDatetime(tzinfo=
|
389
|
+
timestamp=event.timestamp.ToDatetime(tzinfo=UTC),
|
392
390
|
regarding=event.regarding,
|
393
391
|
reason=event.reason,
|
394
392
|
event_key=str(self.event_key),
|
corvic/pa_scalar/_temporal.py
CHANGED
corvic/result/__init__.py
CHANGED
@@ -55,12 +55,11 @@ from typing import (
|
|
55
55
|
Literal,
|
56
56
|
NoReturn,
|
57
57
|
ParamSpec,
|
58
|
+
Self,
|
58
59
|
TypeVar,
|
59
60
|
overload,
|
60
61
|
)
|
61
62
|
|
62
|
-
from typing_extensions import Self
|
63
|
-
|
64
63
|
from corvic.well_known_types import JSONAble, JSONExpressable, to_json
|
65
64
|
|
66
65
|
T_co = TypeVar("T_co", covariant=True)
|
@@ -0,0 +1,215 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Final, cast
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import polars as pl
|
6
|
+
import structlog
|
7
|
+
|
8
|
+
REFERENCE_YEAR: Final = 1900
|
9
|
+
"""Reference year for normalizing year in Datetime encoder"""
|
10
|
+
|
11
|
+
MAX_NUMBER_OF_YEARS: Final = 200
|
12
|
+
"""Maximum number of years for normalizing year in Datetime encoder"""
|
13
|
+
|
14
|
+
_logger = structlog.get_logger()
|
15
|
+
|
16
|
+
|
17
|
+
def encode_one_hot(to_encode: pl.Series) -> tuple[pl.Series, list[str]]:
|
18
|
+
encoded = to_encode.to_dummies()
|
19
|
+
return (
|
20
|
+
encoded.select(
|
21
|
+
pl.concat_list(pl.all()).alias("val").cast(pl.List(pl.Boolean))
|
22
|
+
).to_series(),
|
23
|
+
encoded.columns,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def encode_min_max_scale(
|
28
|
+
to_encode: pl.Series, range_min: float, range_max: float
|
29
|
+
) -> pl.Series:
|
30
|
+
from sklearn.preprocessing import MinMaxScaler
|
31
|
+
|
32
|
+
encoder = MinMaxScaler(
|
33
|
+
feature_range=(
|
34
|
+
range_min,
|
35
|
+
range_max,
|
36
|
+
)
|
37
|
+
)
|
38
|
+
return pl.Series(
|
39
|
+
encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
def encode_label_boolean(
|
44
|
+
to_encode: pl.Series, neg_label: int, pos_label: int
|
45
|
+
) -> pl.Series:
|
46
|
+
from sklearn.preprocessing import LabelBinarizer
|
47
|
+
|
48
|
+
encoder = LabelBinarizer(
|
49
|
+
neg_label=neg_label,
|
50
|
+
pos_label=pos_label,
|
51
|
+
)
|
52
|
+
return pl.Series(encoder.fit_transform(to_encode.to_numpy().reshape(-1)))
|
53
|
+
|
54
|
+
|
55
|
+
def encode_label(to_encode: pl.Series, *, normalize: bool) -> pl.Series:
|
56
|
+
from sklearn.preprocessing import LabelEncoder
|
57
|
+
|
58
|
+
encoder = LabelEncoder()
|
59
|
+
encoded = encoder.fit_transform(to_encode.to_numpy().reshape(-1)).flatten()
|
60
|
+
# `classes_` is only set after fit,
|
61
|
+
# Creating custom typestubs will not solve this typing issue.
|
62
|
+
if normalize and hasattr(encoder, "classes_"):
|
63
|
+
classes_ = cast(list[int], encoder.classes_) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
|
64
|
+
max_class: int = len(classes_) - 1
|
65
|
+
if max_class > 0:
|
66
|
+
encoded = encoded.astype(np.float64)
|
67
|
+
encoded /= max_class
|
68
|
+
|
69
|
+
return pl.Series(encoded)
|
70
|
+
|
71
|
+
|
72
|
+
def encode_kbins(
|
73
|
+
to_encode: pl.Series, n_bins: int, method: str, strategy: str
|
74
|
+
) -> pl.Series:
|
75
|
+
from sklearn.preprocessing import KBinsDiscretizer
|
76
|
+
|
77
|
+
encoder = KBinsDiscretizer(
|
78
|
+
n_bins=n_bins,
|
79
|
+
encode=method,
|
80
|
+
strategy=strategy,
|
81
|
+
dtype=np.float32,
|
82
|
+
)
|
83
|
+
return pl.Series(
|
84
|
+
encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
|
85
|
+
)
|
86
|
+
|
87
|
+
|
88
|
+
def encode_boolean(to_encode: pl.Series, threshold: float) -> pl.Series:
|
89
|
+
from sklearn.preprocessing import Binarizer
|
90
|
+
|
91
|
+
encoder = Binarizer(
|
92
|
+
threshold=threshold,
|
93
|
+
)
|
94
|
+
return pl.Series(
|
95
|
+
encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def encode_max_abs_scale(to_encode: pl.Series) -> pl.Series:
|
100
|
+
from sklearn.preprocessing import MaxAbsScaler
|
101
|
+
|
102
|
+
encoder = MaxAbsScaler()
|
103
|
+
try:
|
104
|
+
encoded = encoder.fit_transform(
|
105
|
+
np.nan_to_num(to_encode.to_numpy()).reshape(-1, 1)
|
106
|
+
).flatten()
|
107
|
+
except ValueError:
|
108
|
+
encoded = np.array([])
|
109
|
+
|
110
|
+
return pl.Series(encoded)
|
111
|
+
|
112
|
+
|
113
|
+
def encode_standard_scale(
|
114
|
+
to_encode: pl.Series, *, with_mean: bool, with_std: bool
|
115
|
+
) -> pl.Series:
|
116
|
+
from sklearn.preprocessing import StandardScaler
|
117
|
+
|
118
|
+
encoder = StandardScaler(
|
119
|
+
with_mean=with_mean,
|
120
|
+
with_std=with_std,
|
121
|
+
)
|
122
|
+
return pl.Series(
|
123
|
+
encoder.fit_transform(to_encode.to_numpy().reshape(-1, 1)).flatten()
|
124
|
+
)
|
125
|
+
|
126
|
+
|
127
|
+
def encode_duration(to_encode: pl.Series) -> pl.Series:
|
128
|
+
if to_encode.dtype != pl.Duration:
|
129
|
+
raise ValueError("Invalid arguments, expected a duration series")
|
130
|
+
if to_encode.is_null().all():
|
131
|
+
return pl.zeros(len(to_encode), dtype=pl.Float32, eager=True)
|
132
|
+
|
133
|
+
return to_encode.dt.total_seconds().cast(pl.Float32).fill_null(0.0)
|
134
|
+
|
135
|
+
|
136
|
+
def _get_cyclic_encoding(
|
137
|
+
to_encode: pl.Series,
|
138
|
+
period: int,
|
139
|
+
) -> tuple[pl.Series, pl.Series]:
|
140
|
+
sine_series = (
|
141
|
+
(2 * math.pi * to_encode / period).sin().alias(f"{to_encode.name}_sine")
|
142
|
+
)
|
143
|
+
cosine_series = (
|
144
|
+
(2 * math.pi * to_encode / period).cos().alias(f"{to_encode.name}_cosine")
|
145
|
+
)
|
146
|
+
return sine_series, cosine_series
|
147
|
+
|
148
|
+
|
149
|
+
def encode_datetime(to_encode: pl.Series) -> pl.Series:
|
150
|
+
match to_encode.dtype:
|
151
|
+
case pl.Date | pl.Time:
|
152
|
+
pass
|
153
|
+
case pl.Datetime:
|
154
|
+
to_encode = to_encode.dt.replace_time_zone("UTC")
|
155
|
+
case _:
|
156
|
+
raise ValueError(
|
157
|
+
"Invalid arguments column could not be endoded as datetime"
|
158
|
+
)
|
159
|
+
|
160
|
+
if to_encode.is_null().all():
|
161
|
+
zero_vector = pl.zeros(11, dtype=pl.Float32, eager=True)
|
162
|
+
return pl.Series([zero_vector] * len(to_encode), dtype=pl.List(pl.Float32))
|
163
|
+
|
164
|
+
n = len(to_encode)
|
165
|
+
year_norm = pl.zeros(n, dtype=pl.Float32, eager=True).alias("year")
|
166
|
+
month_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_sine")
|
167
|
+
month_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("month_cosine")
|
168
|
+
day_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_sine")
|
169
|
+
day_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("day_cosine")
|
170
|
+
hour_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_sine")
|
171
|
+
hour_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("hour_cosine")
|
172
|
+
minute_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_sine")
|
173
|
+
minute_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("minute_cosine")
|
174
|
+
second_sine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_sine")
|
175
|
+
second_cosine = pl.zeros(n, dtype=pl.Float32, eager=True).alias("second_cosine")
|
176
|
+
|
177
|
+
if to_encode.dtype in [pl.Date, pl.Datetime]:
|
178
|
+
try:
|
179
|
+
year = to_encode.dt.year().cast(pl.Float32).alias("year")
|
180
|
+
month = to_encode.dt.month().cast(pl.Float32).alias("month")
|
181
|
+
day = to_encode.dt.day().cast(pl.Float32).alias("day")
|
182
|
+
|
183
|
+
year_norm = (year - REFERENCE_YEAR) / MAX_NUMBER_OF_YEARS
|
184
|
+
month_sine, month_cosine = _get_cyclic_encoding(month, 12)
|
185
|
+
day_sine, day_cosine = _get_cyclic_encoding(day, 31)
|
186
|
+
except pl.exceptions.PanicException as e:
|
187
|
+
_logger.exception("Error extracting datetime", exc_info=e)
|
188
|
+
|
189
|
+
if to_encode.dtype in [pl.Time, pl.Datetime]:
|
190
|
+
try:
|
191
|
+
hour = to_encode.dt.hour().cast(pl.Float32).alias("hour")
|
192
|
+
minute = to_encode.dt.minute().cast(pl.Float32).alias("minute")
|
193
|
+
second = to_encode.dt.second().cast(pl.Float32).alias("second")
|
194
|
+
|
195
|
+
hour_sine, hour_cosine = _get_cyclic_encoding(hour, 24)
|
196
|
+
minute_sine, minute_cosine = _get_cyclic_encoding(minute, 60)
|
197
|
+
second_sine, second_cosine = _get_cyclic_encoding(second, 60)
|
198
|
+
except pl.exceptions.PanicException as e:
|
199
|
+
_logger.exception("Error extracting datetime", exc_info=e)
|
200
|
+
|
201
|
+
return pl.DataFrame(
|
202
|
+
[
|
203
|
+
year_norm.fill_null(0.0),
|
204
|
+
month_sine.fill_null(0.0),
|
205
|
+
month_cosine.fill_null(0.0),
|
206
|
+
day_sine.fill_null(0.0),
|
207
|
+
day_cosine.fill_null(0.0),
|
208
|
+
hour_sine.fill_null(0.0),
|
209
|
+
hour_cosine.fill_null(0.0),
|
210
|
+
minute_sine.fill_null(0.0),
|
211
|
+
minute_cosine.fill_null(0.0),
|
212
|
+
second_sine.fill_null(0.0),
|
213
|
+
second_cosine.fill_null(0.0),
|
214
|
+
]
|
215
|
+
).select(pl.concat_list(pl.all()).alias(to_encode.name))[to_encode.name]
|
corvic/system/_embedder.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
+
import asyncio
|
1
2
|
import dataclasses
|
2
3
|
from collections.abc import Sequence
|
3
|
-
from
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Protocol
|
4
6
|
|
5
7
|
import numpy as np
|
6
8
|
import polars as pl
|
7
|
-
from typing_extensions import Protocol
|
8
9
|
|
9
10
|
from corvic import orm
|
10
11
|
from corvic.result import InternalError, InvalidArgumentError, Ok
|
@@ -43,6 +44,12 @@ class TextEmbedder(Protocol):
|
|
43
44
|
self, context: EmbedTextContext
|
44
45
|
) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError: ...
|
45
46
|
|
47
|
+
async def aembed(
|
48
|
+
self,
|
49
|
+
context: EmbedTextContext,
|
50
|
+
worker_threads: ThreadPoolExecutor | None = None,
|
51
|
+
) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError: ...
|
52
|
+
|
46
53
|
|
47
54
|
@dataclasses.dataclass
|
48
55
|
class EmbedImageContext:
|
@@ -69,6 +76,12 @@ class ImageEmbedder(Protocol):
|
|
69
76
|
self, context: EmbedImageContext
|
70
77
|
) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
|
71
78
|
|
79
|
+
async def aembed(
|
80
|
+
self,
|
81
|
+
context: EmbedImageContext,
|
82
|
+
worker_threads: ThreadPoolExecutor | None = None,
|
83
|
+
) -> Ok[EmbedImageResult] | InvalidArgumentError | InternalError: ...
|
84
|
+
|
72
85
|
|
73
86
|
@dataclasses.dataclass
|
74
87
|
class ClipModels:
|
@@ -142,3 +155,12 @@ class ClipText(TextEmbedder):
|
|
142
155
|
),
|
143
156
|
)
|
144
157
|
)
|
158
|
+
|
159
|
+
async def aembed(
|
160
|
+
self,
|
161
|
+
context: EmbedTextContext,
|
162
|
+
worker_threads: ThreadPoolExecutor | None = None,
|
163
|
+
) -> Ok[EmbedTextResult] | InvalidArgumentError | InternalError:
|
164
|
+
return await asyncio.get_running_loop().run_in_executor(
|
165
|
+
worker_threads, self.embed, context
|
166
|
+
)
|