haystack-ml-stack 0.2.5__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/PKG-INFO +10 -8
  2. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/pyproject.toml +9 -4
  3. haystack_ml_stack-0.3.0/src/haystack_ml_stack/__init__.py +14 -0
  4. haystack_ml_stack-0.3.0/src/haystack_ml_stack/_serializers.py +368 -0
  5. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/app.py +103 -38
  6. haystack_ml_stack-0.3.0/src/haystack_ml_stack/dynamo.py +326 -0
  7. haystack_ml_stack-0.3.0/src/haystack_ml_stack/exceptions.py +5 -0
  8. haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/__init__.py +0 -0
  9. haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/v1/__init__.py +0 -0
  10. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/settings.py +2 -1
  11. haystack_ml_stack-0.3.0/src/haystack_ml_stack/utils.py +675 -0
  12. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/PKG-INFO +10 -8
  13. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/SOURCES.txt +5 -0
  14. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/requires.txt +3 -0
  15. haystack_ml_stack-0.3.0/tests/test_serializers.py +152 -0
  16. haystack_ml_stack-0.3.0/tests/test_utils.py +510 -0
  17. haystack_ml_stack-0.2.5/src/haystack_ml_stack/__init__.py +0 -4
  18. haystack_ml_stack-0.2.5/src/haystack_ml_stack/dynamo.py +0 -194
  19. haystack_ml_stack-0.2.5/src/haystack_ml_stack/utils.py +0 -312
  20. haystack_ml_stack-0.2.5/tests/test_utils.py +0 -76
  21. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/README.md +0 -0
  22. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/setup.cfg +0 -0
  23. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/cache.py +0 -0
  24. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/model_store.py +0 -0
  25. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
  26. {haystack_ml_stack-0.2.5 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
@@ -1,18 +1,20 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.2.5
3
+ Version: 0.3.0
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: pydantic==2.5.0
10
- Requires-Dist: cachetools==5.5.2
11
- Requires-Dist: cloudpickle==2.2.1
12
- Requires-Dist: aioboto3==12.0.0
13
- Requires-Dist: fastapi==0.104.1
14
- Requires-Dist: pydantic-settings==2.2
15
- Requires-Dist: newrelic==11.1.0
9
+ Requires-Dist: protobuf==6.33.2
10
+ Provides-Extra: server
11
+ Requires-Dist: pydantic==2.5.0; extra == "server"
12
+ Requires-Dist: cachetools==5.5.2; extra == "server"
13
+ Requires-Dist: cloudpickle==2.2.1; extra == "server"
14
+ Requires-Dist: aioboto3==12.0.0; extra == "server"
15
+ Requires-Dist: fastapi==0.104.1; extra == "server"
16
+ Requires-Dist: pydantic-settings==2.2; extra == "server"
17
+ Requires-Dist: newrelic==11.1.0; extra == "server"
16
18
 
17
19
  # Haystack ML Stack
18
20
 
@@ -5,18 +5,23 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "haystack-ml-stack"
8
- version = "0.2.5"
8
+ version = "0.3.0"
9
9
  description = "Functions related to Haystack ML"
10
10
  readme = "README.md"
11
11
  authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
12
12
  requires-python = ">=3.11"
13
13
  dependencies = [
14
+ "protobuf==6.33.2",
15
+ ]
16
+ license = { text = "MIT" }
17
+
18
+ [project.optional-dependencies]
19
+ server = [
14
20
  "pydantic==2.5.0",
15
21
  "cachetools==5.5.2",
16
22
  "cloudpickle==2.2.1",
17
23
  "aioboto3==12.0.0",
18
24
  "fastapi==0.104.1",
19
25
  "pydantic-settings==2.2",
20
- "newrelic==11.1.0"
21
- ]
22
- license = { text = "MIT" }
26
+ "newrelic==11.1.0",
27
+ ]
@@ -0,0 +1,14 @@
1
+ __all__ = []
2
+
3
+ try:
4
+ from .app import create_app
5
+
6
+ __all__ = ["create_app"]
7
+ except ImportError:
8
+ pass
9
+
10
+ from ._serializers import SerializerRegistry, FeatureRegistryId
11
+
12
+ __all__ = [*__all__, "SerializerRegistry", "FeatureRegistryId"]
13
+
14
+ __version__ = "0.3.0"
@@ -0,0 +1,368 @@
1
+ from .generated.v1 import features_pb2 as features_pb2_v1
2
+ from google.protobuf.message import Message
3
+ from google.protobuf.json_format import ParseDict as ProtoParseDict
4
+ import typing as _t
5
+ from abc import ABC, abstractmethod
6
+
7
+ MessageType = _t.TypeVar("MessageType", bound=Message)
8
+
9
+
10
+ class Serializer(ABC):
11
+ @abstractmethod
12
+ def serialize(self, value) -> bytes: ...
13
+
14
+ @abstractmethod
15
+ def deserialize(self, value: bytes) -> _t.Any: ...
16
+
17
+
18
+ class SimpleSerializer(Serializer, _t.Generic[MessageType]):
19
+ """This simple serializer uses the function `ParseDict` provided by google
20
+ to parse dictionaries. While it allows for simple code, it's very slow to run.
21
+ This class should be used directly for PoCs only, production serializers should have
22
+ custom implementations where fields are set directly. Early tests show that
23
+ manual serialization can provide 10x speedup.
24
+
25
+ Deserialization is fine since it deserializes from the binary into the message
26
+ itself, it doesn't need to create a dictionary."""
27
+
28
+ def __init__(self, msg_class: type[MessageType]):
29
+ self.msg_class = msg_class
30
+ return
31
+
32
+ def serialize(self, value) -> bytes:
33
+ msg = self.msg_class()
34
+ return ProtoParseDict(value, message=msg).SerializeToString()
35
+
36
+ def deserialize(self, value) -> MessageType:
37
+ msg: Message = self.msg_class()
38
+ msg.ParseFromString(value)
39
+ return msg
40
+
41
+
42
+ class StreamPWatchedSerializerV1(SimpleSerializer):
43
+ def __init__(self):
44
+ super().__init__(msg_class=features_pb2_v1.StreamPWatched)
45
+
46
+ def serialize(self, value):
47
+ root_msg = self.build_msg(value)
48
+ return root_msg.SerializeToString()
49
+
50
+ def build_msg(self, value) -> features_pb2_v1.StreamPWatched:
51
+ message = self.msg_class()
52
+ assert value["version"] == 1, "Wrong version given!"
53
+ message.version = value["version"]
54
+ for entry_context, counts in value["data"].items():
55
+ entry_context_msg: features_pb2_v1.EntryContextCounts = getattr(
56
+ message.data, entry_context
57
+ )
58
+ entry_context_msg.attempts = int(counts["attempts"])
59
+ entry_context_msg.watched = int(counts["watched"])
60
+ return message
61
+
62
+
63
+ UserPWatchedSerializerV1 = StreamPWatchedSerializerV1
64
+
65
+
66
+ class StreamPWatchedSerializerV0(Serializer):
67
+ serializer_v1 = StreamPWatchedSerializerV1()
68
+
69
+ def serialize(self, value) -> bytes:
70
+ raise NotImplementedError(
71
+ "This serializer should never be used for serialization!"
72
+ )
73
+
74
+ def deserialize(self, value) -> features_pb2_v1.StreamPWatched:
75
+ value = {
76
+ "data": {
77
+ entry_context.replace(" ", "_"): counts
78
+ for entry_context, counts in value.items()
79
+ },
80
+ "version": 1,
81
+ }
82
+ return self.serializer_v1.build_msg(value)
83
+
84
+
85
+ class StreamPSelectSerializerV1(SimpleSerializer):
86
+ def __init__(self):
87
+ super().__init__(msg_class=features_pb2_v1.StreamPSelect)
88
+ return
89
+
90
+ def serialize(self, value) -> bytes:
91
+ root_msg = self.build_msg(value)
92
+ return root_msg.SerializeToString()
93
+
94
+ def build_msg(self, value) -> features_pb2_v1.StreamPSelect:
95
+ message: features_pb2_v1.StreamPSelect = self.msg_class()
96
+ assert value["version"] == 1, "Wrong version given!"
97
+ message.version = 1
98
+ data = value["data"]
99
+ for (
100
+ browsed_debias_key,
101
+ position_pselects,
102
+ ) in data.items():
103
+ position_pselects_msg: features_pb2_v1.PositionPSelect = getattr(
104
+ message.data, browsed_debias_key
105
+ )
106
+ for position, select_counts in position_pselects.items():
107
+ select_counts_msg = getattr(position_pselects_msg, position)
108
+ select_counts_msg.total_selects = int(select_counts["total_selects"])
109
+ select_counts_msg.total_browsed = int(select_counts["total_browsed"])
110
+ select_counts_msg.total_selects_and_watched = int(
111
+ select_counts["total_selects_and_watched"]
112
+ )
113
+ return message
114
+
115
+
116
+ UserPSelectSerializerV1 = StreamPSelectSerializerV1
117
+
118
+
119
+ class StreamPSelectSerializerV0(Serializer):
120
+ serializer_v1 = StreamPSelectSerializerV1()
121
+
122
+ def serialize(self, value) -> bytes:
123
+ raise NotImplementedError(
124
+ "This serializer should never be used for serialization!"
125
+ )
126
+
127
+ def deserialize(self, value):
128
+ key_mapping = {
129
+ "0": "first_pos",
130
+ "1": "second_pos",
131
+ "2": "third_pos",
132
+ "3+": "rest_pos",
133
+ }
134
+ for browsed_debiasing in value.keys():
135
+ for old_key, new_key in key_mapping.items():
136
+ if old_key not in value[browsed_debiasing]:
137
+ continue
138
+ value[browsed_debiasing][new_key] = value[browsed_debiasing].pop(
139
+ old_key
140
+ )
141
+ out = {
142
+ "data": {
143
+ "up_to_4_browsed": value["4_browsed"],
144
+ "all_browsed": value["all_browsed"],
145
+ },
146
+ "version": 1,
147
+ }
148
+ msg = self.serializer_v1.build_msg(value=out)
149
+ return msg
150
+
151
+
152
+ class StreamSimilaritySerializerV1(SimpleSerializer):
153
+ def __init__(self):
154
+ super().__init__(msg_class=features_pb2_v1.StreamSimilarityScores)
155
+
156
+ def serialize(self, value):
157
+ msg = self.build_msg(value)
158
+ return msg.SerializeToString()
159
+
160
+ def build_msg(self, value) -> features_pb2_v1.StreamSimilarityScores:
161
+ message = self.msg_class()
162
+ assert value["version"] == 1, "Wrong version given!"
163
+ message.version = value["version"]
164
+ for key, score in value["data"].items():
165
+ message.data[key] = score
166
+ return message
167
+
168
+
169
+ class StreamSimilaritySerializerV0(Serializer):
170
+ serializer_v1 = StreamSimilaritySerializerV1()
171
+
172
+ def serialize(self, value):
173
+ raise NotImplementedError(
174
+ "This serializer should never be used for serialization!"
175
+ )
176
+
177
+ def deserialize(self, value):
178
+ value = {"data": value, "version": 1}
179
+ msg = self.serializer_v1.build_msg(value)
180
+ return msg
181
+
182
+
183
+ class UserPersonalizingPWatchedSerializerV1(SimpleSerializer):
184
+ def __init__(self):
185
+ super().__init__(msg_class=features_pb2_v1.UserPersonalizingPWatched)
186
+
187
+ def serialize(self, value: dict) -> bytes:
188
+ root_msg = self.build_msg(value)
189
+ return root_msg.SerializeToString()
190
+
191
+ def build_msg(self, value) -> features_pb2_v1.UserPersonalizingPWatched:
192
+ root_msg = features_pb2_v1.UserPersonalizingPWatched()
193
+ assert value["version"] == 1, "Wrong version given!"
194
+ root_msg.version = value["version"]
195
+ data = value["data"]
196
+ for personalizing_key, entry_context_pwatched in data.items():
197
+ personalizing_msg = root_msg.data[personalizing_key]
198
+ for entry_context, counts in entry_context_pwatched.items():
199
+ entry_context_msg = getattr(personalizing_msg, entry_context)
200
+ entry_context_msg.attempts = int(counts["attempts"])
201
+ entry_context_msg.watched = int(counts["watched"])
202
+ return root_msg
203
+
204
+
205
+ class UserPersonalizingPSelectSerializerV1(SimpleSerializer):
206
+ def __init__(self):
207
+ super().__init__(msg_class=features_pb2_v1.UserPersonalizingPSelect)
208
+
209
+ def serialize(self, value):
210
+ root_msg = features_pb2_v1.UserPersonalizingPSelect()
211
+ root_msg.version = value["version"]
212
+ data = value["data"]
213
+ for personalizing_key, browsed_debiased_pselecs in data.items():
214
+ personalizing_msg = root_msg.data[personalizing_key]
215
+ for (
216
+ browsed_debias_key,
217
+ position_pselects,
218
+ ) in browsed_debiased_pselecs.items():
219
+ position_pselects_msg = getattr(personalizing_msg, browsed_debias_key)
220
+ for position, select_counts in position_pselects.items():
221
+ select_counts_msg = getattr(position_pselects_msg, position)
222
+ select_counts_msg.total_selects = int(
223
+ select_counts["total_selects"]
224
+ )
225
+ select_counts_msg.total_browsed = int(
226
+ select_counts["total_browsed"]
227
+ )
228
+ select_counts_msg.total_selects_and_watched = int(
229
+ select_counts["total_selects_and_watched"]
230
+ )
231
+ return root_msg.SerializeToString()
232
+
233
+
234
+ class PassThroughSerializer(Serializer):
235
+ def serialize(self, value):
236
+ return value
237
+
238
+ def deserialize(self, value):
239
+ return value
240
+
241
+
242
+ user_personalizing_pwatched_serializer_v1 = UserPersonalizingPWatchedSerializerV1()
243
+ user_pwatched_serializer_v1 = UserPWatchedSerializerV1()
244
+ user_personalizing_pselect_serializer_v1 = UserPersonalizingPSelectSerializerV1()
245
+ user_pselect_serializer_v1 = UserPSelectSerializerV1()
246
+ stream_pwatched_serializer_v0 = StreamPWatchedSerializerV0()
247
+ stream_pwatched_serializer_v1 = StreamPWatchedSerializerV1()
248
+ stream_pselect_serializer_v0 = StreamPSelectSerializerV0()
249
+ stream_pselect_serializer_v1 = StreamPSelectSerializerV1()
250
+ stream_similarity_scores_serializer_v0 = StreamSimilaritySerializerV0()
251
+ stream_similarity_scores_serializer_v1 = StreamSimilaritySerializerV1()
252
+
253
+
254
+ class FeatureRegistryId(_t.NamedTuple):
255
+ entity_type: _t.Literal["STREAM", "USER"]
256
+ feature_id: str
257
+ version: str
258
+
259
+
260
+ stream_pwatched_v0_features: list[FeatureRegistryId] = [
261
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v0"),
262
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v0"),
263
+ FeatureRegistryId(
264
+ entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v0"
265
+ ),
266
+ ]
267
+
268
+ stream_pwatched_v1_features: list[FeatureRegistryId] = [
269
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v1"),
270
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v1"),
271
+ FeatureRegistryId(
272
+ entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v1"
273
+ ),
274
+ ]
275
+
276
+ stream_pselect_v0_features: list[FeatureRegistryId] = [
277
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v0"),
278
+ FeatureRegistryId(
279
+ entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v0"
280
+ ),
281
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v0"),
282
+ ]
283
+
284
+ stream_pselect_v1_features: list[FeatureRegistryId] = [
285
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v1"),
286
+ FeatureRegistryId(
287
+ entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v1"
288
+ ),
289
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v1"),
290
+ ]
291
+
292
+ stream_similarity_v0_features: list[FeatureRegistryId] = [
293
+ FeatureRegistryId(entity_type="STREAM", feature_id="SIMILARITY", version="v0"),
294
+ FeatureRegistryId(
295
+ entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v0"
296
+ ),
297
+ ]
298
+
299
+ stream_similarity_v1_features: list[FeatureRegistryId] = [
300
+ FeatureRegistryId(
301
+ entity_type="STREAM", feature_id="SIMILARITY#GEMINI", version="v1"
302
+ ),
303
+ FeatureRegistryId(
304
+ entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v1"
305
+ ),
306
+ ]
307
+
308
+ user_personalizing_pwatched_v1_features: list[FeatureRegistryId] = [
309
+ FeatureRegistryId(
310
+ entity_type="USER", feature_id="PWATCHED#6M#CATEGORY", version="v1"
311
+ ),
312
+ FeatureRegistryId(
313
+ entity_type="USER",
314
+ feature_id="PWATCHED#6M#AUTHOR_SHOW",
315
+ version="v1",
316
+ ),
317
+ FeatureRegistryId(
318
+ entity_type="USER",
319
+ feature_id="PWATCHED#6M#GEMINI_CATEGORY",
320
+ version="v1",
321
+ ),
322
+ ]
323
+
324
+ user_personalizing_pselect_v1_features: list[FeatureRegistryId] = [
325
+ FeatureRegistryId(
326
+ entity_type="USER", feature_id="PSELECT#6M#CATEGORY", version="v1"
327
+ ),
328
+ FeatureRegistryId(
329
+ entity_type="USER", feature_id="PSELECT#6M#AUTHOR_SHOW", version="v1"
330
+ ),
331
+ FeatureRegistryId(
332
+ entity_type="USER", feature_id="PSELECT#6M#GEMINI_CATEGORY", version="v1"
333
+ ),
334
+ ]
335
+
336
+ user_bias_pwatched_v1_features: list[FeatureRegistryId] = [
337
+ FeatureRegistryId(entity_type="USER", feature_id="PWATCHED#6M", version="v1")
338
+ ]
339
+
340
+ user_bias_pselect_v1_features: list[FeatureRegistryId] = [
341
+ FeatureRegistryId(entity_type="USER", feature_id="PSELECT#6M", version="v1")
342
+ ]
343
+
344
+ features_serializer_tuples: list[tuple[list[FeatureRegistryId], Serializer]] = [
345
+ (stream_pwatched_v0_features, stream_pwatched_serializer_v0),
346
+ (stream_pwatched_v1_features, stream_pwatched_serializer_v1),
347
+ (stream_pselect_v0_features, stream_pselect_serializer_v0),
348
+ (stream_pselect_v1_features, stream_pselect_serializer_v1),
349
+ (stream_similarity_v0_features, stream_similarity_scores_serializer_v0),
350
+ (stream_similarity_v1_features, stream_similarity_scores_serializer_v1),
351
+ (
352
+ user_personalizing_pwatched_v1_features,
353
+ user_personalizing_pwatched_serializer_v1,
354
+ ),
355
+ (user_bias_pwatched_v1_features, user_pwatched_serializer_v1),
356
+ (user_personalizing_pselect_v1_features, user_personalizing_pselect_serializer_v1),
357
+ (user_bias_pselect_v1_features, user_pselect_serializer_v1),
358
+ ]
359
+
360
+ SerializerRegistry: dict[FeatureRegistryId, Serializer] = {
361
+ FeatureRegistryId(
362
+ entity_type="PASS_THROUGH", feature_id="PASS_THROUGH", version="v1"
363
+ ): PassThroughSerializer()
364
+ }
365
+
366
+ for feature_ids, serializer in features_serializer_tuples:
367
+ for feature_id in feature_ids:
368
+ SerializerRegistry[feature_id] = serializer
@@ -15,9 +15,12 @@ import newrelic.agent
15
15
 
16
16
 
17
17
  from .cache import make_features_cache
18
- from .dynamo import set_stream_features, FeatureRetrievalMeta
18
+ from .dynamo import set_all_features, FeatureRetrievalMeta
19
19
  from .model_store import download_and_load_model
20
20
  from .settings import Settings
21
+ from . import exceptions
22
+ from ._serializers import SerializerRegistry
23
+ from google.protobuf import text_format
21
24
 
22
25
  logging.basicConfig(
23
26
  level=logging.INFO,
@@ -30,6 +33,61 @@ logger = logging.getLogger(__name__)
30
33
  MAX_POOL_CONNECTIONS = int(os.environ.get("MAX_POOL_CONNECTIONS", 50))
31
34
 
32
35
 
36
+ class StreamLoggerProxy:
37
+ def __init__(self, stream, feature_ids):
38
+ self._stream = stream
39
+ self._feature_ids = feature_ids
40
+
41
+ def __repr__(self):
42
+ parts = []
43
+ for k, v in self._stream.items():
44
+ if k in self._feature_ids:
45
+ # Format only when needed for the log output
46
+ formatted_v = text_format.MessageToString(v, as_one_line=True)
47
+ parts.append(f"'{k}': '{formatted_v}'")
48
+ else:
49
+ parts.append(f"'{k}': {repr(v)}")
50
+ return "{" + ", ".join(parts) + "}"
51
+
52
+
53
+ async def load_model(state, cfg: Settings) -> None:
54
+ if not cfg.s3_model_path:
55
+ logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
56
+ else:
57
+ try:
58
+ # Pass the persistent session/client if needed
59
+ state["model"] = await download_and_load_model(
60
+ cfg.s3_model_path, aio_session=state["session"]
61
+ )
62
+ state["stream_features"] = state["model"].get("stream_features", [])
63
+ state["user_features"] = state["model"].get("user_features", [])
64
+ valid_features = set(
65
+ (entity_type, feature_id)
66
+ for entity_type, feature_id, _ in SerializerRegistry.keys()
67
+ )
68
+ all_features = set(
69
+ [("STREAM", feature_name) for feature_name in state["stream_features"]]
70
+ + [("USER", feature_name) for feature_name in state["user_features"]]
71
+ )
72
+ invalid_features = all_features.difference(valid_features)
73
+ if invalid_features:
74
+ raise exceptions.InvalidFeaturesException(
75
+ f"Received invalid features: {invalid_features}"
76
+ )
77
+ newrelic.agent.add_custom_attribute(
78
+ "total_stream_features", len(state["stream_features"])
79
+ )
80
+ newrelic.agent.add_custom_attribute(
81
+ "total_user_features", len(state["user_features"])
82
+ )
83
+ logger.info("Model loaded successfully.")
84
+ except exceptions.InvalidFeaturesException as e:
85
+ logger.error("%s", e)
86
+ raise e
87
+ except Exception as e:
88
+ logger.critical("Failed to load model: %s", e)
89
+
90
+
33
91
  def create_app(
34
92
  settings: Optional[Settings] = None,
35
93
  *,
@@ -42,10 +100,12 @@ def create_app(
42
100
  cfg = settings or Settings()
43
101
 
44
102
  # Mutable state: cache + model
45
- features_cache = make_features_cache(cfg.cache_maxsize)
103
+ stream_features_cache = make_features_cache(cfg.stream_cache_maxsize)
104
+ user_features_cache = make_features_cache(cfg.user_cache_maxsize)
105
+ aws_session = aiobotocore.session.get_session()
46
106
  state: Dict[str, Any] = {
47
107
  "model": preloaded_model,
48
- "session": aiobotocore.session.get_session(),
108
+ "session": aws_session,
49
109
  "model_name": (
50
110
  os.path.basename(cfg.s3_model_path) if cfg.s3_model_path else None
51
111
  ),
@@ -58,8 +118,11 @@ def create_app(
58
118
  Everything before 'yield' runs on startup.
59
119
  Everything after 'yield' runs on shutdown.
60
120
  """
121
+ # 1. Load ML Model
122
+ if state["model"] is None:
123
+ await load_model(state, cfg)
61
124
  async with AsyncExitStack() as stack:
62
- # 1. Initialize DynamoDB Client (Persistent Pool)
125
+ # 2. Initialize DynamoDB Client (Persistent Pool)
63
126
  session = state["session"]
64
127
  state["dynamo_client"] = await stack.enter_async_context(
65
128
  session.create_client(
@@ -69,29 +132,6 @@ def create_app(
69
132
  )
70
133
  )
71
134
  logger.info("DynamoDB persistent client initialized.")
72
-
73
- # 2. Load ML Model
74
- if state["model"] is None:
75
- if not cfg.s3_model_path:
76
- logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
77
- else:
78
- try:
79
- # Pass the persistent session/client if needed
80
- state["model"] = await download_and_load_model(
81
- cfg.s3_model_path, aio_session=state["session"]
82
- )
83
- state["stream_features"] = state["model"].get(
84
- "stream_features", []
85
- )
86
- state["user_features"] = state["model"].get("user_features", [])
87
-
88
- newrelic.agent.add_custom_attribute(
89
- "total_stream_features", len(state["stream_features"])
90
- )
91
- logger.info("Model loaded successfully.")
92
- except Exception as e:
93
- logger.critical("Failed to load model: %s", e)
94
-
95
135
  yield
96
136
 
97
137
  # 3. Shutdown Logic
@@ -116,7 +156,8 @@ def create_app(
116
156
  return {
117
157
  "status": "ok",
118
158
  "model_loaded": True,
119
- "cache_size": len(features_cache),
159
+ "stream_cache_size": len(stream_features_cache),
160
+ "user_cache_size": len(user_features_cache),
120
161
  "model_name": state.get("model_name"),
121
162
  "stream_features": state.get("stream_features", []),
122
163
  }
@@ -151,8 +192,11 @@ def create_app(
151
192
  # Feature fetch (optional based on model)
152
193
  model = state["model"]
153
194
  stream_features = model.get("stream_features", []) or []
195
+ user_features = model.get("user_features", []) or []
154
196
  retrieval_meta = FeatureRetrievalMeta(
155
197
  cache_misses=0,
198
+ stream_cache_misses=0,
199
+ user_cache_misses=0,
156
200
  retrieval_ms=0,
157
201
  success=True,
158
202
  cache_delay_minutes=0,
@@ -160,21 +204,40 @@ def create_app(
160
204
  parsing_ms=0,
161
205
  )
162
206
  if stream_features:
163
- retrieval_meta = await set_stream_features(
164
- dynamo_client=state["dynamo_client"],
165
- streams=streams,
166
- stream_features=stream_features,
167
- features_cache=features_cache,
168
- features_table=cfg.features_table,
169
- stream_pk_prefix=cfg.stream_pk_prefix,
170
- cache_sep=cfg.cache_separator,
171
- )
207
+ try:
208
+ retrieval_meta = await set_all_features(
209
+ dynamo_client=state["dynamo_client"],
210
+ user=user,
211
+ streams=streams,
212
+ stream_features=stream_features,
213
+ user_features=user_features,
214
+ stream_features_cache=stream_features_cache,
215
+ user_features_cache=user_features_cache,
216
+ features_table=cfg.features_table,
217
+ cache_sep=cfg.cache_separator,
218
+ )
219
+ except exceptions.InvalidFeaturesException as e:
220
+ logger.error(
221
+ "The following features are not present in the SerializerRegistry %s",
222
+ e,
223
+ )
224
+ raise HTTPException(
225
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
226
+ detail=f"Received invalid features from feature store: {e}",
227
+ ) from e
172
228
 
173
229
  random_number = random.random()
174
230
  userid = user.get("userid", "")
175
231
  # Sampling logs
176
232
  if random_number < cfg.logs_fraction:
177
- logger.info("User %s streams: %s", user.get("userid", ""), streams)
233
+ logger.info(
234
+ "User %s streams: %s",
235
+ user.get("userid", ""),
236
+ [
237
+ StreamLoggerProxy(s, stream_features + user_features)
238
+ for s in streams
239
+ ],
240
+ )
178
241
 
179
242
  # Synchronous model execution (user code)
180
243
  try:
@@ -200,6 +263,8 @@ def create_app(
200
263
  "Inference",
201
264
  {
202
265
  "cache_misses": retrieval_meta.cache_misses,
266
+ "user_cache_misses": retrieval_meta.user_cache_misses,
267
+ "stream_cache_misses": retrieval_meta.stream_cache_misses,
203
268
  "retrieval_success": int(retrieval_meta.success),
204
269
  "cache_delay_minutes": retrieval_meta.cache_delay_minutes,
205
270
  "dynamo_ms": retrieval_meta.dynamo_ms,