haystack-ml-stack 0.3.3__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/PKG-INFO +3 -1
  2. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/pyproject.toml +3 -2
  3. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/__init__.py +1 -1
  4. haystack_ml_stack-0.4.0/src/haystack_ml_stack/_kafka.py +88 -0
  5. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/app.py +66 -18
  6. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/settings.py +8 -2
  7. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/utils.py +1 -1
  8. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack.egg-info/PKG-INFO +3 -1
  9. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack.egg-info/SOURCES.txt +1 -0
  10. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack.egg-info/requires.txt +2 -0
  11. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/tests/test_utils.py +75 -0
  12. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/README.md +0 -0
  13. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/setup.cfg +0 -0
  14. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/_serializers.py +0 -0
  15. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/cache.py +0 -0
  16. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/dynamo.py +0 -0
  17. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/exceptions.py +0 -0
  18. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/generated/__init__.py +0 -0
  19. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/generated/v1/__init__.py +0 -0
  20. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/generated/v1/features_pb2.py +0 -0
  21. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/generated/v1/features_pb2.pyi +0 -0
  22. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack/model_store.py +0 -0
  23. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
  24. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
  25. {haystack_ml_stack-0.3.3 → haystack_ml_stack-0.4.0}/tests/test_serializers.py +0 -0
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.3.3
3
+ Version: 0.4.0
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  Requires-Dist: protobuf==6.33.2
10
+ Requires-Dist: orjson==3.11.7
10
11
  Provides-Extra: server
11
12
  Requires-Dist: pydantic==2.5.0; extra == "server"
12
13
  Requires-Dist: cachetools==5.5.2; extra == "server"
@@ -15,6 +16,7 @@ Requires-Dist: aioboto3==12.0.0; extra == "server"
15
16
  Requires-Dist: fastapi==0.104.1; extra == "server"
16
17
  Requires-Dist: pydantic-settings==2.2; extra == "server"
17
18
  Requires-Dist: newrelic==11.1.0; extra == "server"
19
+ Requires-Dist: confluent-kafka==2.13.0; extra == "server"
18
20
 
19
21
  # Haystack ML Stack
20
22
 
@@ -5,13 +5,13 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "haystack-ml-stack"
8
- version = "0.3.3"
8
+ version = "0.4.0"
9
9
  description = "Functions related to Haystack ML"
10
10
  readme = "README.md"
11
11
  authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
12
12
  requires-python = ">=3.11"
13
13
  dependencies = [
14
- "protobuf==6.33.2",
14
+ "protobuf==6.33.2", "orjson==3.11.7"
15
15
  ]
16
16
  license = { text = "MIT" }
17
17
 
@@ -24,4 +24,5 @@ server = [
24
24
  "fastapi==0.104.1",
25
25
  "pydantic-settings==2.2",
26
26
  "newrelic==11.1.0",
27
+ "confluent-kafka==2.13.0"
27
28
  ]
@@ -11,4 +11,4 @@ from ._serializers import SerializerRegistry, FeatureRegistryId
11
11
 
12
12
  __all__ = [*__all__, "SerializerRegistry", "FeatureRegistryId"]
13
13
 
14
- __version__ = "0.3.3"
14
+ __version__ = "0.4.0"
@@ -0,0 +1,88 @@
1
+ from confluent_kafka.aio import AIOProducer
2
+ import orjson
3
+ from google.protobuf.message import Message
4
+ import base64
5
+ import os
6
+ import logging
7
+ from .settings import Settings
8
+ from . import __version__
9
+ import hashlib
10
+
11
+ logger = logging.getLogger(__name__)
12
+ SECURITY_PROTOCOL = "SASL_SSL"
13
+ SASL_MECHANISM = "SCRAM-SHA-512"
14
+
15
+
16
+ async def send_to_kafka(
17
+ producer: AIOProducer,
18
+ topic: str,
19
+ user: dict,
20
+ streams: list[dict],
21
+ playlist: dict,
22
+ state: dict,
23
+ model_output: dict,
24
+ monitoring_meta: dict,
25
+ ) -> None:
26
+ if topic is None or producer is None:
27
+ return
28
+ message = {
29
+ "userid": user.get("userid"),
30
+ "client_os": playlist.get("clientOs"),
31
+ "model_input": {"user": user, "streams": streams, "playlist": playlist},
32
+ "model_output": model_output,
33
+ "model_name": state["model_name"].replace(".pkl", "")
34
+ if state["model_name"]
35
+ else None,
36
+ "model_type": "streams",
37
+ "meta": {
38
+ "monitoring": monitoring_meta,
39
+ "haystack_ml_stack_version": __version__,
40
+ "playlist_category": playlist.get("category"),
41
+ "user_features": state.get("user_features", []),
42
+ "stream_features": state.get("stream_features", []),
43
+ },
44
+ }
45
+ delivery_future = await producer.produce(
46
+ topic, orjson.dumps(message, default=default_serialization)
47
+ )
48
+ await delivery_future
49
+ return
50
+
51
+
52
+ def default_serialization(obj):
53
+ if isinstance(obj, Message):
54
+ return {
55
+ "version": obj.version,
56
+ "proto": base64.b64encode(obj.SerializeToString()).decode("ascii"),
57
+ }
58
+ raise orjson.JSONEncodeError("Unknown data type to serialize!")
59
+
60
+
61
+ def initialize_kafka_producer(app_config: Settings) -> AIOProducer:
62
+ secret_keys = orjson.loads(os.getenv("SECRET_KEYS") or "{}")
63
+ if not secret_keys:
64
+ raise ValueError("No Kafka credentials found.")
65
+ with open("/tmp/ca.pem", "w") as f:
66
+ f.write(base64.b64decode(secret_keys["KAFKA_BROKER_CA_CERTIFICATE"]).decode())
67
+ kafka_config = {
68
+ "bootstrap.servers": app_config.kafka_bootstrap_servers,
69
+ "security.protocol": SECURITY_PROTOCOL,
70
+ "sasl.username": secret_keys["KAFKA_WRITER_USER"],
71
+ "sasl.password": secret_keys["KAFKA_WRITER_PASSWORD"],
72
+ "sasl.mechanism": SASL_MECHANISM,
73
+ "ssl.ca.location": "/tmp/ca.pem",
74
+ "compression.type": "lz4",
75
+ }
76
+ logger.info(
77
+ "Initializing kafka producer pushing to topic %s", app_config.kafka_topic
78
+ )
79
+ producer = AIOProducer(kafka_config)
80
+ logger.info("Producer initialized!")
81
+ return producer
82
+
83
+
84
+ def should_log_user(userid: str, kafka_fraction: float) -> bool:
85
+ if not userid:
86
+ return False
87
+ hash_value = int(hashlib.sha256(userid.encode()).hexdigest(), 16) / (2**256)
88
+ return hash_value < kafka_fraction
@@ -7,20 +7,22 @@ from typing import Any, Dict, List, Optional
7
7
  import time
8
8
  from contextlib import asynccontextmanager, AsyncExitStack
9
9
  import traceback
10
+ import json
11
+ import asyncio
10
12
 
11
13
  import aiobotocore.session
12
14
  from aiobotocore.config import AioConfig
13
- from fastapi import FastAPI, HTTPException, Request, Response
15
+ from fastapi import FastAPI, HTTPException, Request, Response, BackgroundTasks
14
16
  from fastapi.encoders import jsonable_encoder
15
17
  import newrelic.agent
16
18
 
17
-
18
19
  from .cache import make_features_cache
19
20
  from .dynamo import set_all_features, FeatureRetrievalMeta
20
21
  from .model_store import download_and_load_model
21
22
  from .settings import Settings
22
23
  from . import exceptions
23
24
  from ._serializers import SerializerRegistry
25
+ from ._kafka import send_to_kafka, initialize_kafka_producer, should_log_user
24
26
  from google.protobuf import text_format
25
27
 
26
28
  logging.basicConfig(
@@ -122,6 +124,10 @@ def create_app(
122
124
  # 1. Load ML Model
123
125
  if state["model"] is None:
124
126
  await load_model(state, cfg)
127
+ kafka_producer = None
128
+ if cfg.kafka_bootstrap_servers is not None:
129
+ kafka_producer = initialize_kafka_producer(app_config=cfg)
130
+ state["kafka_producer"] = kafka_producer
125
131
  async with AsyncExitStack() as stack:
126
132
  # 2. Initialize DynamoDB Client (Persistent Pool)
127
133
  session = state["session"]
@@ -138,6 +144,17 @@ def create_app(
138
144
  # 3. Shutdown Logic
139
145
  # The AsyncExitStack automatically closes the DynamoDB client pool here
140
146
  logger.info("Shutting down: Connection pools closed.")
147
+ logger.info("Shutting down: Flushing Kafka queue.")
148
+ if kafka_producer is not None:
149
+ try:
150
+ await kafka_producer.flush()
151
+ except Exception:
152
+ logger.error(
153
+ "Unknown exception while flushing kafka queue, shutting down producer.\n%s",
154
+ traceback.format_exc(),
155
+ )
156
+ finally:
157
+ await kafka_producer.close()
141
158
 
142
159
  app = FastAPI(
143
160
  title="ML Stream Scorer",
@@ -161,10 +178,13 @@ def create_app(
161
178
  "user_cache_size": len(user_features_cache),
162
179
  "model_name": state.get("model_name"),
163
180
  "stream_features": state.get("stream_features", []),
181
+ "user_features": state.get("user_features", []),
164
182
  }
165
183
 
166
184
  @app.post("/score", status_code=HTTPStatus.OK)
167
- async def score_stream(request: Request, response: Response):
185
+ async def score_stream(
186
+ request: Request, response: Response, background_tasks: BackgroundTasks
187
+ ):
168
188
  if state["model"] is None:
169
189
  raise HTTPException(
170
190
  status_code=HTTPStatus.SERVICE_UNAVAILABLE,
@@ -173,10 +193,24 @@ def create_app(
173
193
 
174
194
  try:
175
195
  data = await request.json()
176
- except Exception as e:
196
+ except json.JSONDecodeError as e:
197
+ body = await request.body()
198
+ logger.error(
199
+ "Received malformed json. Raw body: %s\n%s",
200
+ body.decode(errors="replace"),
201
+ traceback.format_exc(),
202
+ )
177
203
  raise HTTPException(
178
204
  status_code=HTTPStatus.BAD_REQUEST, detail="Invalid JSON payload"
179
205
  ) from e
206
+ except Exception as e:
207
+ logger.error(
208
+ "Unexpected exception when parsing request.\n %s",
209
+ traceback.format_exc(),
210
+ )
211
+ raise HTTPException(
212
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Unknown exception"
213
+ ) from e
180
214
  query_params = {}
181
215
  for k in request.query_params.keys():
182
216
  values = request.query_params.getlist(k)
@@ -259,22 +293,24 @@ def create_app(
259
293
  status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
260
294
  detail="Model prediction failed",
261
295
  ) from e
262
-
296
+ monitoring_meta = {
297
+ "cache_misses": retrieval_meta.cache_misses,
298
+ "user_cache_misses": retrieval_meta.user_cache_misses,
299
+ "stream_cache_misses": retrieval_meta.stream_cache_misses,
300
+ "user_cache_size": len(user_features_cache),
301
+ "stream_cache_size": len(stream_features_cache),
302
+ "retrieval_success": int(retrieval_meta.success),
303
+ "cache_delay_minutes": retrieval_meta.cache_delay_minutes,
304
+ "dynamo_ms": retrieval_meta.dynamo_ms,
305
+ "dynamo_parse_ms": retrieval_meta.parsing_ms,
306
+ "retrieval_ms": retrieval_meta.retrieval_ms,
307
+ "preprocess_ms": (predict_start - preprocess_start) * 1e-6,
308
+ "predict_ms": (predict_end - predict_start) * 1e-6,
309
+ "total_streams": len(model_output),
310
+ }
263
311
  newrelic.agent.record_custom_event(
264
312
  "Inference",
265
- {
266
- "cache_misses": retrieval_meta.cache_misses,
267
- "user_cache_misses": retrieval_meta.user_cache_misses,
268
- "stream_cache_misses": retrieval_meta.stream_cache_misses,
269
- "retrieval_success": int(retrieval_meta.success),
270
- "cache_delay_minutes": retrieval_meta.cache_delay_minutes,
271
- "dynamo_ms": retrieval_meta.dynamo_ms,
272
- "dynamo_parse_ms": retrieval_meta.parsing_ms,
273
- "retrieval_ms": retrieval_meta.retrieval_ms,
274
- "preprocess_ms": (predict_start - preprocess_start) * 1e-6,
275
- "predict_ms": (predict_end - predict_start) * 1e-6,
276
- "total_streams": len(model_output),
277
- },
313
+ monitoring_meta,
278
314
  )
279
315
  if model_output:
280
316
  if random_number < cfg.logs_fraction:
@@ -283,6 +319,18 @@ def create_app(
283
319
  userid,
284
320
  model_output,
285
321
  )
322
+ if should_log_user(userid=userid, kafka_fraction=cfg.kafka_fraction):
323
+ background_tasks.add_task(
324
+ send_to_kafka,
325
+ producer=state["kafka_producer"],
326
+ topic=cfg.kafka_topic,
327
+ user=user,
328
+ streams=streams,
329
+ playlist=playlist,
330
+ state=state,
331
+ model_output=model_output,
332
+ monitoring_meta=monitoring_meta,
333
+ )
286
334
  return jsonable_encoder(model_output)
287
335
 
288
336
  raise HTTPException(
@@ -1,9 +1,15 @@
1
1
  from pydantic_settings import BaseSettings
2
2
  from pydantic import Field
3
3
 
4
+
4
5
  class Settings(BaseSettings):
5
6
  # Logging
6
7
  logs_fraction: float = Field(0.01, alias="LOGS_FRACTION")
8
+ kafka_bootstrap_servers: str | None = Field(
9
+ default=None, alias="KAFKA_BOOTSTRAP_SERVERS"
10
+ )
11
+ kafka_fraction: float = Field(0.01, alias="KAFKA_FRACTION")
12
+ kafka_topic: str = Field(default=None, alias="KAFKA_TOPIC")
7
13
 
8
14
  # Model (S3)
9
15
  s3_model_path: str | None = Field(default=None, alias="S3_MODEL_PATH")
@@ -14,10 +20,10 @@ class Settings(BaseSettings):
14
20
 
15
21
  # Cache
16
22
  stream_cache_maxsize: int = 50_000
17
- user_cache_maxsize: int = 500_000
23
+ user_cache_maxsize: int = 80_000
18
24
  cache_separator: str = "--"
19
25
 
20
26
  class Config:
21
27
  env_file = ".env"
22
28
  env_file_encoding = "utf-8"
23
- extra = "ignore"
29
+ extra = "ignore"
@@ -336,7 +336,7 @@ def user_pwatched_cleanup(
336
336
  "launch_first_in_session",
337
337
  ]
338
338
  _validate_pwatched_entry_context(entry_contexts)
339
- counts_obj = user.get("PWATCHED#6M", UserPWatched())
339
+ counts_obj = user.get("PWATCHED#6M", UserPWatched()).data
340
340
  out = _cleanup_entry_context_counts(
341
341
  counts_obj=counts_obj,
342
342
  entry_contexts=entry_contexts,
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.3.3
3
+ Version: 0.4.0
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  Requires-Dist: protobuf==6.33.2
10
+ Requires-Dist: orjson==3.11.7
10
11
  Provides-Extra: server
11
12
  Requires-Dist: pydantic==2.5.0; extra == "server"
12
13
  Requires-Dist: cachetools==5.5.2; extra == "server"
@@ -15,6 +16,7 @@ Requires-Dist: aioboto3==12.0.0; extra == "server"
15
16
  Requires-Dist: fastapi==0.104.1; extra == "server"
16
17
  Requires-Dist: pydantic-settings==2.2; extra == "server"
17
18
  Requires-Dist: newrelic==11.1.0; extra == "server"
19
+ Requires-Dist: confluent-kafka==2.13.0; extra == "server"
18
20
 
19
21
  # Haystack ML Stack
20
22
 
@@ -1,6 +1,7 @@
1
1
  README.md
2
2
  pyproject.toml
3
3
  src/haystack_ml_stack/__init__.py
4
+ src/haystack_ml_stack/_kafka.py
4
5
  src/haystack_ml_stack/_serializers.py
5
6
  src/haystack_ml_stack/app.py
6
7
  src/haystack_ml_stack/cache.py
@@ -1,4 +1,5 @@
1
1
  protobuf==6.33.2
2
+ orjson==3.11.7
2
3
 
3
4
  [server]
4
5
  pydantic==2.5.0
@@ -8,3 +9,4 @@ aioboto3==12.0.0
8
9
  fastapi==0.104.1
9
10
  pydantic-settings==2.2
10
11
  newrelic==11.1.0
12
+ confluent-kafka==2.13.0
@@ -560,3 +560,78 @@ def test_stream_similarity_top_category_functions():
560
560
  assert all(
561
561
  actual_key == expected_key for actual_key, expected_key in zip(actual, expected)
562
562
  )
563
+
564
+
565
+ def test_user_pwatched_cleanup():
566
+ user_pwatched_data = {
567
+ "version": 1,
568
+ "data": {
569
+ "sel_thumb": {"attempts": 1, "watched": 1},
570
+ "ch_swtch": {"attempts": 2, "watched": 0},
571
+ },
572
+ }
573
+ user_pwatched_msg = features_pb2_v1.UserPWatched()
574
+ ProtoParseDict(js_dict=user_pwatched_data, message=user_pwatched_msg)
575
+ user = {"PWATCHED#6M": user_pwatched_msg}
576
+ out = {}
577
+ utils.user_pwatched_cleanup(
578
+ user=user, entry_contexts=["autoplay", "sel_thumb", "ch_swtch"], out=out
579
+ )
580
+ expected = pd.Series(
581
+ {
582
+ "USER_AUTOPLAY_6M_TOTAL_ATTEMPTS": 0,
583
+ "USER_AUTOPLAY_6M_TOTAL_WATCHED": 0,
584
+ "USER_SEL_THUMB_6M_TOTAL_ATTEMPTS": 1,
585
+ "USER_SEL_THUMB_6M_TOTAL_WATCHED": 1,
586
+ "USER_CH_SWTCH_6M_TOTAL_ATTEMPTS": 2,
587
+ "USER_CH_SWTCH_6M_TOTAL_WATCHED": 0,
588
+ }
589
+ )
590
+ actual = pd.Series(out).loc[expected.index]
591
+ assert (expected == actual).all()
592
+
593
+
594
+ def test_user_pselect_cleanup():
595
+ user_pselect_data = {
596
+ "version": 1,
597
+ "data": {
598
+ "all_browsed": {
599
+ "first_pos": {
600
+ "total_selects": 0,
601
+ "total_selects_and_watched": 0,
602
+ "total_browsed": 1,
603
+ },
604
+ "rest_pos": {
605
+ "total_selects": 2,
606
+ "total_selects_and_watched": 2,
607
+ "total_browsed": 1,
608
+ },
609
+ },
610
+ "up_to_4_browsed": {
611
+ "first_pos": {
612
+ "total_selects": 0,
613
+ "total_selects_and_watched": 0,
614
+ "total_browsed": 1,
615
+ },
616
+ "rest_pos": {
617
+ "total_selects": 2,
618
+ "total_selects_and_watched": 2,
619
+ "total_browsed": 0,
620
+ },
621
+ },
622
+ },
623
+ }
624
+ user_pselect_msg = features_pb2_v1.UserPSelect()
625
+ ProtoParseDict(js_dict=user_pselect_data, message=user_pselect_msg)
626
+ user = {"PSELECT#6M": user_pselect_msg}
627
+ out = {}
628
+ utils.user_pselect_cleanup(user=user, position_debiasing="up_to_4_browsed", out=out)
629
+ expected = pd.Series(
630
+ {
631
+ "USER_6M_TOTAL_BROWSED_UP_TO_4_BROWSED": 1,
632
+ "USER_6M_TOTAL_SELECTS_UP_TO_4_BROWSED": 2,
633
+ "USER_6M_TOTAL_SELECTS_AND_WATCHED_UP_TO_4_BROWSED": 2,
634
+ }
635
+ )
636
+ actual = pd.Series(out).loc[expected.index]
637
+ assert (actual == expected).all()