haystack-ml-stack 0.2.5__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/PKG-INFO +10 -8
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/pyproject.toml +9 -4
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/__init__.py +14 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/_serializers.py +368 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/app.py +103 -38
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/dynamo.py +326 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/exceptions.py +5 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/__init__.py +0 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/v1/__init__.py +0 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/settings.py +2 -1
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/utils.py +675 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/PKG-INFO +10 -8
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/SOURCES.txt +5 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/requires.txt +3 -0
- haystack_ml_stack-0.3.0/tests/test_serializers.py +152 -0
- haystack_ml_stack-0.3.0/tests/test_utils.py +510 -0
- haystack_ml_stack-0.2.5/src/haystack_ml_stack/__init__.py +0 -4
- haystack_ml_stack-0.2.5/src/haystack_ml_stack/dynamo.py +0 -194
- haystack_ml_stack-0.2.5/src/haystack_ml_stack/utils.py +0 -312
- haystack_ml_stack-0.2.5/tests/test_utils.py +0 -76
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/README.md +0 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/setup.cfg +0 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/cache.py +0 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/model_store.py +0 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
- {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
|
@@ -1,18 +1,20 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: haystack-ml-stack
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Functions related to Haystack ML
|
|
5
5
|
Author-email: Oscar Vega <oscar@haystack.tv>
|
|
6
6
|
License: MIT
|
|
7
7
|
Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
|
9
|
-
Requires-Dist:
|
|
10
|
-
|
|
11
|
-
Requires-Dist:
|
|
12
|
-
Requires-Dist:
|
|
13
|
-
Requires-Dist:
|
|
14
|
-
Requires-Dist:
|
|
15
|
-
Requires-Dist:
|
|
9
|
+
Requires-Dist: protobuf==6.33.2
|
|
10
|
+
Provides-Extra: server
|
|
11
|
+
Requires-Dist: pydantic==2.5.0; extra == "server"
|
|
12
|
+
Requires-Dist: cachetools==5.5.2; extra == "server"
|
|
13
|
+
Requires-Dist: cloudpickle==2.2.1; extra == "server"
|
|
14
|
+
Requires-Dist: aioboto3==12.0.0; extra == "server"
|
|
15
|
+
Requires-Dist: fastapi==0.104.1; extra == "server"
|
|
16
|
+
Requires-Dist: pydantic-settings==2.2; extra == "server"
|
|
17
|
+
Requires-Dist: newrelic==11.1.0; extra == "server"
|
|
16
18
|
|
|
17
19
|
# Haystack ML Stack
|
|
18
20
|
|
|
@@ -5,18 +5,23 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "haystack-ml-stack"
|
|
8
|
-
version = "0.
|
|
8
|
+
version = "0.3.0"
|
|
9
9
|
description = "Functions related to Haystack ML"
|
|
10
10
|
readme = "README.md"
|
|
11
11
|
authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
|
|
12
12
|
requires-python = ">=3.11"
|
|
13
13
|
dependencies = [
|
|
14
|
+
"protobuf==6.33.2",
|
|
15
|
+
]
|
|
16
|
+
license = { text = "MIT" }
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
server = [
|
|
14
20
|
"pydantic==2.5.0",
|
|
15
21
|
"cachetools==5.5.2",
|
|
16
22
|
"cloudpickle==2.2.1",
|
|
17
23
|
"aioboto3==12.0.0",
|
|
18
24
|
"fastapi==0.104.1",
|
|
19
25
|
"pydantic-settings==2.2",
|
|
20
|
-
"newrelic==11.1.0"
|
|
21
|
-
]
|
|
22
|
-
license = { text = "MIT" }
|
|
26
|
+
"newrelic==11.1.0",
|
|
27
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
__all__ = []
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from .app import create_app
|
|
5
|
+
|
|
6
|
+
__all__ = ["create_app"]
|
|
7
|
+
except ImportError:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
from ._serializers import SerializerRegistry, FeatureRegistryId
|
|
11
|
+
|
|
12
|
+
__all__ = [*__all__, "SerializerRegistry", "FeatureRegistryId"]
|
|
13
|
+
|
|
14
|
+
__version__ = "0.3.0"
|
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
from .generated.v1 import features_pb2 as features_pb2_v1
|
|
2
|
+
from google.protobuf.message import Message
|
|
3
|
+
from google.protobuf.json_format import ParseDict as ProtoParseDict
|
|
4
|
+
import typing as _t
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
MessageType = _t.TypeVar("MessageType", bound=Message)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Serializer(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def serialize(self, value) -> bytes: ...
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def deserialize(self, value: bytes) -> _t.Any: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SimpleSerializer(Serializer, _t.Generic[MessageType]):
|
|
19
|
+
"""This simple serializer uses the function `ParseDict` provided by google
|
|
20
|
+
to parse dictionaries. While it allows for simple code, it's very slow to run.
|
|
21
|
+
This class should be used directly for PoCs only, production serializers should have
|
|
22
|
+
custom implementations where fields are set directly. Early tests show that
|
|
23
|
+
manual serialization can provide 10x speedup.
|
|
24
|
+
|
|
25
|
+
Deserialization is fine since it deserializes from the binary into the message
|
|
26
|
+
itself, it doesn't need to create a dictionary."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, msg_class: type[MessageType]):
|
|
29
|
+
self.msg_class = msg_class
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
def serialize(self, value) -> bytes:
|
|
33
|
+
msg = self.msg_class()
|
|
34
|
+
return ProtoParseDict(value, message=msg).SerializeToString()
|
|
35
|
+
|
|
36
|
+
def deserialize(self, value) -> MessageType:
|
|
37
|
+
msg: Message = self.msg_class()
|
|
38
|
+
msg.ParseFromString(value)
|
|
39
|
+
return msg
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class StreamPWatchedSerializerV1(SimpleSerializer):
|
|
43
|
+
def __init__(self):
|
|
44
|
+
super().__init__(msg_class=features_pb2_v1.StreamPWatched)
|
|
45
|
+
|
|
46
|
+
def serialize(self, value):
|
|
47
|
+
root_msg = self.build_msg(value)
|
|
48
|
+
return root_msg.SerializeToString()
|
|
49
|
+
|
|
50
|
+
def build_msg(self, value) -> features_pb2_v1.StreamPWatched:
|
|
51
|
+
message = self.msg_class()
|
|
52
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
53
|
+
message.version = value["version"]
|
|
54
|
+
for entry_context, counts in value["data"].items():
|
|
55
|
+
entry_context_msg: features_pb2_v1.EntryContextCounts = getattr(
|
|
56
|
+
message.data, entry_context
|
|
57
|
+
)
|
|
58
|
+
entry_context_msg.attempts = int(counts["attempts"])
|
|
59
|
+
entry_context_msg.watched = int(counts["watched"])
|
|
60
|
+
return message
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
UserPWatchedSerializerV1 = StreamPWatchedSerializerV1
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class StreamPWatchedSerializerV0(Serializer):
|
|
67
|
+
serializer_v1 = StreamPWatchedSerializerV1()
|
|
68
|
+
|
|
69
|
+
def serialize(self, value) -> bytes:
|
|
70
|
+
raise NotImplementedError(
|
|
71
|
+
"This serializer should never be used for serialization!"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def deserialize(self, value) -> features_pb2_v1.StreamPWatched:
|
|
75
|
+
value = {
|
|
76
|
+
"data": {
|
|
77
|
+
entry_context.replace(" ", "_"): counts
|
|
78
|
+
for entry_context, counts in value.items()
|
|
79
|
+
},
|
|
80
|
+
"version": 1,
|
|
81
|
+
}
|
|
82
|
+
return self.serializer_v1.build_msg(value)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class StreamPSelectSerializerV1(SimpleSerializer):
|
|
86
|
+
def __init__(self):
|
|
87
|
+
super().__init__(msg_class=features_pb2_v1.StreamPSelect)
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
def serialize(self, value) -> bytes:
|
|
91
|
+
root_msg = self.build_msg(value)
|
|
92
|
+
return root_msg.SerializeToString()
|
|
93
|
+
|
|
94
|
+
def build_msg(self, value) -> features_pb2_v1.StreamPSelect:
|
|
95
|
+
message: features_pb2_v1.StreamPSelect = self.msg_class()
|
|
96
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
97
|
+
message.version = 1
|
|
98
|
+
data = value["data"]
|
|
99
|
+
for (
|
|
100
|
+
browsed_debias_key,
|
|
101
|
+
position_pselects,
|
|
102
|
+
) in data.items():
|
|
103
|
+
position_pselects_msg: features_pb2_v1.PositionPSelect = getattr(
|
|
104
|
+
message.data, browsed_debias_key
|
|
105
|
+
)
|
|
106
|
+
for position, select_counts in position_pselects.items():
|
|
107
|
+
select_counts_msg = getattr(position_pselects_msg, position)
|
|
108
|
+
select_counts_msg.total_selects = int(select_counts["total_selects"])
|
|
109
|
+
select_counts_msg.total_browsed = int(select_counts["total_browsed"])
|
|
110
|
+
select_counts_msg.total_selects_and_watched = int(
|
|
111
|
+
select_counts["total_selects_and_watched"]
|
|
112
|
+
)
|
|
113
|
+
return message
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
UserPSelectSerializerV1 = StreamPSelectSerializerV1
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class StreamPSelectSerializerV0(Serializer):
|
|
120
|
+
serializer_v1 = StreamPSelectSerializerV1()
|
|
121
|
+
|
|
122
|
+
def serialize(self, value) -> bytes:
|
|
123
|
+
raise NotImplementedError(
|
|
124
|
+
"This serializer should never be used for serialization!"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def deserialize(self, value):
|
|
128
|
+
key_mapping = {
|
|
129
|
+
"0": "first_pos",
|
|
130
|
+
"1": "second_pos",
|
|
131
|
+
"2": "third_pos",
|
|
132
|
+
"3+": "rest_pos",
|
|
133
|
+
}
|
|
134
|
+
for browsed_debiasing in value.keys():
|
|
135
|
+
for old_key, new_key in key_mapping.items():
|
|
136
|
+
if old_key not in value[browsed_debiasing]:
|
|
137
|
+
continue
|
|
138
|
+
value[browsed_debiasing][new_key] = value[browsed_debiasing].pop(
|
|
139
|
+
old_key
|
|
140
|
+
)
|
|
141
|
+
out = {
|
|
142
|
+
"data": {
|
|
143
|
+
"up_to_4_browsed": value["4_browsed"],
|
|
144
|
+
"all_browsed": value["all_browsed"],
|
|
145
|
+
},
|
|
146
|
+
"version": 1,
|
|
147
|
+
}
|
|
148
|
+
msg = self.serializer_v1.build_msg(value=out)
|
|
149
|
+
return msg
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class StreamSimilaritySerializerV1(SimpleSerializer):
|
|
153
|
+
def __init__(self):
|
|
154
|
+
super().__init__(msg_class=features_pb2_v1.StreamSimilarityScores)
|
|
155
|
+
|
|
156
|
+
def serialize(self, value):
|
|
157
|
+
msg = self.build_msg(value)
|
|
158
|
+
return msg.SerializeToString()
|
|
159
|
+
|
|
160
|
+
def build_msg(self, value) -> features_pb2_v1.StreamSimilarityScores:
|
|
161
|
+
message = self.msg_class()
|
|
162
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
163
|
+
message.version = value["version"]
|
|
164
|
+
for key, score in value["data"].items():
|
|
165
|
+
message.data[key] = score
|
|
166
|
+
return message
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class StreamSimilaritySerializerV0(Serializer):
|
|
170
|
+
serializer_v1 = StreamSimilaritySerializerV1()
|
|
171
|
+
|
|
172
|
+
def serialize(self, value):
|
|
173
|
+
raise NotImplementedError(
|
|
174
|
+
"This serializer should never be used for serialization!"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def deserialize(self, value):
|
|
178
|
+
value = {"data": value, "version": 1}
|
|
179
|
+
msg = self.serializer_v1.build_msg(value)
|
|
180
|
+
return msg
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class UserPersonalizingPWatchedSerializerV1(SimpleSerializer):
|
|
184
|
+
def __init__(self):
|
|
185
|
+
super().__init__(msg_class=features_pb2_v1.UserPersonalizingPWatched)
|
|
186
|
+
|
|
187
|
+
def serialize(self, value: dict) -> bytes:
|
|
188
|
+
root_msg = self.build_msg(value)
|
|
189
|
+
return root_msg.SerializeToString()
|
|
190
|
+
|
|
191
|
+
def build_msg(self, value) -> features_pb2_v1.UserPersonalizingPWatched:
|
|
192
|
+
root_msg = features_pb2_v1.UserPersonalizingPWatched()
|
|
193
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
194
|
+
root_msg.version = value["version"]
|
|
195
|
+
data = value["data"]
|
|
196
|
+
for personalizing_key, entry_context_pwatched in data.items():
|
|
197
|
+
personalizing_msg = root_msg.data[personalizing_key]
|
|
198
|
+
for entry_context, counts in entry_context_pwatched.items():
|
|
199
|
+
entry_context_msg = getattr(personalizing_msg, entry_context)
|
|
200
|
+
entry_context_msg.attempts = int(counts["attempts"])
|
|
201
|
+
entry_context_msg.watched = int(counts["watched"])
|
|
202
|
+
return root_msg
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class UserPersonalizingPSelectSerializerV1(SimpleSerializer):
|
|
206
|
+
def __init__(self):
|
|
207
|
+
super().__init__(msg_class=features_pb2_v1.UserPersonalizingPSelect)
|
|
208
|
+
|
|
209
|
+
def serialize(self, value):
|
|
210
|
+
root_msg = features_pb2_v1.UserPersonalizingPSelect()
|
|
211
|
+
root_msg.version = value["version"]
|
|
212
|
+
data = value["data"]
|
|
213
|
+
for personalizing_key, browsed_debiased_pselecs in data.items():
|
|
214
|
+
personalizing_msg = root_msg.data[personalizing_key]
|
|
215
|
+
for (
|
|
216
|
+
browsed_debias_key,
|
|
217
|
+
position_pselects,
|
|
218
|
+
) in browsed_debiased_pselecs.items():
|
|
219
|
+
position_pselects_msg = getattr(personalizing_msg, browsed_debias_key)
|
|
220
|
+
for position, select_counts in position_pselects.items():
|
|
221
|
+
select_counts_msg = getattr(position_pselects_msg, position)
|
|
222
|
+
select_counts_msg.total_selects = int(
|
|
223
|
+
select_counts["total_selects"]
|
|
224
|
+
)
|
|
225
|
+
select_counts_msg.total_browsed = int(
|
|
226
|
+
select_counts["total_browsed"]
|
|
227
|
+
)
|
|
228
|
+
select_counts_msg.total_selects_and_watched = int(
|
|
229
|
+
select_counts["total_selects_and_watched"]
|
|
230
|
+
)
|
|
231
|
+
return root_msg.SerializeToString()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class PassThroughSerializer(Serializer):
|
|
235
|
+
def serialize(self, value):
|
|
236
|
+
return value
|
|
237
|
+
|
|
238
|
+
def deserialize(self, value):
|
|
239
|
+
return value
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
user_personalizing_pwatched_serializer_v1 = UserPersonalizingPWatchedSerializerV1()
|
|
243
|
+
user_pwatched_serializer_v1 = UserPWatchedSerializerV1()
|
|
244
|
+
user_personalizing_pselect_serializer_v1 = UserPersonalizingPSelectSerializerV1()
|
|
245
|
+
user_pselect_serializer_v1 = UserPSelectSerializerV1()
|
|
246
|
+
stream_pwatched_serializer_v0 = StreamPWatchedSerializerV0()
|
|
247
|
+
stream_pwatched_serializer_v1 = StreamPWatchedSerializerV1()
|
|
248
|
+
stream_pselect_serializer_v0 = StreamPSelectSerializerV0()
|
|
249
|
+
stream_pselect_serializer_v1 = StreamPSelectSerializerV1()
|
|
250
|
+
stream_similarity_scores_serializer_v0 = StreamSimilaritySerializerV0()
|
|
251
|
+
stream_similarity_scores_serializer_v1 = StreamSimilaritySerializerV1()
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class FeatureRegistryId(_t.NamedTuple):
|
|
255
|
+
entity_type: _t.Literal["STREAM", "USER"]
|
|
256
|
+
feature_id: str
|
|
257
|
+
version: str
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
stream_pwatched_v0_features: list[FeatureRegistryId] = [
|
|
261
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v0"),
|
|
262
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v0"),
|
|
263
|
+
FeatureRegistryId(
|
|
264
|
+
entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v0"
|
|
265
|
+
),
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
stream_pwatched_v1_features: list[FeatureRegistryId] = [
|
|
269
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v1"),
|
|
270
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v1"),
|
|
271
|
+
FeatureRegistryId(
|
|
272
|
+
entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v1"
|
|
273
|
+
),
|
|
274
|
+
]
|
|
275
|
+
|
|
276
|
+
stream_pselect_v0_features: list[FeatureRegistryId] = [
|
|
277
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v0"),
|
|
278
|
+
FeatureRegistryId(
|
|
279
|
+
entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v0"
|
|
280
|
+
),
|
|
281
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v0"),
|
|
282
|
+
]
|
|
283
|
+
|
|
284
|
+
stream_pselect_v1_features: list[FeatureRegistryId] = [
|
|
285
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v1"),
|
|
286
|
+
FeatureRegistryId(
|
|
287
|
+
entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v1"
|
|
288
|
+
),
|
|
289
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v1"),
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
stream_similarity_v0_features: list[FeatureRegistryId] = [
|
|
293
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="SIMILARITY", version="v0"),
|
|
294
|
+
FeatureRegistryId(
|
|
295
|
+
entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v0"
|
|
296
|
+
),
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
stream_similarity_v1_features: list[FeatureRegistryId] = [
|
|
300
|
+
FeatureRegistryId(
|
|
301
|
+
entity_type="STREAM", feature_id="SIMILARITY#GEMINI", version="v1"
|
|
302
|
+
),
|
|
303
|
+
FeatureRegistryId(
|
|
304
|
+
entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v1"
|
|
305
|
+
),
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
user_personalizing_pwatched_v1_features: list[FeatureRegistryId] = [
|
|
309
|
+
FeatureRegistryId(
|
|
310
|
+
entity_type="USER", feature_id="PWATCHED#6M#CATEGORY", version="v1"
|
|
311
|
+
),
|
|
312
|
+
FeatureRegistryId(
|
|
313
|
+
entity_type="USER",
|
|
314
|
+
feature_id="PWATCHED#6M#AUTHOR_SHOW",
|
|
315
|
+
version="v1",
|
|
316
|
+
),
|
|
317
|
+
FeatureRegistryId(
|
|
318
|
+
entity_type="USER",
|
|
319
|
+
feature_id="PWATCHED#6M#GEMINI_CATEGORY",
|
|
320
|
+
version="v1",
|
|
321
|
+
),
|
|
322
|
+
]
|
|
323
|
+
|
|
324
|
+
user_personalizing_pselect_v1_features: list[FeatureRegistryId] = [
|
|
325
|
+
FeatureRegistryId(
|
|
326
|
+
entity_type="USER", feature_id="PSELECT#6M#CATEGORY", version="v1"
|
|
327
|
+
),
|
|
328
|
+
FeatureRegistryId(
|
|
329
|
+
entity_type="USER", feature_id="PSELECT#6M#AUTHOR_SHOW", version="v1"
|
|
330
|
+
),
|
|
331
|
+
FeatureRegistryId(
|
|
332
|
+
entity_type="USER", feature_id="PSELECT#6M#GEMINI_CATEGORY", version="v1"
|
|
333
|
+
),
|
|
334
|
+
]
|
|
335
|
+
|
|
336
|
+
user_bias_pwatched_v1_features: list[FeatureRegistryId] = [
|
|
337
|
+
FeatureRegistryId(entity_type="USER", feature_id="PWATCHED#6M", version="v1")
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
user_bias_pselect_v1_features: list[FeatureRegistryId] = [
|
|
341
|
+
FeatureRegistryId(entity_type="USER", feature_id="PSELECT#6M", version="v1")
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
features_serializer_tuples: list[tuple[list[FeatureRegistryId], Serializer]] = [
|
|
345
|
+
(stream_pwatched_v0_features, stream_pwatched_serializer_v0),
|
|
346
|
+
(stream_pwatched_v1_features, stream_pwatched_serializer_v1),
|
|
347
|
+
(stream_pselect_v0_features, stream_pselect_serializer_v0),
|
|
348
|
+
(stream_pselect_v1_features, stream_pselect_serializer_v1),
|
|
349
|
+
(stream_similarity_v0_features, stream_similarity_scores_serializer_v0),
|
|
350
|
+
(stream_similarity_v1_features, stream_similarity_scores_serializer_v1),
|
|
351
|
+
(
|
|
352
|
+
user_personalizing_pwatched_v1_features,
|
|
353
|
+
user_personalizing_pwatched_serializer_v1,
|
|
354
|
+
),
|
|
355
|
+
(user_bias_pwatched_v1_features, user_pwatched_serializer_v1),
|
|
356
|
+
(user_personalizing_pselect_v1_features, user_personalizing_pselect_serializer_v1),
|
|
357
|
+
(user_bias_pselect_v1_features, user_pselect_serializer_v1),
|
|
358
|
+
]
|
|
359
|
+
|
|
360
|
+
SerializerRegistry: dict[FeatureRegistryId, Serializer] = {
|
|
361
|
+
FeatureRegistryId(
|
|
362
|
+
entity_type="PASS_THROUGH", feature_id="PASS_THROUGH", version="v1"
|
|
363
|
+
): PassThroughSerializer()
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
for feature_ids, serializer in features_serializer_tuples:
|
|
367
|
+
for feature_id in feature_ids:
|
|
368
|
+
SerializerRegistry[feature_id] = serializer
|
|
@@ -15,9 +15,12 @@ import newrelic.agent
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
from .cache import make_features_cache
|
|
18
|
-
from .dynamo import
|
|
18
|
+
from .dynamo import set_all_features, FeatureRetrievalMeta
|
|
19
19
|
from .model_store import download_and_load_model
|
|
20
20
|
from .settings import Settings
|
|
21
|
+
from . import exceptions
|
|
22
|
+
from ._serializers import SerializerRegistry
|
|
23
|
+
from google.protobuf import text_format
|
|
21
24
|
|
|
22
25
|
logging.basicConfig(
|
|
23
26
|
level=logging.INFO,
|
|
@@ -30,6 +33,61 @@ logger = logging.getLogger(__name__)
|
|
|
30
33
|
MAX_POOL_CONNECTIONS = int(os.environ.get("MAX_POOL_CONNECTIONS", 50))
|
|
31
34
|
|
|
32
35
|
|
|
36
|
+
class StreamLoggerProxy:
|
|
37
|
+
def __init__(self, stream, feature_ids):
|
|
38
|
+
self._stream = stream
|
|
39
|
+
self._feature_ids = feature_ids
|
|
40
|
+
|
|
41
|
+
def __repr__(self):
|
|
42
|
+
parts = []
|
|
43
|
+
for k, v in self._stream.items():
|
|
44
|
+
if k in self._feature_ids:
|
|
45
|
+
# Format only when needed for the log output
|
|
46
|
+
formatted_v = text_format.MessageToString(v, as_one_line=True)
|
|
47
|
+
parts.append(f"'{k}': '{formatted_v}'")
|
|
48
|
+
else:
|
|
49
|
+
parts.append(f"'{k}': {repr(v)}")
|
|
50
|
+
return "{" + ", ".join(parts) + "}"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def load_model(state, cfg: Settings) -> None:
|
|
54
|
+
if not cfg.s3_model_path:
|
|
55
|
+
logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
|
|
56
|
+
else:
|
|
57
|
+
try:
|
|
58
|
+
# Pass the persistent session/client if needed
|
|
59
|
+
state["model"] = await download_and_load_model(
|
|
60
|
+
cfg.s3_model_path, aio_session=state["session"]
|
|
61
|
+
)
|
|
62
|
+
state["stream_features"] = state["model"].get("stream_features", [])
|
|
63
|
+
state["user_features"] = state["model"].get("user_features", [])
|
|
64
|
+
valid_features = set(
|
|
65
|
+
(entity_type, feature_id)
|
|
66
|
+
for entity_type, feature_id, _ in SerializerRegistry.keys()
|
|
67
|
+
)
|
|
68
|
+
all_features = set(
|
|
69
|
+
[("STREAM", feature_name) for feature_name in state["stream_features"]]
|
|
70
|
+
+ [("USER", feature_name) for feature_name in state["user_features"]]
|
|
71
|
+
)
|
|
72
|
+
invalid_features = all_features.difference(valid_features)
|
|
73
|
+
if invalid_features:
|
|
74
|
+
raise exceptions.InvalidFeaturesException(
|
|
75
|
+
f"Received invalid features: {invalid_features}"
|
|
76
|
+
)
|
|
77
|
+
newrelic.agent.add_custom_attribute(
|
|
78
|
+
"total_stream_features", len(state["stream_features"])
|
|
79
|
+
)
|
|
80
|
+
newrelic.agent.add_custom_attribute(
|
|
81
|
+
"total_user_features", len(state["user_features"])
|
|
82
|
+
)
|
|
83
|
+
logger.info("Model loaded successfully.")
|
|
84
|
+
except exceptions.InvalidFeaturesException as e:
|
|
85
|
+
logger.error("%s", e)
|
|
86
|
+
raise e
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.critical("Failed to load model: %s", e)
|
|
89
|
+
|
|
90
|
+
|
|
33
91
|
def create_app(
|
|
34
92
|
settings: Optional[Settings] = None,
|
|
35
93
|
*,
|
|
@@ -42,10 +100,12 @@ def create_app(
|
|
|
42
100
|
cfg = settings or Settings()
|
|
43
101
|
|
|
44
102
|
# Mutable state: cache + model
|
|
45
|
-
|
|
103
|
+
stream_features_cache = make_features_cache(cfg.stream_cache_maxsize)
|
|
104
|
+
user_features_cache = make_features_cache(cfg.user_cache_maxsize)
|
|
105
|
+
aws_session = aiobotocore.session.get_session()
|
|
46
106
|
state: Dict[str, Any] = {
|
|
47
107
|
"model": preloaded_model,
|
|
48
|
-
"session":
|
|
108
|
+
"session": aws_session,
|
|
49
109
|
"model_name": (
|
|
50
110
|
os.path.basename(cfg.s3_model_path) if cfg.s3_model_path else None
|
|
51
111
|
),
|
|
@@ -58,8 +118,11 @@ def create_app(
|
|
|
58
118
|
Everything before 'yield' runs on startup.
|
|
59
119
|
Everything after 'yield' runs on shutdown.
|
|
60
120
|
"""
|
|
121
|
+
# 1. Load ML Model
|
|
122
|
+
if state["model"] is None:
|
|
123
|
+
await load_model(state, cfg)
|
|
61
124
|
async with AsyncExitStack() as stack:
|
|
62
|
-
#
|
|
125
|
+
# 2. Initialize DynamoDB Client (Persistent Pool)
|
|
63
126
|
session = state["session"]
|
|
64
127
|
state["dynamo_client"] = await stack.enter_async_context(
|
|
65
128
|
session.create_client(
|
|
@@ -69,29 +132,6 @@ def create_app(
|
|
|
69
132
|
)
|
|
70
133
|
)
|
|
71
134
|
logger.info("DynamoDB persistent client initialized.")
|
|
72
|
-
|
|
73
|
-
# 2. Load ML Model
|
|
74
|
-
if state["model"] is None:
|
|
75
|
-
if not cfg.s3_model_path:
|
|
76
|
-
logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
|
|
77
|
-
else:
|
|
78
|
-
try:
|
|
79
|
-
# Pass the persistent session/client if needed
|
|
80
|
-
state["model"] = await download_and_load_model(
|
|
81
|
-
cfg.s3_model_path, aio_session=state["session"]
|
|
82
|
-
)
|
|
83
|
-
state["stream_features"] = state["model"].get(
|
|
84
|
-
"stream_features", []
|
|
85
|
-
)
|
|
86
|
-
state["user_features"] = state["model"].get("user_features", [])
|
|
87
|
-
|
|
88
|
-
newrelic.agent.add_custom_attribute(
|
|
89
|
-
"total_stream_features", len(state["stream_features"])
|
|
90
|
-
)
|
|
91
|
-
logger.info("Model loaded successfully.")
|
|
92
|
-
except Exception as e:
|
|
93
|
-
logger.critical("Failed to load model: %s", e)
|
|
94
|
-
|
|
95
135
|
yield
|
|
96
136
|
|
|
97
137
|
# 3. Shutdown Logic
|
|
@@ -116,7 +156,8 @@ def create_app(
|
|
|
116
156
|
return {
|
|
117
157
|
"status": "ok",
|
|
118
158
|
"model_loaded": True,
|
|
119
|
-
"
|
|
159
|
+
"stream_cache_size": len(stream_features_cache),
|
|
160
|
+
"user_cache_size": len(user_features_cache),
|
|
120
161
|
"model_name": state.get("model_name"),
|
|
121
162
|
"stream_features": state.get("stream_features", []),
|
|
122
163
|
}
|
|
@@ -151,8 +192,11 @@ def create_app(
|
|
|
151
192
|
# Feature fetch (optional based on model)
|
|
152
193
|
model = state["model"]
|
|
153
194
|
stream_features = model.get("stream_features", []) or []
|
|
195
|
+
user_features = model.get("user_features", []) or []
|
|
154
196
|
retrieval_meta = FeatureRetrievalMeta(
|
|
155
197
|
cache_misses=0,
|
|
198
|
+
stream_cache_misses=0,
|
|
199
|
+
user_cache_misses=0,
|
|
156
200
|
retrieval_ms=0,
|
|
157
201
|
success=True,
|
|
158
202
|
cache_delay_minutes=0,
|
|
@@ -160,21 +204,40 @@ def create_app(
|
|
|
160
204
|
parsing_ms=0,
|
|
161
205
|
)
|
|
162
206
|
if stream_features:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
207
|
+
try:
|
|
208
|
+
retrieval_meta = await set_all_features(
|
|
209
|
+
dynamo_client=state["dynamo_client"],
|
|
210
|
+
user=user,
|
|
211
|
+
streams=streams,
|
|
212
|
+
stream_features=stream_features,
|
|
213
|
+
user_features=user_features,
|
|
214
|
+
stream_features_cache=stream_features_cache,
|
|
215
|
+
user_features_cache=user_features_cache,
|
|
216
|
+
features_table=cfg.features_table,
|
|
217
|
+
cache_sep=cfg.cache_separator,
|
|
218
|
+
)
|
|
219
|
+
except exceptions.InvalidFeaturesException as e:
|
|
220
|
+
logger.error(
|
|
221
|
+
"The following features are not present in the SerializerRegistry %s",
|
|
222
|
+
e,
|
|
223
|
+
)
|
|
224
|
+
raise HTTPException(
|
|
225
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
|
226
|
+
detail=f"Received invalid features from feature store: {e}",
|
|
227
|
+
) from e
|
|
172
228
|
|
|
173
229
|
random_number = random.random()
|
|
174
230
|
userid = user.get("userid", "")
|
|
175
231
|
# Sampling logs
|
|
176
232
|
if random_number < cfg.logs_fraction:
|
|
177
|
-
logger.info(
|
|
233
|
+
logger.info(
|
|
234
|
+
"User %s streams: %s",
|
|
235
|
+
user.get("userid", ""),
|
|
236
|
+
[
|
|
237
|
+
StreamLoggerProxy(s, stream_features + user_features)
|
|
238
|
+
for s in streams
|
|
239
|
+
],
|
|
240
|
+
)
|
|
178
241
|
|
|
179
242
|
# Synchronous model execution (user code)
|
|
180
243
|
try:
|
|
@@ -200,6 +263,8 @@ def create_app(
|
|
|
200
263
|
"Inference",
|
|
201
264
|
{
|
|
202
265
|
"cache_misses": retrieval_meta.cache_misses,
|
|
266
|
+
"user_cache_misses": retrieval_meta.user_cache_misses,
|
|
267
|
+
"stream_cache_misses": retrieval_meta.stream_cache_misses,
|
|
203
268
|
"retrieval_success": int(retrieval_meta.success),
|
|
204
269
|
"cache_delay_minutes": retrieval_meta.cache_delay_minutes,
|
|
205
270
|
"dynamo_ms": retrieval_meta.dynamo_ms,
|