haystack-ml-stack 0.2.4__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/PKG-INFO +10 -8
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/pyproject.toml +9 -4
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/__init__.py +14 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/_serializers.py +368 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/app.py +133 -38
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/cache.py +2 -2
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/dynamo.py +326 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/exceptions.py +5 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/__init__.py +0 -0
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/v1/__init__.py +0 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/settings.py +2 -1
- haystack_ml_stack-0.3.0/src/haystack_ml_stack/utils.py +675 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/PKG-INFO +10 -8
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/SOURCES.txt +5 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/requires.txt +3 -0
- haystack_ml_stack-0.3.0/tests/test_serializers.py +152 -0
- haystack_ml_stack-0.3.0/tests/test_utils.py +510 -0
- haystack_ml_stack-0.2.4/src/haystack_ml_stack/__init__.py +0 -4
- haystack_ml_stack-0.2.4/src/haystack_ml_stack/dynamo.py +0 -194
- haystack_ml_stack-0.2.4/src/haystack_ml_stack/utils.py +0 -312
- haystack_ml_stack-0.2.4/tests/test_utils.py +0 -76
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/README.md +0 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/setup.cfg +0 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/model_store.py +0 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
- {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
|
@@ -1,18 +1,20 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: haystack-ml-stack
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Functions related to Haystack ML
|
|
5
5
|
Author-email: Oscar Vega <oscar@haystack.tv>
|
|
6
6
|
License: MIT
|
|
7
7
|
Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
|
9
|
-
Requires-Dist:
|
|
10
|
-
|
|
11
|
-
Requires-Dist:
|
|
12
|
-
Requires-Dist:
|
|
13
|
-
Requires-Dist:
|
|
14
|
-
Requires-Dist:
|
|
15
|
-
Requires-Dist:
|
|
9
|
+
Requires-Dist: protobuf==6.33.2
|
|
10
|
+
Provides-Extra: server
|
|
11
|
+
Requires-Dist: pydantic==2.5.0; extra == "server"
|
|
12
|
+
Requires-Dist: cachetools==5.5.2; extra == "server"
|
|
13
|
+
Requires-Dist: cloudpickle==2.2.1; extra == "server"
|
|
14
|
+
Requires-Dist: aioboto3==12.0.0; extra == "server"
|
|
15
|
+
Requires-Dist: fastapi==0.104.1; extra == "server"
|
|
16
|
+
Requires-Dist: pydantic-settings==2.2; extra == "server"
|
|
17
|
+
Requires-Dist: newrelic==11.1.0; extra == "server"
|
|
16
18
|
|
|
17
19
|
# Haystack ML Stack
|
|
18
20
|
|
|
@@ -5,18 +5,23 @@ build-backend = "setuptools.build_meta"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "haystack-ml-stack"
|
|
8
|
-
version = "0.
|
|
8
|
+
version = "0.3.0"
|
|
9
9
|
description = "Functions related to Haystack ML"
|
|
10
10
|
readme = "README.md"
|
|
11
11
|
authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
|
|
12
12
|
requires-python = ">=3.11"
|
|
13
13
|
dependencies = [
|
|
14
|
+
"protobuf==6.33.2",
|
|
15
|
+
]
|
|
16
|
+
license = { text = "MIT" }
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
server = [
|
|
14
20
|
"pydantic==2.5.0",
|
|
15
21
|
"cachetools==5.5.2",
|
|
16
22
|
"cloudpickle==2.2.1",
|
|
17
23
|
"aioboto3==12.0.0",
|
|
18
24
|
"fastapi==0.104.1",
|
|
19
25
|
"pydantic-settings==2.2",
|
|
20
|
-
"newrelic==11.1.0"
|
|
21
|
-
]
|
|
22
|
-
license = { text = "MIT" }
|
|
26
|
+
"newrelic==11.1.0",
|
|
27
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
__all__ = []
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from .app import create_app
|
|
5
|
+
|
|
6
|
+
__all__ = ["create_app"]
|
|
7
|
+
except ImportError:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
from ._serializers import SerializerRegistry, FeatureRegistryId
|
|
11
|
+
|
|
12
|
+
__all__ = [*__all__, "SerializerRegistry", "FeatureRegistryId"]
|
|
13
|
+
|
|
14
|
+
__version__ = "0.3.0"
|
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
from .generated.v1 import features_pb2 as features_pb2_v1
|
|
2
|
+
from google.protobuf.message import Message
|
|
3
|
+
from google.protobuf.json_format import ParseDict as ProtoParseDict
|
|
4
|
+
import typing as _t
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
MessageType = _t.TypeVar("MessageType", bound=Message)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Serializer(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def serialize(self, value) -> bytes: ...
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def deserialize(self, value: bytes) -> _t.Any: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SimpleSerializer(Serializer, _t.Generic[MessageType]):
|
|
19
|
+
"""This simple serializer uses the function `ParseDict` provided by google
|
|
20
|
+
to parse dictionaries. While it allows for simple code, it's very slow to run.
|
|
21
|
+
This class should be used directly for PoCs only, production serializers should have
|
|
22
|
+
custom implementations where fields are set directly. Early tests show that
|
|
23
|
+
manual serialization can provide 10x speedup.
|
|
24
|
+
|
|
25
|
+
Deserialization is fine since it deserializes from the binary into the message
|
|
26
|
+
itself, it doesn't need to create a dictionary."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, msg_class: type[MessageType]):
|
|
29
|
+
self.msg_class = msg_class
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
def serialize(self, value) -> bytes:
|
|
33
|
+
msg = self.msg_class()
|
|
34
|
+
return ProtoParseDict(value, message=msg).SerializeToString()
|
|
35
|
+
|
|
36
|
+
def deserialize(self, value) -> MessageType:
|
|
37
|
+
msg: Message = self.msg_class()
|
|
38
|
+
msg.ParseFromString(value)
|
|
39
|
+
return msg
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class StreamPWatchedSerializerV1(SimpleSerializer):
|
|
43
|
+
def __init__(self):
|
|
44
|
+
super().__init__(msg_class=features_pb2_v1.StreamPWatched)
|
|
45
|
+
|
|
46
|
+
def serialize(self, value):
|
|
47
|
+
root_msg = self.build_msg(value)
|
|
48
|
+
return root_msg.SerializeToString()
|
|
49
|
+
|
|
50
|
+
def build_msg(self, value) -> features_pb2_v1.StreamPWatched:
|
|
51
|
+
message = self.msg_class()
|
|
52
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
53
|
+
message.version = value["version"]
|
|
54
|
+
for entry_context, counts in value["data"].items():
|
|
55
|
+
entry_context_msg: features_pb2_v1.EntryContextCounts = getattr(
|
|
56
|
+
message.data, entry_context
|
|
57
|
+
)
|
|
58
|
+
entry_context_msg.attempts = int(counts["attempts"])
|
|
59
|
+
entry_context_msg.watched = int(counts["watched"])
|
|
60
|
+
return message
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
UserPWatchedSerializerV1 = StreamPWatchedSerializerV1
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class StreamPWatchedSerializerV0(Serializer):
|
|
67
|
+
serializer_v1 = StreamPWatchedSerializerV1()
|
|
68
|
+
|
|
69
|
+
def serialize(self, value) -> bytes:
|
|
70
|
+
raise NotImplementedError(
|
|
71
|
+
"This serializer should never be used for serialization!"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def deserialize(self, value) -> features_pb2_v1.StreamPWatched:
|
|
75
|
+
value = {
|
|
76
|
+
"data": {
|
|
77
|
+
entry_context.replace(" ", "_"): counts
|
|
78
|
+
for entry_context, counts in value.items()
|
|
79
|
+
},
|
|
80
|
+
"version": 1,
|
|
81
|
+
}
|
|
82
|
+
return self.serializer_v1.build_msg(value)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class StreamPSelectSerializerV1(SimpleSerializer):
|
|
86
|
+
def __init__(self):
|
|
87
|
+
super().__init__(msg_class=features_pb2_v1.StreamPSelect)
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
def serialize(self, value) -> bytes:
|
|
91
|
+
root_msg = self.build_msg(value)
|
|
92
|
+
return root_msg.SerializeToString()
|
|
93
|
+
|
|
94
|
+
def build_msg(self, value) -> features_pb2_v1.StreamPSelect:
|
|
95
|
+
message: features_pb2_v1.StreamPSelect = self.msg_class()
|
|
96
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
97
|
+
message.version = 1
|
|
98
|
+
data = value["data"]
|
|
99
|
+
for (
|
|
100
|
+
browsed_debias_key,
|
|
101
|
+
position_pselects,
|
|
102
|
+
) in data.items():
|
|
103
|
+
position_pselects_msg: features_pb2_v1.PositionPSelect = getattr(
|
|
104
|
+
message.data, browsed_debias_key
|
|
105
|
+
)
|
|
106
|
+
for position, select_counts in position_pselects.items():
|
|
107
|
+
select_counts_msg = getattr(position_pselects_msg, position)
|
|
108
|
+
select_counts_msg.total_selects = int(select_counts["total_selects"])
|
|
109
|
+
select_counts_msg.total_browsed = int(select_counts["total_browsed"])
|
|
110
|
+
select_counts_msg.total_selects_and_watched = int(
|
|
111
|
+
select_counts["total_selects_and_watched"]
|
|
112
|
+
)
|
|
113
|
+
return message
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
UserPSelectSerializerV1 = StreamPSelectSerializerV1
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class StreamPSelectSerializerV0(Serializer):
|
|
120
|
+
serializer_v1 = StreamPSelectSerializerV1()
|
|
121
|
+
|
|
122
|
+
def serialize(self, value) -> bytes:
|
|
123
|
+
raise NotImplementedError(
|
|
124
|
+
"This serializer should never be used for serialization!"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def deserialize(self, value):
|
|
128
|
+
key_mapping = {
|
|
129
|
+
"0": "first_pos",
|
|
130
|
+
"1": "second_pos",
|
|
131
|
+
"2": "third_pos",
|
|
132
|
+
"3+": "rest_pos",
|
|
133
|
+
}
|
|
134
|
+
for browsed_debiasing in value.keys():
|
|
135
|
+
for old_key, new_key in key_mapping.items():
|
|
136
|
+
if old_key not in value[browsed_debiasing]:
|
|
137
|
+
continue
|
|
138
|
+
value[browsed_debiasing][new_key] = value[browsed_debiasing].pop(
|
|
139
|
+
old_key
|
|
140
|
+
)
|
|
141
|
+
out = {
|
|
142
|
+
"data": {
|
|
143
|
+
"up_to_4_browsed": value["4_browsed"],
|
|
144
|
+
"all_browsed": value["all_browsed"],
|
|
145
|
+
},
|
|
146
|
+
"version": 1,
|
|
147
|
+
}
|
|
148
|
+
msg = self.serializer_v1.build_msg(value=out)
|
|
149
|
+
return msg
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class StreamSimilaritySerializerV1(SimpleSerializer):
|
|
153
|
+
def __init__(self):
|
|
154
|
+
super().__init__(msg_class=features_pb2_v1.StreamSimilarityScores)
|
|
155
|
+
|
|
156
|
+
def serialize(self, value):
|
|
157
|
+
msg = self.build_msg(value)
|
|
158
|
+
return msg.SerializeToString()
|
|
159
|
+
|
|
160
|
+
def build_msg(self, value) -> features_pb2_v1.StreamSimilarityScores:
|
|
161
|
+
message = self.msg_class()
|
|
162
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
163
|
+
message.version = value["version"]
|
|
164
|
+
for key, score in value["data"].items():
|
|
165
|
+
message.data[key] = score
|
|
166
|
+
return message
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class StreamSimilaritySerializerV0(Serializer):
|
|
170
|
+
serializer_v1 = StreamSimilaritySerializerV1()
|
|
171
|
+
|
|
172
|
+
def serialize(self, value):
|
|
173
|
+
raise NotImplementedError(
|
|
174
|
+
"This serializer should never be used for serialization!"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def deserialize(self, value):
|
|
178
|
+
value = {"data": value, "version": 1}
|
|
179
|
+
msg = self.serializer_v1.build_msg(value)
|
|
180
|
+
return msg
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class UserPersonalizingPWatchedSerializerV1(SimpleSerializer):
|
|
184
|
+
def __init__(self):
|
|
185
|
+
super().__init__(msg_class=features_pb2_v1.UserPersonalizingPWatched)
|
|
186
|
+
|
|
187
|
+
def serialize(self, value: dict) -> bytes:
|
|
188
|
+
root_msg = self.build_msg(value)
|
|
189
|
+
return root_msg.SerializeToString()
|
|
190
|
+
|
|
191
|
+
def build_msg(self, value) -> features_pb2_v1.UserPersonalizingPWatched:
|
|
192
|
+
root_msg = features_pb2_v1.UserPersonalizingPWatched()
|
|
193
|
+
assert value["version"] == 1, "Wrong version given!"
|
|
194
|
+
root_msg.version = value["version"]
|
|
195
|
+
data = value["data"]
|
|
196
|
+
for personalizing_key, entry_context_pwatched in data.items():
|
|
197
|
+
personalizing_msg = root_msg.data[personalizing_key]
|
|
198
|
+
for entry_context, counts in entry_context_pwatched.items():
|
|
199
|
+
entry_context_msg = getattr(personalizing_msg, entry_context)
|
|
200
|
+
entry_context_msg.attempts = int(counts["attempts"])
|
|
201
|
+
entry_context_msg.watched = int(counts["watched"])
|
|
202
|
+
return root_msg
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class UserPersonalizingPSelectSerializerV1(SimpleSerializer):
|
|
206
|
+
def __init__(self):
|
|
207
|
+
super().__init__(msg_class=features_pb2_v1.UserPersonalizingPSelect)
|
|
208
|
+
|
|
209
|
+
def serialize(self, value):
|
|
210
|
+
root_msg = features_pb2_v1.UserPersonalizingPSelect()
|
|
211
|
+
root_msg.version = value["version"]
|
|
212
|
+
data = value["data"]
|
|
213
|
+
for personalizing_key, browsed_debiased_pselecs in data.items():
|
|
214
|
+
personalizing_msg = root_msg.data[personalizing_key]
|
|
215
|
+
for (
|
|
216
|
+
browsed_debias_key,
|
|
217
|
+
position_pselects,
|
|
218
|
+
) in browsed_debiased_pselecs.items():
|
|
219
|
+
position_pselects_msg = getattr(personalizing_msg, browsed_debias_key)
|
|
220
|
+
for position, select_counts in position_pselects.items():
|
|
221
|
+
select_counts_msg = getattr(position_pselects_msg, position)
|
|
222
|
+
select_counts_msg.total_selects = int(
|
|
223
|
+
select_counts["total_selects"]
|
|
224
|
+
)
|
|
225
|
+
select_counts_msg.total_browsed = int(
|
|
226
|
+
select_counts["total_browsed"]
|
|
227
|
+
)
|
|
228
|
+
select_counts_msg.total_selects_and_watched = int(
|
|
229
|
+
select_counts["total_selects_and_watched"]
|
|
230
|
+
)
|
|
231
|
+
return root_msg.SerializeToString()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class PassThroughSerializer(Serializer):
|
|
235
|
+
def serialize(self, value):
|
|
236
|
+
return value
|
|
237
|
+
|
|
238
|
+
def deserialize(self, value):
|
|
239
|
+
return value
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
user_personalizing_pwatched_serializer_v1 = UserPersonalizingPWatchedSerializerV1()
|
|
243
|
+
user_pwatched_serializer_v1 = UserPWatchedSerializerV1()
|
|
244
|
+
user_personalizing_pselect_serializer_v1 = UserPersonalizingPSelectSerializerV1()
|
|
245
|
+
user_pselect_serializer_v1 = UserPSelectSerializerV1()
|
|
246
|
+
stream_pwatched_serializer_v0 = StreamPWatchedSerializerV0()
|
|
247
|
+
stream_pwatched_serializer_v1 = StreamPWatchedSerializerV1()
|
|
248
|
+
stream_pselect_serializer_v0 = StreamPSelectSerializerV0()
|
|
249
|
+
stream_pselect_serializer_v1 = StreamPSelectSerializerV1()
|
|
250
|
+
stream_similarity_scores_serializer_v0 = StreamSimilaritySerializerV0()
|
|
251
|
+
stream_similarity_scores_serializer_v1 = StreamSimilaritySerializerV1()
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class FeatureRegistryId(_t.NamedTuple):
|
|
255
|
+
entity_type: _t.Literal["STREAM", "USER"]
|
|
256
|
+
feature_id: str
|
|
257
|
+
version: str
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
stream_pwatched_v0_features: list[FeatureRegistryId] = [
|
|
261
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v0"),
|
|
262
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v0"),
|
|
263
|
+
FeatureRegistryId(
|
|
264
|
+
entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v0"
|
|
265
|
+
),
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
stream_pwatched_v1_features: list[FeatureRegistryId] = [
|
|
269
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v1"),
|
|
270
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v1"),
|
|
271
|
+
FeatureRegistryId(
|
|
272
|
+
entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v1"
|
|
273
|
+
),
|
|
274
|
+
]
|
|
275
|
+
|
|
276
|
+
stream_pselect_v0_features: list[FeatureRegistryId] = [
|
|
277
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v0"),
|
|
278
|
+
FeatureRegistryId(
|
|
279
|
+
entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v0"
|
|
280
|
+
),
|
|
281
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v0"),
|
|
282
|
+
]
|
|
283
|
+
|
|
284
|
+
stream_pselect_v1_features: list[FeatureRegistryId] = [
|
|
285
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v1"),
|
|
286
|
+
FeatureRegistryId(
|
|
287
|
+
entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v1"
|
|
288
|
+
),
|
|
289
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v1"),
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
stream_similarity_v0_features: list[FeatureRegistryId] = [
|
|
293
|
+
FeatureRegistryId(entity_type="STREAM", feature_id="SIMILARITY", version="v0"),
|
|
294
|
+
FeatureRegistryId(
|
|
295
|
+
entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v0"
|
|
296
|
+
),
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
stream_similarity_v1_features: list[FeatureRegistryId] = [
|
|
300
|
+
FeatureRegistryId(
|
|
301
|
+
entity_type="STREAM", feature_id="SIMILARITY#GEMINI", version="v1"
|
|
302
|
+
),
|
|
303
|
+
FeatureRegistryId(
|
|
304
|
+
entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v1"
|
|
305
|
+
),
|
|
306
|
+
]
|
|
307
|
+
|
|
308
|
+
user_personalizing_pwatched_v1_features: list[FeatureRegistryId] = [
|
|
309
|
+
FeatureRegistryId(
|
|
310
|
+
entity_type="USER", feature_id="PWATCHED#6M#CATEGORY", version="v1"
|
|
311
|
+
),
|
|
312
|
+
FeatureRegistryId(
|
|
313
|
+
entity_type="USER",
|
|
314
|
+
feature_id="PWATCHED#6M#AUTHOR_SHOW",
|
|
315
|
+
version="v1",
|
|
316
|
+
),
|
|
317
|
+
FeatureRegistryId(
|
|
318
|
+
entity_type="USER",
|
|
319
|
+
feature_id="PWATCHED#6M#GEMINI_CATEGORY",
|
|
320
|
+
version="v1",
|
|
321
|
+
),
|
|
322
|
+
]
|
|
323
|
+
|
|
324
|
+
user_personalizing_pselect_v1_features: list[FeatureRegistryId] = [
|
|
325
|
+
FeatureRegistryId(
|
|
326
|
+
entity_type="USER", feature_id="PSELECT#6M#CATEGORY", version="v1"
|
|
327
|
+
),
|
|
328
|
+
FeatureRegistryId(
|
|
329
|
+
entity_type="USER", feature_id="PSELECT#6M#AUTHOR_SHOW", version="v1"
|
|
330
|
+
),
|
|
331
|
+
FeatureRegistryId(
|
|
332
|
+
entity_type="USER", feature_id="PSELECT#6M#GEMINI_CATEGORY", version="v1"
|
|
333
|
+
),
|
|
334
|
+
]
|
|
335
|
+
|
|
336
|
+
user_bias_pwatched_v1_features: list[FeatureRegistryId] = [
|
|
337
|
+
FeatureRegistryId(entity_type="USER", feature_id="PWATCHED#6M", version="v1")
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
user_bias_pselect_v1_features: list[FeatureRegistryId] = [
|
|
341
|
+
FeatureRegistryId(entity_type="USER", feature_id="PSELECT#6M", version="v1")
|
|
342
|
+
]
|
|
343
|
+
|
|
344
|
+
features_serializer_tuples: list[tuple[list[FeatureRegistryId], Serializer]] = [
|
|
345
|
+
(stream_pwatched_v0_features, stream_pwatched_serializer_v0),
|
|
346
|
+
(stream_pwatched_v1_features, stream_pwatched_serializer_v1),
|
|
347
|
+
(stream_pselect_v0_features, stream_pselect_serializer_v0),
|
|
348
|
+
(stream_pselect_v1_features, stream_pselect_serializer_v1),
|
|
349
|
+
(stream_similarity_v0_features, stream_similarity_scores_serializer_v0),
|
|
350
|
+
(stream_similarity_v1_features, stream_similarity_scores_serializer_v1),
|
|
351
|
+
(
|
|
352
|
+
user_personalizing_pwatched_v1_features,
|
|
353
|
+
user_personalizing_pwatched_serializer_v1,
|
|
354
|
+
),
|
|
355
|
+
(user_bias_pwatched_v1_features, user_pwatched_serializer_v1),
|
|
356
|
+
(user_personalizing_pselect_v1_features, user_personalizing_pselect_serializer_v1),
|
|
357
|
+
(user_bias_pselect_v1_features, user_pselect_serializer_v1),
|
|
358
|
+
]
|
|
359
|
+
|
|
360
|
+
SerializerRegistry: dict[FeatureRegistryId, Serializer] = {
|
|
361
|
+
FeatureRegistryId(
|
|
362
|
+
entity_type="PASS_THROUGH", feature_id="PASS_THROUGH", version="v1"
|
|
363
|
+
): PassThroughSerializer()
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
for feature_ids, serializer in features_serializer_tuples:
|
|
367
|
+
for feature_id in feature_ids:
|
|
368
|
+
SerializerRegistry[feature_id] = serializer
|
|
@@ -5,17 +5,22 @@ import sys
|
|
|
5
5
|
from http import HTTPStatus
|
|
6
6
|
from typing import Any, Dict, List, Optional
|
|
7
7
|
import time
|
|
8
|
+
from contextlib import asynccontextmanager, AsyncExitStack
|
|
8
9
|
|
|
9
10
|
import aiobotocore.session
|
|
11
|
+
from aiobotocore.config import AioConfig
|
|
10
12
|
from fastapi import FastAPI, HTTPException, Request, Response
|
|
11
13
|
from fastapi.encoders import jsonable_encoder
|
|
12
14
|
import newrelic.agent
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
from .cache import make_features_cache
|
|
16
|
-
from .dynamo import
|
|
18
|
+
from .dynamo import set_all_features, FeatureRetrievalMeta
|
|
17
19
|
from .model_store import download_and_load_model
|
|
18
20
|
from .settings import Settings
|
|
21
|
+
from . import exceptions
|
|
22
|
+
from ._serializers import SerializerRegistry
|
|
23
|
+
from google.protobuf import text_format
|
|
19
24
|
|
|
20
25
|
logging.basicConfig(
|
|
21
26
|
level=logging.INFO,
|
|
@@ -25,7 +30,62 @@ logging.basicConfig(
|
|
|
25
30
|
)
|
|
26
31
|
|
|
27
32
|
logger = logging.getLogger(__name__)
|
|
28
|
-
|
|
33
|
+
MAX_POOL_CONNECTIONS = int(os.environ.get("MAX_POOL_CONNECTIONS", 50))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class StreamLoggerProxy:
|
|
37
|
+
def __init__(self, stream, feature_ids):
|
|
38
|
+
self._stream = stream
|
|
39
|
+
self._feature_ids = feature_ids
|
|
40
|
+
|
|
41
|
+
def __repr__(self):
|
|
42
|
+
parts = []
|
|
43
|
+
for k, v in self._stream.items():
|
|
44
|
+
if k in self._feature_ids:
|
|
45
|
+
# Format only when needed for the log output
|
|
46
|
+
formatted_v = text_format.MessageToString(v, as_one_line=True)
|
|
47
|
+
parts.append(f"'{k}': '{formatted_v}'")
|
|
48
|
+
else:
|
|
49
|
+
parts.append(f"'{k}': {repr(v)}")
|
|
50
|
+
return "{" + ", ".join(parts) + "}"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
async def load_model(state, cfg: Settings) -> None:
|
|
54
|
+
if not cfg.s3_model_path:
|
|
55
|
+
logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
|
|
56
|
+
else:
|
|
57
|
+
try:
|
|
58
|
+
# Pass the persistent session/client if needed
|
|
59
|
+
state["model"] = await download_and_load_model(
|
|
60
|
+
cfg.s3_model_path, aio_session=state["session"]
|
|
61
|
+
)
|
|
62
|
+
state["stream_features"] = state["model"].get("stream_features", [])
|
|
63
|
+
state["user_features"] = state["model"].get("user_features", [])
|
|
64
|
+
valid_features = set(
|
|
65
|
+
(entity_type, feature_id)
|
|
66
|
+
for entity_type, feature_id, _ in SerializerRegistry.keys()
|
|
67
|
+
)
|
|
68
|
+
all_features = set(
|
|
69
|
+
[("STREAM", feature_name) for feature_name in state["stream_features"]]
|
|
70
|
+
+ [("USER", feature_name) for feature_name in state["user_features"]]
|
|
71
|
+
)
|
|
72
|
+
invalid_features = all_features.difference(valid_features)
|
|
73
|
+
if invalid_features:
|
|
74
|
+
raise exceptions.InvalidFeaturesException(
|
|
75
|
+
f"Received invalid features: {invalid_features}"
|
|
76
|
+
)
|
|
77
|
+
newrelic.agent.add_custom_attribute(
|
|
78
|
+
"total_stream_features", len(state["stream_features"])
|
|
79
|
+
)
|
|
80
|
+
newrelic.agent.add_custom_attribute(
|
|
81
|
+
"total_user_features", len(state["user_features"])
|
|
82
|
+
)
|
|
83
|
+
logger.info("Model loaded successfully.")
|
|
84
|
+
except exceptions.InvalidFeaturesException as e:
|
|
85
|
+
logger.error("%s", e)
|
|
86
|
+
raise e
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.critical("Failed to load model: %s", e)
|
|
29
89
|
|
|
30
90
|
|
|
31
91
|
def create_app(
|
|
@@ -39,40 +99,51 @@ def create_app(
|
|
|
39
99
|
"""
|
|
40
100
|
cfg = settings or Settings()
|
|
41
101
|
|
|
42
|
-
app = FastAPI(
|
|
43
|
-
title="ML Stream Scorer",
|
|
44
|
-
description="Scores video streams using a pre-trained ML model and DynamoDB features.",
|
|
45
|
-
version="1.0.0",
|
|
46
|
-
)
|
|
47
|
-
|
|
48
102
|
# Mutable state: cache + model
|
|
49
|
-
|
|
103
|
+
stream_features_cache = make_features_cache(cfg.stream_cache_maxsize)
|
|
104
|
+
user_features_cache = make_features_cache(cfg.user_cache_maxsize)
|
|
105
|
+
aws_session = aiobotocore.session.get_session()
|
|
50
106
|
state: Dict[str, Any] = {
|
|
51
107
|
"model": preloaded_model,
|
|
52
|
-
"session":
|
|
108
|
+
"session": aws_session,
|
|
53
109
|
"model_name": (
|
|
54
110
|
os.path.basename(cfg.s3_model_path) if cfg.s3_model_path else None
|
|
55
111
|
),
|
|
56
112
|
}
|
|
57
113
|
|
|
58
|
-
@
|
|
59
|
-
async def
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
114
|
+
@asynccontextmanager
|
|
115
|
+
async def lifespan(app_server: FastAPI):
|
|
116
|
+
"""
|
|
117
|
+
Handles startup and shutdown logic.
|
|
118
|
+
Everything before 'yield' runs on startup.
|
|
119
|
+
Everything after 'yield' runs on shutdown.
|
|
120
|
+
"""
|
|
121
|
+
# 1. Load ML Model
|
|
122
|
+
if state["model"] is None:
|
|
123
|
+
await load_model(state, cfg)
|
|
124
|
+
async with AsyncExitStack() as stack:
|
|
125
|
+
# 2. Initialize DynamoDB Client (Persistent Pool)
|
|
126
|
+
session = state["session"]
|
|
127
|
+
state["dynamo_client"] = await stack.enter_async_context(
|
|
128
|
+
session.create_client(
|
|
129
|
+
"dynamodb",
|
|
130
|
+
# Ensure the pool is large enough for ML concurrency
|
|
131
|
+
config=AioConfig(max_pool_connections=MAX_POOL_CONNECTIONS),
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
logger.info("DynamoDB persistent client initialized.")
|
|
135
|
+
yield
|
|
63
136
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
137
|
+
# 3. Shutdown Logic
|
|
138
|
+
# The AsyncExitStack automatically closes the DynamoDB client pool here
|
|
139
|
+
logger.info("Shutting down: Connection pools closed.")
|
|
67
140
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
except Exception as e:
|
|
75
|
-
logger.critical("Failed to load model: %s", e)
|
|
141
|
+
app = FastAPI(
|
|
142
|
+
title="ML Stream Scorer",
|
|
143
|
+
description="Scores video streams using a pre-trained ML model and DynamoDB features.",
|
|
144
|
+
version="1.0.0",
|
|
145
|
+
lifespan=lifespan,
|
|
146
|
+
)
|
|
76
147
|
|
|
77
148
|
@app.get("/health", status_code=HTTPStatus.OK)
|
|
78
149
|
async def health():
|
|
@@ -85,7 +156,8 @@ def create_app(
|
|
|
85
156
|
return {
|
|
86
157
|
"status": "ok",
|
|
87
158
|
"model_loaded": True,
|
|
88
|
-
"
|
|
159
|
+
"stream_cache_size": len(stream_features_cache),
|
|
160
|
+
"user_cache_size": len(user_features_cache),
|
|
89
161
|
"model_name": state.get("model_name"),
|
|
90
162
|
"stream_features": state.get("stream_features", []),
|
|
91
163
|
}
|
|
@@ -120,8 +192,11 @@ def create_app(
|
|
|
120
192
|
# Feature fetch (optional based on model)
|
|
121
193
|
model = state["model"]
|
|
122
194
|
stream_features = model.get("stream_features", []) or []
|
|
195
|
+
user_features = model.get("user_features", []) or []
|
|
123
196
|
retrieval_meta = FeatureRetrievalMeta(
|
|
124
197
|
cache_misses=0,
|
|
198
|
+
stream_cache_misses=0,
|
|
199
|
+
user_cache_misses=0,
|
|
125
200
|
retrieval_ms=0,
|
|
126
201
|
success=True,
|
|
127
202
|
cache_delay_minutes=0,
|
|
@@ -129,21 +204,40 @@ def create_app(
|
|
|
129
204
|
parsing_ms=0,
|
|
130
205
|
)
|
|
131
206
|
if stream_features:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
207
|
+
try:
|
|
208
|
+
retrieval_meta = await set_all_features(
|
|
209
|
+
dynamo_client=state["dynamo_client"],
|
|
210
|
+
user=user,
|
|
211
|
+
streams=streams,
|
|
212
|
+
stream_features=stream_features,
|
|
213
|
+
user_features=user_features,
|
|
214
|
+
stream_features_cache=stream_features_cache,
|
|
215
|
+
user_features_cache=user_features_cache,
|
|
216
|
+
features_table=cfg.features_table,
|
|
217
|
+
cache_sep=cfg.cache_separator,
|
|
218
|
+
)
|
|
219
|
+
except exceptions.InvalidFeaturesException as e:
|
|
220
|
+
logger.error(
|
|
221
|
+
"The following features are not present in the SerializerRegistry %s",
|
|
222
|
+
e,
|
|
223
|
+
)
|
|
224
|
+
raise HTTPException(
|
|
225
|
+
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
|
226
|
+
detail=f"Received invalid features from feature store: {e}",
|
|
227
|
+
) from e
|
|
141
228
|
|
|
142
229
|
random_number = random.random()
|
|
143
230
|
userid = user.get("userid", "")
|
|
144
231
|
# Sampling logs
|
|
145
232
|
if random_number < cfg.logs_fraction:
|
|
146
|
-
logger.info(
|
|
233
|
+
logger.info(
|
|
234
|
+
"User %s streams: %s",
|
|
235
|
+
user.get("userid", ""),
|
|
236
|
+
[
|
|
237
|
+
StreamLoggerProxy(s, stream_features + user_features)
|
|
238
|
+
for s in streams
|
|
239
|
+
],
|
|
240
|
+
)
|
|
147
241
|
|
|
148
242
|
# Synchronous model execution (user code)
|
|
149
243
|
try:
|
|
@@ -168,8 +262,9 @@ def create_app(
|
|
|
168
262
|
newrelic.agent.record_custom_event(
|
|
169
263
|
"Inference",
|
|
170
264
|
{
|
|
171
|
-
"app_name": APP_NAME,
|
|
172
265
|
"cache_misses": retrieval_meta.cache_misses,
|
|
266
|
+
"user_cache_misses": retrieval_meta.user_cache_misses,
|
|
267
|
+
"stream_cache_misses": retrieval_meta.stream_cache_misses,
|
|
173
268
|
"retrieval_success": int(retrieval_meta.success),
|
|
174
269
|
"cache_delay_minutes": retrieval_meta.cache_delay_minutes,
|
|
175
270
|
"dynamo_ms": retrieval_meta.dynamo_ms,
|