haystack-ml-stack 0.2.4__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/PKG-INFO +10 -8
  2. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/pyproject.toml +9 -4
  3. haystack_ml_stack-0.3.0/src/haystack_ml_stack/__init__.py +14 -0
  4. haystack_ml_stack-0.3.0/src/haystack_ml_stack/_serializers.py +368 -0
  5. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/app.py +133 -38
  6. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/cache.py +2 -2
  7. haystack_ml_stack-0.3.0/src/haystack_ml_stack/dynamo.py +326 -0
  8. haystack_ml_stack-0.3.0/src/haystack_ml_stack/exceptions.py +5 -0
  9. haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/__init__.py +0 -0
  10. haystack_ml_stack-0.3.0/src/haystack_ml_stack/generated/v1/__init__.py +0 -0
  11. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/settings.py +2 -1
  12. haystack_ml_stack-0.3.0/src/haystack_ml_stack/utils.py +675 -0
  13. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/PKG-INFO +10 -8
  14. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/SOURCES.txt +5 -0
  15. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/requires.txt +3 -0
  16. haystack_ml_stack-0.3.0/tests/test_serializers.py +152 -0
  17. haystack_ml_stack-0.3.0/tests/test_utils.py +510 -0
  18. haystack_ml_stack-0.2.4/src/haystack_ml_stack/__init__.py +0 -4
  19. haystack_ml_stack-0.2.4/src/haystack_ml_stack/dynamo.py +0 -194
  20. haystack_ml_stack-0.2.4/src/haystack_ml_stack/utils.py +0 -312
  21. haystack_ml_stack-0.2.4/tests/test_utils.py +0 -76
  22. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/README.md +0 -0
  23. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/setup.cfg +0 -0
  24. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack/model_store.py +0 -0
  25. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
  26. {haystack_ml_stack-0.2.4 → haystack_ml_stack-0.3.0}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
@@ -1,18 +1,20 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.2.4
3
+ Version: 0.3.0
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: pydantic==2.5.0
10
- Requires-Dist: cachetools==5.5.2
11
- Requires-Dist: cloudpickle==2.2.1
12
- Requires-Dist: aioboto3==12.0.0
13
- Requires-Dist: fastapi==0.104.1
14
- Requires-Dist: pydantic-settings==2.2
15
- Requires-Dist: newrelic==11.1.0
9
+ Requires-Dist: protobuf==6.33.2
10
+ Provides-Extra: server
11
+ Requires-Dist: pydantic==2.5.0; extra == "server"
12
+ Requires-Dist: cachetools==5.5.2; extra == "server"
13
+ Requires-Dist: cloudpickle==2.2.1; extra == "server"
14
+ Requires-Dist: aioboto3==12.0.0; extra == "server"
15
+ Requires-Dist: fastapi==0.104.1; extra == "server"
16
+ Requires-Dist: pydantic-settings==2.2; extra == "server"
17
+ Requires-Dist: newrelic==11.1.0; extra == "server"
16
18
 
17
19
  # Haystack ML Stack
18
20
 
@@ -5,18 +5,23 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "haystack-ml-stack"
8
- version = "0.2.4"
8
+ version = "0.3.0"
9
9
  description = "Functions related to Haystack ML"
10
10
  readme = "README.md"
11
11
  authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
12
12
  requires-python = ">=3.11"
13
13
  dependencies = [
14
+ "protobuf==6.33.2",
15
+ ]
16
+ license = { text = "MIT" }
17
+
18
+ [project.optional-dependencies]
19
+ server = [
14
20
  "pydantic==2.5.0",
15
21
  "cachetools==5.5.2",
16
22
  "cloudpickle==2.2.1",
17
23
  "aioboto3==12.0.0",
18
24
  "fastapi==0.104.1",
19
25
  "pydantic-settings==2.2",
20
- "newrelic==11.1.0"
21
- ]
22
- license = { text = "MIT" }
26
+ "newrelic==11.1.0",
27
+ ]
@@ -0,0 +1,14 @@
1
+ __all__ = []
2
+
3
+ try:
4
+ from .app import create_app
5
+
6
+ __all__ = ["create_app"]
7
+ except ImportError:
8
+ pass
9
+
10
+ from ._serializers import SerializerRegistry, FeatureRegistryId
11
+
12
+ __all__ = [*__all__, "SerializerRegistry", "FeatureRegistryId"]
13
+
14
+ __version__ = "0.3.0"
@@ -0,0 +1,368 @@
1
+ from .generated.v1 import features_pb2 as features_pb2_v1
2
+ from google.protobuf.message import Message
3
+ from google.protobuf.json_format import ParseDict as ProtoParseDict
4
+ import typing as _t
5
+ from abc import ABC, abstractmethod
6
+
7
+ MessageType = _t.TypeVar("MessageType", bound=Message)
8
+
9
+
10
+ class Serializer(ABC):
11
+ @abstractmethod
12
+ def serialize(self, value) -> bytes: ...
13
+
14
+ @abstractmethod
15
+ def deserialize(self, value: bytes) -> _t.Any: ...
16
+
17
+
18
+ class SimpleSerializer(Serializer, _t.Generic[MessageType]):
19
+ """This simple serializer uses the function `ParseDict` provided by google
20
+ to parse dictionaries. While it allows for simple code, it's very slow to run.
21
+ This class should be used directly for PoCs only, production serializers should have
22
+ custom implementations where fields are set directly. Early tests show that
23
+ manual serialization can provide 10x speedup.
24
+
25
+ Deserialization is fine since it deserializes from the binary into the message
26
+ itself, it doesn't need to create a dictionary."""
27
+
28
+ def __init__(self, msg_class: type[MessageType]):
29
+ self.msg_class = msg_class
30
+ return
31
+
32
+ def serialize(self, value) -> bytes:
33
+ msg = self.msg_class()
34
+ return ProtoParseDict(value, message=msg).SerializeToString()
35
+
36
+ def deserialize(self, value) -> MessageType:
37
+ msg: Message = self.msg_class()
38
+ msg.ParseFromString(value)
39
+ return msg
40
+
41
+
42
+ class StreamPWatchedSerializerV1(SimpleSerializer):
43
+ def __init__(self):
44
+ super().__init__(msg_class=features_pb2_v1.StreamPWatched)
45
+
46
+ def serialize(self, value):
47
+ root_msg = self.build_msg(value)
48
+ return root_msg.SerializeToString()
49
+
50
+ def build_msg(self, value) -> features_pb2_v1.StreamPWatched:
51
+ message = self.msg_class()
52
+ assert value["version"] == 1, "Wrong version given!"
53
+ message.version = value["version"]
54
+ for entry_context, counts in value["data"].items():
55
+ entry_context_msg: features_pb2_v1.EntryContextCounts = getattr(
56
+ message.data, entry_context
57
+ )
58
+ entry_context_msg.attempts = int(counts["attempts"])
59
+ entry_context_msg.watched = int(counts["watched"])
60
+ return message
61
+
62
+
63
+ UserPWatchedSerializerV1 = StreamPWatchedSerializerV1
64
+
65
+
66
+ class StreamPWatchedSerializerV0(Serializer):
67
+ serializer_v1 = StreamPWatchedSerializerV1()
68
+
69
+ def serialize(self, value) -> bytes:
70
+ raise NotImplementedError(
71
+ "This serializer should never be used for serialization!"
72
+ )
73
+
74
+ def deserialize(self, value) -> features_pb2_v1.StreamPWatched:
75
+ value = {
76
+ "data": {
77
+ entry_context.replace(" ", "_"): counts
78
+ for entry_context, counts in value.items()
79
+ },
80
+ "version": 1,
81
+ }
82
+ return self.serializer_v1.build_msg(value)
83
+
84
+
85
+ class StreamPSelectSerializerV1(SimpleSerializer):
86
+ def __init__(self):
87
+ super().__init__(msg_class=features_pb2_v1.StreamPSelect)
88
+ return
89
+
90
+ def serialize(self, value) -> bytes:
91
+ root_msg = self.build_msg(value)
92
+ return root_msg.SerializeToString()
93
+
94
+ def build_msg(self, value) -> features_pb2_v1.StreamPSelect:
95
+ message: features_pb2_v1.StreamPSelect = self.msg_class()
96
+ assert value["version"] == 1, "Wrong version given!"
97
+ message.version = 1
98
+ data = value["data"]
99
+ for (
100
+ browsed_debias_key,
101
+ position_pselects,
102
+ ) in data.items():
103
+ position_pselects_msg: features_pb2_v1.PositionPSelect = getattr(
104
+ message.data, browsed_debias_key
105
+ )
106
+ for position, select_counts in position_pselects.items():
107
+ select_counts_msg = getattr(position_pselects_msg, position)
108
+ select_counts_msg.total_selects = int(select_counts["total_selects"])
109
+ select_counts_msg.total_browsed = int(select_counts["total_browsed"])
110
+ select_counts_msg.total_selects_and_watched = int(
111
+ select_counts["total_selects_and_watched"]
112
+ )
113
+ return message
114
+
115
+
116
+ UserPSelectSerializerV1 = StreamPSelectSerializerV1
117
+
118
+
119
+ class StreamPSelectSerializerV0(Serializer):
120
+ serializer_v1 = StreamPSelectSerializerV1()
121
+
122
+ def serialize(self, value) -> bytes:
123
+ raise NotImplementedError(
124
+ "This serializer should never be used for serialization!"
125
+ )
126
+
127
+ def deserialize(self, value):
128
+ key_mapping = {
129
+ "0": "first_pos",
130
+ "1": "second_pos",
131
+ "2": "third_pos",
132
+ "3+": "rest_pos",
133
+ }
134
+ for browsed_debiasing in value.keys():
135
+ for old_key, new_key in key_mapping.items():
136
+ if old_key not in value[browsed_debiasing]:
137
+ continue
138
+ value[browsed_debiasing][new_key] = value[browsed_debiasing].pop(
139
+ old_key
140
+ )
141
+ out = {
142
+ "data": {
143
+ "up_to_4_browsed": value["4_browsed"],
144
+ "all_browsed": value["all_browsed"],
145
+ },
146
+ "version": 1,
147
+ }
148
+ msg = self.serializer_v1.build_msg(value=out)
149
+ return msg
150
+
151
+
152
+ class StreamSimilaritySerializerV1(SimpleSerializer):
153
+ def __init__(self):
154
+ super().__init__(msg_class=features_pb2_v1.StreamSimilarityScores)
155
+
156
+ def serialize(self, value):
157
+ msg = self.build_msg(value)
158
+ return msg.SerializeToString()
159
+
160
+ def build_msg(self, value) -> features_pb2_v1.StreamSimilarityScores:
161
+ message = self.msg_class()
162
+ assert value["version"] == 1, "Wrong version given!"
163
+ message.version = value["version"]
164
+ for key, score in value["data"].items():
165
+ message.data[key] = score
166
+ return message
167
+
168
+
169
+ class StreamSimilaritySerializerV0(Serializer):
170
+ serializer_v1 = StreamSimilaritySerializerV1()
171
+
172
+ def serialize(self, value):
173
+ raise NotImplementedError(
174
+ "This serializer should never be used for serialization!"
175
+ )
176
+
177
+ def deserialize(self, value):
178
+ value = {"data": value, "version": 1}
179
+ msg = self.serializer_v1.build_msg(value)
180
+ return msg
181
+
182
+
183
+ class UserPersonalizingPWatchedSerializerV1(SimpleSerializer):
184
+ def __init__(self):
185
+ super().__init__(msg_class=features_pb2_v1.UserPersonalizingPWatched)
186
+
187
+ def serialize(self, value: dict) -> bytes:
188
+ root_msg = self.build_msg(value)
189
+ return root_msg.SerializeToString()
190
+
191
+ def build_msg(self, value) -> features_pb2_v1.UserPersonalizingPWatched:
192
+ root_msg = features_pb2_v1.UserPersonalizingPWatched()
193
+ assert value["version"] == 1, "Wrong version given!"
194
+ root_msg.version = value["version"]
195
+ data = value["data"]
196
+ for personalizing_key, entry_context_pwatched in data.items():
197
+ personalizing_msg = root_msg.data[personalizing_key]
198
+ for entry_context, counts in entry_context_pwatched.items():
199
+ entry_context_msg = getattr(personalizing_msg, entry_context)
200
+ entry_context_msg.attempts = int(counts["attempts"])
201
+ entry_context_msg.watched = int(counts["watched"])
202
+ return root_msg
203
+
204
+
205
+ class UserPersonalizingPSelectSerializerV1(SimpleSerializer):
206
+ def __init__(self):
207
+ super().__init__(msg_class=features_pb2_v1.UserPersonalizingPSelect)
208
+
209
+ def serialize(self, value):
210
+ root_msg = features_pb2_v1.UserPersonalizingPSelect()
211
+ root_msg.version = value["version"]
212
+ data = value["data"]
213
+ for personalizing_key, browsed_debiased_pselecs in data.items():
214
+ personalizing_msg = root_msg.data[personalizing_key]
215
+ for (
216
+ browsed_debias_key,
217
+ position_pselects,
218
+ ) in browsed_debiased_pselecs.items():
219
+ position_pselects_msg = getattr(personalizing_msg, browsed_debias_key)
220
+ for position, select_counts in position_pselects.items():
221
+ select_counts_msg = getattr(position_pselects_msg, position)
222
+ select_counts_msg.total_selects = int(
223
+ select_counts["total_selects"]
224
+ )
225
+ select_counts_msg.total_browsed = int(
226
+ select_counts["total_browsed"]
227
+ )
228
+ select_counts_msg.total_selects_and_watched = int(
229
+ select_counts["total_selects_and_watched"]
230
+ )
231
+ return root_msg.SerializeToString()
232
+
233
+
234
+ class PassThroughSerializer(Serializer):
235
+ def serialize(self, value):
236
+ return value
237
+
238
+ def deserialize(self, value):
239
+ return value
240
+
241
+
242
+ user_personalizing_pwatched_serializer_v1 = UserPersonalizingPWatchedSerializerV1()
243
+ user_pwatched_serializer_v1 = UserPWatchedSerializerV1()
244
+ user_personalizing_pselect_serializer_v1 = UserPersonalizingPSelectSerializerV1()
245
+ user_pselect_serializer_v1 = UserPSelectSerializerV1()
246
+ stream_pwatched_serializer_v0 = StreamPWatchedSerializerV0()
247
+ stream_pwatched_serializer_v1 = StreamPWatchedSerializerV1()
248
+ stream_pselect_serializer_v0 = StreamPSelectSerializerV0()
249
+ stream_pselect_serializer_v1 = StreamPSelectSerializerV1()
250
+ stream_similarity_scores_serializer_v0 = StreamSimilaritySerializerV0()
251
+ stream_similarity_scores_serializer_v1 = StreamSimilaritySerializerV1()
252
+
253
+
254
+ class FeatureRegistryId(_t.NamedTuple):
255
+ entity_type: _t.Literal["STREAM", "USER"]
256
+ feature_id: str
257
+ version: str
258
+
259
+
260
+ stream_pwatched_v0_features: list[FeatureRegistryId] = [
261
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v0"),
262
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v0"),
263
+ FeatureRegistryId(
264
+ entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v0"
265
+ ),
266
+ ]
267
+
268
+ stream_pwatched_v1_features: list[FeatureRegistryId] = [
269
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H", version="v1"),
270
+ FeatureRegistryId(entity_type="STREAM", feature_id="PWATCHED#24H#TV", version="v1"),
271
+ FeatureRegistryId(
272
+ entity_type="STREAM", feature_id="PWATCHED#24H#MOBILE", version="v1"
273
+ ),
274
+ ]
275
+
276
+ stream_pselect_v0_features: list[FeatureRegistryId] = [
277
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v0"),
278
+ FeatureRegistryId(
279
+ entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v0"
280
+ ),
281
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v0"),
282
+ ]
283
+
284
+ stream_pselect_v1_features: list[FeatureRegistryId] = [
285
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H", version="v1"),
286
+ FeatureRegistryId(
287
+ entity_type="STREAM", feature_id="PSELECT#24H#MOBILE", version="v1"
288
+ ),
289
+ FeatureRegistryId(entity_type="STREAM", feature_id="PSELECT#24H#TV", version="v1"),
290
+ ]
291
+
292
+ stream_similarity_v0_features: list[FeatureRegistryId] = [
293
+ FeatureRegistryId(entity_type="STREAM", feature_id="SIMILARITY", version="v0"),
294
+ FeatureRegistryId(
295
+ entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v0"
296
+ ),
297
+ ]
298
+
299
+ stream_similarity_v1_features: list[FeatureRegistryId] = [
300
+ FeatureRegistryId(
301
+ entity_type="STREAM", feature_id="SIMILARITY#GEMINI", version="v1"
302
+ ),
303
+ FeatureRegistryId(
304
+ entity_type="STREAM", feature_id="SIMILARITY#WEATHER_ALERT", version="v1"
305
+ ),
306
+ ]
307
+
308
+ user_personalizing_pwatched_v1_features: list[FeatureRegistryId] = [
309
+ FeatureRegistryId(
310
+ entity_type="USER", feature_id="PWATCHED#6M#CATEGORY", version="v1"
311
+ ),
312
+ FeatureRegistryId(
313
+ entity_type="USER",
314
+ feature_id="PWATCHED#6M#AUTHOR_SHOW",
315
+ version="v1",
316
+ ),
317
+ FeatureRegistryId(
318
+ entity_type="USER",
319
+ feature_id="PWATCHED#6M#GEMINI_CATEGORY",
320
+ version="v1",
321
+ ),
322
+ ]
323
+
324
+ user_personalizing_pselect_v1_features: list[FeatureRegistryId] = [
325
+ FeatureRegistryId(
326
+ entity_type="USER", feature_id="PSELECT#6M#CATEGORY", version="v1"
327
+ ),
328
+ FeatureRegistryId(
329
+ entity_type="USER", feature_id="PSELECT#6M#AUTHOR_SHOW", version="v1"
330
+ ),
331
+ FeatureRegistryId(
332
+ entity_type="USER", feature_id="PSELECT#6M#GEMINI_CATEGORY", version="v1"
333
+ ),
334
+ ]
335
+
336
+ user_bias_pwatched_v1_features: list[FeatureRegistryId] = [
337
+ FeatureRegistryId(entity_type="USER", feature_id="PWATCHED#6M", version="v1")
338
+ ]
339
+
340
+ user_bias_pselect_v1_features: list[FeatureRegistryId] = [
341
+ FeatureRegistryId(entity_type="USER", feature_id="PSELECT#6M", version="v1")
342
+ ]
343
+
344
+ features_serializer_tuples: list[tuple[list[FeatureRegistryId], Serializer]] = [
345
+ (stream_pwatched_v0_features, stream_pwatched_serializer_v0),
346
+ (stream_pwatched_v1_features, stream_pwatched_serializer_v1),
347
+ (stream_pselect_v0_features, stream_pselect_serializer_v0),
348
+ (stream_pselect_v1_features, stream_pselect_serializer_v1),
349
+ (stream_similarity_v0_features, stream_similarity_scores_serializer_v0),
350
+ (stream_similarity_v1_features, stream_similarity_scores_serializer_v1),
351
+ (
352
+ user_personalizing_pwatched_v1_features,
353
+ user_personalizing_pwatched_serializer_v1,
354
+ ),
355
+ (user_bias_pwatched_v1_features, user_pwatched_serializer_v1),
356
+ (user_personalizing_pselect_v1_features, user_personalizing_pselect_serializer_v1),
357
+ (user_bias_pselect_v1_features, user_pselect_serializer_v1),
358
+ ]
359
+
360
+ SerializerRegistry: dict[FeatureRegistryId, Serializer] = {
361
+ FeatureRegistryId(
362
+ entity_type="PASS_THROUGH", feature_id="PASS_THROUGH", version="v1"
363
+ ): PassThroughSerializer()
364
+ }
365
+
366
+ for feature_ids, serializer in features_serializer_tuples:
367
+ for feature_id in feature_ids:
368
+ SerializerRegistry[feature_id] = serializer
@@ -5,17 +5,22 @@ import sys
5
5
  from http import HTTPStatus
6
6
  from typing import Any, Dict, List, Optional
7
7
  import time
8
+ from contextlib import asynccontextmanager, AsyncExitStack
8
9
 
9
10
  import aiobotocore.session
11
+ from aiobotocore.config import AioConfig
10
12
  from fastapi import FastAPI, HTTPException, Request, Response
11
13
  from fastapi.encoders import jsonable_encoder
12
14
  import newrelic.agent
13
15
 
14
16
 
15
17
  from .cache import make_features_cache
16
- from .dynamo import set_stream_features, FeatureRetrievalMeta
18
+ from .dynamo import set_all_features, FeatureRetrievalMeta
17
19
  from .model_store import download_and_load_model
18
20
  from .settings import Settings
21
+ from . import exceptions
22
+ from ._serializers import SerializerRegistry
23
+ from google.protobuf import text_format
19
24
 
20
25
  logging.basicConfig(
21
26
  level=logging.INFO,
@@ -25,7 +30,62 @@ logging.basicConfig(
25
30
  )
26
31
 
27
32
  logger = logging.getLogger(__name__)
28
- APP_NAME = os.environ.get("NEW_RELIC_APP_NAME", None)
33
+ MAX_POOL_CONNECTIONS = int(os.environ.get("MAX_POOL_CONNECTIONS", 50))
34
+
35
+
36
+ class StreamLoggerProxy:
37
+ def __init__(self, stream, feature_ids):
38
+ self._stream = stream
39
+ self._feature_ids = feature_ids
40
+
41
+ def __repr__(self):
42
+ parts = []
43
+ for k, v in self._stream.items():
44
+ if k in self._feature_ids:
45
+ # Format only when needed for the log output
46
+ formatted_v = text_format.MessageToString(v, as_one_line=True)
47
+ parts.append(f"'{k}': '{formatted_v}'")
48
+ else:
49
+ parts.append(f"'{k}': {repr(v)}")
50
+ return "{" + ", ".join(parts) + "}"
51
+
52
+
53
+ async def load_model(state, cfg: Settings) -> None:
54
+ if not cfg.s3_model_path:
55
+ logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
56
+ else:
57
+ try:
58
+ # Pass the persistent session/client if needed
59
+ state["model"] = await download_and_load_model(
60
+ cfg.s3_model_path, aio_session=state["session"]
61
+ )
62
+ state["stream_features"] = state["model"].get("stream_features", [])
63
+ state["user_features"] = state["model"].get("user_features", [])
64
+ valid_features = set(
65
+ (entity_type, feature_id)
66
+ for entity_type, feature_id, _ in SerializerRegistry.keys()
67
+ )
68
+ all_features = set(
69
+ [("STREAM", feature_name) for feature_name in state["stream_features"]]
70
+ + [("USER", feature_name) for feature_name in state["user_features"]]
71
+ )
72
+ invalid_features = all_features.difference(valid_features)
73
+ if invalid_features:
74
+ raise exceptions.InvalidFeaturesException(
75
+ f"Received invalid features: {invalid_features}"
76
+ )
77
+ newrelic.agent.add_custom_attribute(
78
+ "total_stream_features", len(state["stream_features"])
79
+ )
80
+ newrelic.agent.add_custom_attribute(
81
+ "total_user_features", len(state["user_features"])
82
+ )
83
+ logger.info("Model loaded successfully.")
84
+ except exceptions.InvalidFeaturesException as e:
85
+ logger.error("%s", e)
86
+ raise e
87
+ except Exception as e:
88
+ logger.critical("Failed to load model: %s", e)
29
89
 
30
90
 
31
91
  def create_app(
@@ -39,40 +99,51 @@ def create_app(
39
99
  """
40
100
  cfg = settings or Settings()
41
101
 
42
- app = FastAPI(
43
- title="ML Stream Scorer",
44
- description="Scores video streams using a pre-trained ML model and DynamoDB features.",
45
- version="1.0.0",
46
- )
47
-
48
102
  # Mutable state: cache + model
49
- features_cache = make_features_cache(cfg.cache_maxsize)
103
+ stream_features_cache = make_features_cache(cfg.stream_cache_maxsize)
104
+ user_features_cache = make_features_cache(cfg.user_cache_maxsize)
105
+ aws_session = aiobotocore.session.get_session()
50
106
  state: Dict[str, Any] = {
51
107
  "model": preloaded_model,
52
- "session": aiobotocore.session.get_session(),
108
+ "session": aws_session,
53
109
  "model_name": (
54
110
  os.path.basename(cfg.s3_model_path) if cfg.s3_model_path else None
55
111
  ),
56
112
  }
57
113
 
58
- @app.on_event("startup")
59
- async def _startup() -> None:
60
- if state["model"] is not None:
61
- logger.info("Using preloaded model.")
62
- return
114
+ @asynccontextmanager
115
+ async def lifespan(app_server: FastAPI):
116
+ """
117
+ Handles startup and shutdown logic.
118
+ Everything before 'yield' runs on startup.
119
+ Everything after 'yield' runs on shutdown.
120
+ """
121
+ # 1. Load ML Model
122
+ if state["model"] is None:
123
+ await load_model(state, cfg)
124
+ async with AsyncExitStack() as stack:
125
+ # 2. Initialize DynamoDB Client (Persistent Pool)
126
+ session = state["session"]
127
+ state["dynamo_client"] = await stack.enter_async_context(
128
+ session.create_client(
129
+ "dynamodb",
130
+ # Ensure the pool is large enough for ML concurrency
131
+ config=AioConfig(max_pool_connections=MAX_POOL_CONNECTIONS),
132
+ )
133
+ )
134
+ logger.info("DynamoDB persistent client initialized.")
135
+ yield
63
136
 
64
- if not cfg.s3_model_path:
65
- logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
66
- return
137
+ # 3. Shutdown Logic
138
+ # The AsyncExitStack automatically closes the DynamoDB client pool here
139
+ logger.info("Shutting down: Connection pools closed.")
67
140
 
68
- try:
69
- state["model"] = await download_and_load_model(
70
- cfg.s3_model_path, aio_session=state["session"]
71
- )
72
- state["stream_features"] = state["model"].get("stream_features", [])
73
- logger.info("Model loaded on startup.")
74
- except Exception as e:
75
- logger.critical("Failed to load model: %s", e)
141
+ app = FastAPI(
142
+ title="ML Stream Scorer",
143
+ description="Scores video streams using a pre-trained ML model and DynamoDB features.",
144
+ version="1.0.0",
145
+ lifespan=lifespan,
146
+ )
76
147
 
77
148
  @app.get("/health", status_code=HTTPStatus.OK)
78
149
  async def health():
@@ -85,7 +156,8 @@ def create_app(
85
156
  return {
86
157
  "status": "ok",
87
158
  "model_loaded": True,
88
- "cache_size": len(features_cache),
159
+ "stream_cache_size": len(stream_features_cache),
160
+ "user_cache_size": len(user_features_cache),
89
161
  "model_name": state.get("model_name"),
90
162
  "stream_features": state.get("stream_features", []),
91
163
  }
@@ -120,8 +192,11 @@ def create_app(
120
192
  # Feature fetch (optional based on model)
121
193
  model = state["model"]
122
194
  stream_features = model.get("stream_features", []) or []
195
+ user_features = model.get("user_features", []) or []
123
196
  retrieval_meta = FeatureRetrievalMeta(
124
197
  cache_misses=0,
198
+ stream_cache_misses=0,
199
+ user_cache_misses=0,
125
200
  retrieval_ms=0,
126
201
  success=True,
127
202
  cache_delay_minutes=0,
@@ -129,21 +204,40 @@ def create_app(
129
204
  parsing_ms=0,
130
205
  )
131
206
  if stream_features:
132
- retrieval_meta = await set_stream_features(
133
- aio_session=state["session"],
134
- streams=streams,
135
- stream_features=stream_features,
136
- features_cache=features_cache,
137
- features_table=cfg.features_table,
138
- stream_pk_prefix=cfg.stream_pk_prefix,
139
- cache_sep=cfg.cache_separator,
140
- )
207
+ try:
208
+ retrieval_meta = await set_all_features(
209
+ dynamo_client=state["dynamo_client"],
210
+ user=user,
211
+ streams=streams,
212
+ stream_features=stream_features,
213
+ user_features=user_features,
214
+ stream_features_cache=stream_features_cache,
215
+ user_features_cache=user_features_cache,
216
+ features_table=cfg.features_table,
217
+ cache_sep=cfg.cache_separator,
218
+ )
219
+ except exceptions.InvalidFeaturesException as e:
220
+ logger.error(
221
+ "The following features are not present in the SerializerRegistry %s",
222
+ e,
223
+ )
224
+ raise HTTPException(
225
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
226
+ detail=f"Received invalid features from feature store: {e}",
227
+ ) from e
141
228
 
142
229
  random_number = random.random()
143
230
  userid = user.get("userid", "")
144
231
  # Sampling logs
145
232
  if random_number < cfg.logs_fraction:
146
- logger.info("User %s streams: %s", user.get("userid", ""), streams)
233
+ logger.info(
234
+ "User %s streams: %s",
235
+ user.get("userid", ""),
236
+ [
237
+ StreamLoggerProxy(s, stream_features + user_features)
238
+ for s in streams
239
+ ],
240
+ )
147
241
 
148
242
  # Synchronous model execution (user code)
149
243
  try:
@@ -168,8 +262,9 @@ def create_app(
168
262
  newrelic.agent.record_custom_event(
169
263
  "Inference",
170
264
  {
171
- "app_name": APP_NAME,
172
265
  "cache_misses": retrieval_meta.cache_misses,
266
+ "user_cache_misses": retrieval_meta.user_cache_misses,
267
+ "stream_cache_misses": retrieval_meta.stream_cache_misses,
173
268
  "retrieval_success": int(retrieval_meta.success),
174
269
  "cache_delay_minutes": retrieval_meta.cache_delay_minutes,
175
270
  "dynamo_ms": retrieval_meta.dynamo_ms,