haystack-ml-stack 0.4.4__tar.gz → 0.4.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/PKG-INFO +1 -1
  2. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/pyproject.toml +1 -1
  3. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/__init__.py +2 -2
  4. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/_kafka.py +40 -0
  5. haystack_ml_stack-0.4.5/src/haystack_ml_stack/_version.py +1 -0
  6. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/app.py +242 -2
  7. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/dynamo.py +296 -8
  8. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/settings.py +2 -1
  9. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/utils.py +55 -0
  10. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/PKG-INFO +1 -1
  11. haystack_ml_stack-0.4.4/src/haystack_ml_stack/_version.py +0 -1
  12. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/README.md +0 -0
  13. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/setup.cfg +0 -0
  14. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/_serializers.py +0 -0
  15. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/cache.py +0 -0
  16. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/exceptions.py +0 -0
  17. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/generated/__init__.py +0 -0
  18. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/generated/v1/__init__.py +0 -0
  19. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/generated/v1/features_pb2.py +0 -0
  20. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/generated/v1/features_pb2.pyi +0 -0
  21. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/model_store.py +0 -0
  22. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/SOURCES.txt +0 -0
  23. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
  24. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/requires.txt +0 -0
  25. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
  26. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/tests/test_serializers.py +0 -0
  27. {haystack_ml_stack-0.4.4 → haystack_ml_stack-0.4.5}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.4.4
3
+ Version: 0.4.5
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "haystack-ml-stack"
8
- version = "0.4.4"
8
+ version = "0.4.5"
9
9
  description = "Functions related to Haystack ML"
10
10
  readme = "README.md"
11
11
  authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
@@ -1,9 +1,9 @@
1
1
  __all__ = []
2
2
 
3
3
  try:
4
- from .app import create_app
4
+ from .app import create_app, create_stream_app, create_channel_app
5
5
 
6
- __all__ = ["create_app"]
6
+ __all__ = ["create_app", "create_stream_app", "create_channel_app"]
7
7
  except ImportError as e:
8
8
  pass
9
9
 
@@ -23,6 +23,7 @@ async def send_to_kafka(
23
23
  state: dict,
24
24
  model_output: dict,
25
25
  monitoring_meta: dict,
26
+ query_params: dict,
26
27
  processed_at: datetime.datetime,
27
28
  ) -> None:
28
29
  if topic is None or producer is None:
@@ -43,6 +44,45 @@ async def send_to_kafka(
43
44
  "playlist_category": playlist.get("category"),
44
45
  "user_features": state.get("user_features", []),
45
46
  "stream_features": state.get("stream_features", []),
47
+ "query_params": query_params,
48
+ },
49
+ "processed_at": processed_at.isoformat(),
50
+ }
51
+ delivery_future = await producer.produce(
52
+ topic, orjson.dumps(message, default=default_serialization)
53
+ )
54
+ await delivery_future
55
+ return
56
+
57
+
58
+ async def send_channels_to_kafka(
59
+ producer: AIOProducer,
60
+ topic: str,
61
+ user: dict,
62
+ channels: list[dict],
63
+ state: dict,
64
+ model_output: dict,
65
+ monitoring_meta: dict,
66
+ query_params: dict,
67
+ processed_at: datetime.datetime,
68
+ ) -> None:
69
+ if topic is None or producer is None:
70
+ return
71
+ message = {
72
+ "userid": user.get("userid"),
73
+ "client_os": user.get("clientOs"),
74
+ "model_input": {"user": user, "channels": channels},
75
+ "model_output": model_output,
76
+ "model_name": state["model_name"].replace(".pkl", "")
77
+ if state["model_name"]
78
+ else None,
79
+ "model_type": "channels",
80
+ "meta": {
81
+ "monitoring": monitoring_meta,
82
+ "haystack_ml_stack_version": __version__,
83
+ "user_features": state.get("user_features", []),
84
+ "global_features": state.get("global_features", []),
85
+ "query_params": query_params,
46
86
  },
47
87
  "processed_at": processed_at.isoformat(),
48
88
  }
@@ -0,0 +1 @@
1
+ __version__ = "0.4.5"
@@ -18,12 +18,12 @@ from fastapi.encoders import jsonable_encoder
18
18
  import newrelic.agent
19
19
 
20
20
  from .cache import make_features_cache
21
- from .dynamo import set_all_features, FeatureRetrievalMeta
21
+ from .dynamo import set_all_features, create_channel_candidates, FeatureRetrievalMeta
22
22
  from .model_store import download_and_load_model
23
23
  from .settings import Settings
24
24
  from . import exceptions
25
25
  from ._serializers import SerializerRegistry
26
- from ._kafka import send_to_kafka, initialize_kafka_producer, should_log_user
26
+ from ._kafka import send_to_kafka, send_channels_to_kafka, initialize_kafka_producer, should_log_user
27
27
  from google.protobuf import text_format
28
28
 
29
29
  logging.basicConfig(
@@ -65,6 +65,7 @@ async def load_model(state, cfg: Settings) -> None:
65
65
  )
66
66
  state["stream_features"] = state["model"].get("stream_features", [])
67
67
  state["user_features"] = state["model"].get("user_features", [])
68
+ state["global_features"] = state["model"].get("global_features", [])
68
69
  valid_features = set(
69
70
  (entity_type, feature_id)
70
71
  for entity_type, feature_id, _ in SerializerRegistry.keys()
@@ -332,6 +333,7 @@ def create_app(
332
333
  state=state,
333
334
  model_output=model_output,
334
335
  monitoring_meta=monitoring_meta,
336
+ query_params=query_params,
335
337
  processed_at=processed_at,
336
338
  )
337
339
  return jsonable_encoder(model_output)
@@ -348,3 +350,241 @@ def create_app(
348
350
  }
349
351
 
350
352
  return app
353
+
354
+
355
+ create_stream_app = create_app
356
+
357
+
358
+ def create_channel_app(
359
+ settings: Optional[Settings] = None,
360
+ *,
361
+ preloaded_model: Optional[Dict[str, Any]] = None,
362
+ ) -> FastAPI:
363
+ cfg = settings or Settings()
364
+
365
+ global_features_cache = make_features_cache(cfg.global_cache_maxsize)
366
+ user_features_cache = make_features_cache(cfg.user_cache_maxsize)
367
+ aws_session = aiobotocore.session.get_session()
368
+ state: Dict[str, Any] = {
369
+ "model": preloaded_model,
370
+ "session": aws_session,
371
+ "model_name": (
372
+ os.path.basename(cfg.s3_model_path) if cfg.s3_model_path else None
373
+ ),
374
+ }
375
+
376
+ @asynccontextmanager
377
+ async def lifespan(app_server: FastAPI):
378
+ if state["model"] is None:
379
+ await load_model(state, cfg)
380
+ kafka_producer = None
381
+ if cfg.kafka_bootstrap_servers is not None:
382
+ kafka_producer = initialize_kafka_producer(app_config=cfg)
383
+ state["kafka_producer"] = kafka_producer
384
+ async with AsyncExitStack() as stack:
385
+ session = state["session"]
386
+ state["dynamo_client"] = await stack.enter_async_context(
387
+ session.create_client(
388
+ "dynamodb",
389
+ config=AioConfig(max_pool_connections=MAX_POOL_CONNECTIONS),
390
+ )
391
+ )
392
+ logger.info("DynamoDB persistent client initialized.")
393
+ yield
394
+ logger.info("Shutting down: Connection pools closed.")
395
+ logger.info("Shutting down: Flushing Kafka queue.")
396
+ if kafka_producer is not None:
397
+ try:
398
+ await kafka_producer.flush()
399
+ except Exception:
400
+ logger.error(
401
+ "Unknown exception while flushing kafka queue, shutting down producer.\n%s",
402
+ traceback.format_exc(),
403
+ )
404
+ finally:
405
+ await kafka_producer.close()
406
+
407
+ app = FastAPI(
408
+ title="ML Channel Scorer",
409
+ description="Scores channels using a pre-trained ML model and DynamoDB features.",
410
+ version="1.0.0",
411
+ lifespan=lifespan,
412
+ )
413
+
414
+ @app.get("/health", status_code=HTTPStatus.OK)
415
+ async def health():
416
+ model_ok = state["model"] is not None
417
+ if not model_ok:
418
+ raise HTTPException(
419
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
420
+ detail="ML Model not loaded",
421
+ )
422
+ return {
423
+ "status": "ok",
424
+ "model_loaded": True,
425
+ "global_cache_size": len(global_features_cache),
426
+ "user_cache_size": len(user_features_cache),
427
+ "global_features": state.get("global_features", []),
428
+ "user_features": state.get("user_features", []),
429
+ "model_name": state.get("model_name"),
430
+ }
431
+
432
+ @app.post("/score", status_code=HTTPStatus.OK)
433
+ async def score_channels(
434
+ request: Request, response: Response, background_tasks: BackgroundTasks
435
+ ):
436
+ if state["model"] is None:
437
+ raise HTTPException(
438
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
439
+ detail="ML Model not loaded",
440
+ )
441
+
442
+ try:
443
+ data = await request.json()
444
+ except json.JSONDecodeError as e:
445
+ body = await request.body()
446
+ logger.error(
447
+ "Received malformed json. Raw body: %s\n%s",
448
+ body.decode(errors="replace"),
449
+ traceback.format_exc(),
450
+ )
451
+ raise HTTPException(
452
+ status_code=HTTPStatus.BAD_REQUEST, detail="Invalid JSON payload"
453
+ ) from e
454
+ except Exception as e:
455
+ logger.error(
456
+ "Unexpected exception when parsing request.\n %s",
457
+ traceback.format_exc(),
458
+ )
459
+ raise HTTPException(
460
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Unknown exception"
461
+ ) from e
462
+
463
+ query_params = {}
464
+ for k in request.query_params.keys():
465
+ values = request.query_params.getlist(k)
466
+ query_params[k] = values[0] if len(values) == 1 else values
467
+
468
+ user = data
469
+ channels: List[Dict[str, Any]] = []
470
+
471
+ retrieval_meta = FeatureRetrievalMeta(
472
+ cache_misses=0,
473
+ stream_cache_misses=0,
474
+ user_cache_misses=0,
475
+ retrieval_ms=0,
476
+ success=True,
477
+ cache_delay_minutes=0,
478
+ dynamo_ms=0,
479
+ parsing_ms=0,
480
+ )
481
+ try:
482
+ retrieval_meta = await create_channel_candidates(
483
+ dynamo_client=state["dynamo_client"],
484
+ user=user,
485
+ channels=channels,
486
+ global_features=state.get("global_features", []),
487
+ user_features=state.get("user_features", []),
488
+ global_features_cache=global_features_cache,
489
+ user_features_cache=user_features_cache,
490
+ features_table=cfg.features_table,
491
+ cache_sep=cfg.cache_separator,
492
+ )
493
+ except exceptions.InvalidFeaturesException as e:
494
+ logger.error(
495
+ "The following features are not present in the SerializerRegistry %s",
496
+ e,
497
+ )
498
+ raise HTTPException(
499
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
500
+ detail=f"Received invalid features from feature store: {e}",
501
+ ) from e
502
+
503
+ random_number = random.random()
504
+ userid = user.get("userid", "")
505
+
506
+ if random_number < cfg.logs_fraction:
507
+ logger.info(
508
+ "User %s data: %s",
509
+ userid,
510
+ user,
511
+ )
512
+ logger.info(
513
+ "User %s channels: %s",
514
+ userid,
515
+ channels,
516
+ )
517
+
518
+ model = state["model"]
519
+ try:
520
+ preprocess_start = time.perf_counter_ns()
521
+ model["params"]["query_params"] = query_params
522
+ model_input = model["preprocess"](
523
+ user,
524
+ channels,
525
+ model["params"],
526
+ )
527
+ predict_start = time.perf_counter_ns()
528
+ model_output = model["predict"](model_input, model["params"])
529
+ predict_end = time.perf_counter_ns()
530
+ except Exception as e:
531
+ logger.error("Model prediction failed: \n%s", traceback.format_exc())
532
+ raise HTTPException(
533
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
534
+ detail="Model prediction failed",
535
+ ) from e
536
+
537
+ monitoring_meta = {
538
+ "cache_misses": retrieval_meta.cache_misses,
539
+ "user_cache_misses": retrieval_meta.user_cache_misses,
540
+ "global_cache_misses": retrieval_meta.stream_cache_misses,
541
+ "user_cache_size": len(user_features_cache),
542
+ "global_cache_size": len(global_features_cache),
543
+ "retrieval_success": int(retrieval_meta.success),
544
+ "cache_delay_minutes": retrieval_meta.cache_delay_minutes,
545
+ "dynamo_ms": retrieval_meta.dynamo_ms,
546
+ "dynamo_parse_ms": retrieval_meta.parsing_ms,
547
+ "retrieval_ms": retrieval_meta.retrieval_ms,
548
+ "preprocess_ms": (predict_start - preprocess_start) * 1e-6,
549
+ "predict_ms": (predict_end - predict_start) * 1e-6,
550
+ "total_channels": len(model_output),
551
+ }
552
+ newrelic.agent.record_custom_event(
553
+ "ChannelInference",
554
+ monitoring_meta,
555
+ )
556
+ if model_output:
557
+ if random_number < cfg.logs_fraction:
558
+ logger.info(
559
+ "User %s - model output %s",
560
+ userid,
561
+ model_output,
562
+ )
563
+ if should_log_user(userid=userid, kafka_fraction=cfg.kafka_fraction):
564
+ processed_at = datetime.datetime.now(tz=datetime.UTC)
565
+ background_tasks.add_task(
566
+ send_channels_to_kafka,
567
+ producer=state["kafka_producer"],
568
+ topic=cfg.kafka_topic,
569
+ user=user,
570
+ channels=channels,
571
+ state=state,
572
+ model_output=model_output,
573
+ monitoring_meta=monitoring_meta,
574
+ query_params=query_params,
575
+ processed_at=processed_at,
576
+ )
577
+ return jsonable_encoder(model_output)
578
+
579
+ raise HTTPException(
580
+ status_code=HTTPStatus.NOT_FOUND, detail="No model output generated"
581
+ )
582
+
583
+ @app.get("/", status_code=HTTPStatus.OK)
584
+ async def root():
585
+ return {
586
+ "message": "ML Channel Scoring Service is running.",
587
+ "model_name": state.get("model_name"),
588
+ }
589
+
590
+ return app
@@ -1,13 +1,15 @@
1
- from typing import Any, Dict, List, NamedTuple, Literal
1
+ import asyncio
2
+ import datetime
2
3
  import logging
3
4
  import time
4
- import datetime
5
- from boto3.dynamodb.types import TypeDeserializer
5
+ from typing import Any, Dict, List, Literal, NamedTuple
6
+
6
7
  import newrelic.agent
7
- import asyncio
8
- from ._serializers import SerializerRegistry, FeatureRegistryId
9
- from . import exceptions
8
+ from boto3.dynamodb.types import TypeDeserializer
10
9
 
10
+ from . import exceptions
11
+ from ._serializers import FeatureRegistryId, SerializerRegistry
12
+ from .utils import _complete_features_for_channels
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
@@ -21,7 +23,7 @@ class FloatDeserializer(TypeDeserializer):
21
23
 
22
24
 
23
25
  _deser = FloatDeserializer()
24
- IdType = Literal["STREAM", "USER"]
26
+ IdType = Literal["STREAM", "USER", "GLOBAL"]
25
27
 
26
28
 
27
29
  class FeatureRetrievalMeta(NamedTuple):
@@ -268,9 +270,295 @@ async def set_all_features(
268
270
  )
269
271
 
270
272
 
273
+ _MOBILE_OS = {"ios", "android", "iphone", "galaxy"}
274
+
275
+
276
+ def _get_os_cat(client_os: str) -> str:
277
+ normalized = client_os.lower().replace("debug","") if client_os else ""
278
+ return "MOBILE" if normalized in _MOBILE_OS else "TV"
279
+
280
+
281
+ @newrelic.agent.function_trace()
282
+ async def create_channel_candidates(
283
+ *,
284
+ user: Dict[str, Any],
285
+ channels: List[Dict[str, Any]],
286
+ global_features,
287
+ user_features,
288
+ global_features_cache,
289
+ user_features_cache,
290
+ features_table: str,
291
+ cache_sep: str,
292
+ dynamo_client,
293
+ ) -> FeatureRetrievalMeta:
294
+ time_start = time.perf_counter_ns()
295
+
296
+ os_cat = _get_os_cat(user.get("clientOs", ""))
297
+
298
+ global_holder: Dict[str, Any] = {}
299
+ cache_miss: Dict[str, Dict[str, Any]] = {}
300
+ all_feature_keys = [*global_features, *user_features]
301
+ cache_delay_obj: dict[str, float] = {f: 0 for f in all_feature_keys}
302
+ now = datetime.datetime.utcnow()
303
+
304
+ for f in global_features:
305
+ cache_miss, cache_delay_obj = _check_cache(
306
+ obj=global_holder,
307
+ id_type="GLOBAL",
308
+ id_key="GLOBAL",
309
+ feature_key=f,
310
+ cache_sep=cache_sep,
311
+ features_cache=global_features_cache,
312
+ cache_miss=cache_miss,
313
+ cache_delay=cache_delay_obj,
314
+ now=now,
315
+ )
316
+ global_cache_misses = len(cache_miss)
317
+
318
+ for f in user_features:
319
+ cache_miss, cache_delay_obj = _check_cache(
320
+ obj=user,
321
+ id_type="USER",
322
+ id_key=user["userid"],
323
+ feature_key=f,
324
+ cache_sep=cache_sep,
325
+ features_cache=user_features_cache,
326
+ cache_miss=cache_miss,
327
+ cache_delay=cache_delay_obj,
328
+ now=now,
329
+ )
330
+ user_cache_misses = len(cache_miss) - global_cache_misses
331
+
332
+ valid_cache_delays = list(v for v in cache_delay_obj.values() if v > 0)
333
+ cache_delay = min(valid_cache_delays) if valid_cache_delays else 0
334
+
335
+ if not cache_miss:
336
+ _process_channels(
337
+ channels, global_features, user_features, global_holder, user, os_cat
338
+ )
339
+ return FeatureRetrievalMeta(
340
+ user_cache_misses=0,
341
+ stream_cache_misses=0,
342
+ cache_misses=0,
343
+ retrieval_ms=(time.perf_counter_ns() - time_start) * 1e-6,
344
+ success=True,
345
+ cache_delay_minutes=cache_delay / 60,
346
+ dynamo_ms=0,
347
+ parsing_ms=0,
348
+ )
349
+
350
+ cache_misses = len(cache_miss)
351
+ logger.info(
352
+ "Channel candidates cache miss for %d items (%d global, %d user)",
353
+ cache_misses,
354
+ global_cache_misses,
355
+ user_cache_misses,
356
+ )
357
+
358
+ keys = []
359
+ for k in cache_miss.keys():
360
+ id_type, id_key, sk = k.split(cache_sep, 2)
361
+ if id_type == "GLOBAL":
362
+ pk = "GLOBAL"
363
+ else:
364
+ pk = f"{id_type}#{id_key}"
365
+ keys.append({"pk": {"S": pk}, "sk": {"S": sk}})
366
+
367
+ dynamo_start = time.perf_counter_ns()
368
+ try:
369
+ items = await async_batch_get(dynamo_client, features_table, keys)
370
+ except Exception as e:
371
+ logger.error("DynamoDB batch_get failed for channel candidates: %s", e)
372
+ end_time = time.perf_counter_ns()
373
+ return FeatureRetrievalMeta(
374
+ user_cache_misses=user_cache_misses,
375
+ stream_cache_misses=global_cache_misses,
376
+ cache_misses=0,
377
+ retrieval_ms=(end_time - time_start) * 1e-6,
378
+ success=False,
379
+ cache_delay_minutes=cache_delay / 60,
380
+ dynamo_ms=(end_time - dynamo_start) * 1e-6,
381
+ parsing_ms=0,
382
+ )
383
+
384
+ dynamo_end = time.perf_counter_ns()
385
+ updated_keys = set()
386
+ for item in items:
387
+ full_id = item["pk"]["S"]
388
+ if "#" in full_id:
389
+ id_type, id_key = full_id.split("#", 1)
390
+ else:
391
+ id_type = full_id
392
+ id_key = full_id
393
+ feature_name = item["sk"]["S"]
394
+
395
+ if id_type == "GLOBAL":
396
+ cache_to_use = global_features_cache
397
+ elif id_type == "USER":
398
+ cache_to_use = user_features_cache
399
+ else:
400
+ raise ValueError(
401
+ f"Unexpected id type in channel candidates. "
402
+ f"Expected 'GLOBAL' or 'USER', received {id_type}"
403
+ )
404
+
405
+ cache_key = _build_cache_key(
406
+ id_type=id_type,
407
+ id_key=id_key,
408
+ feature_key=feature_name,
409
+ cache_sep=cache_sep,
410
+ )
411
+ parsed = parse_dynamo_item(item)
412
+ feature_version = parsed.get("version", "v0")
413
+ feature_id = FeatureRegistryId(
414
+ entity_type=id_type, feature_id=feature_name, version=feature_version
415
+ )
416
+ try:
417
+ serializer = SerializerRegistry[feature_id]
418
+ except KeyError as e:
419
+ raise exceptions.InvalidFeaturesException(
420
+ f"Could not find '{feature_id}' in serializer registry"
421
+ ) from e
422
+ try:
423
+ value = (
424
+ serializer.deserialize(parsed.get("value"))
425
+ if parsed.get("value")
426
+ else None
427
+ )
428
+ except TypeError as e:
429
+ raise exceptions.DeserializationException(
430
+ f"Ran into an error while deserializing {feature_id}. Error: {e}"
431
+ ) from e
432
+
433
+ cache_to_use[cache_key] = {
434
+ "value": value,
435
+ "cache_ttl_in_seconds": int(parsed.get("cache_ttl_in_seconds", -1)),
436
+ "inserted_at": datetime.datetime.utcnow(),
437
+ }
438
+
439
+ if cache_key in cache_miss:
440
+ cache_miss[cache_key][feature_name] = value
441
+ updated_keys.add(cache_key)
442
+
443
+ parsing_end = time.perf_counter_ns()
444
+
445
+ if len(updated_keys) < len(cache_miss):
446
+ missing_keys = set(cache_miss.keys()) - updated_keys
447
+ for k in missing_keys:
448
+ id_type = _get_id_type_from_partition_key(k, sep=cache_sep)
449
+ if id_type == "GLOBAL":
450
+ global_features_cache[k] = {
451
+ "value": None,
452
+ "cache_ttl_in_seconds": 300,
453
+ }
454
+ elif id_type == "USER":
455
+ user_features_cache[k] = {
456
+ "value": None,
457
+ "cache_ttl_in_seconds": 6 * 3600,
458
+ }
459
+
460
+ _process_channels(
461
+ channels, global_features, user_features, global_holder, user, os_cat
462
+ )
463
+
464
+ end_time = time.perf_counter_ns()
465
+ return FeatureRetrievalMeta(
466
+ cache_misses=global_cache_misses + user_cache_misses,
467
+ user_cache_misses=user_cache_misses,
468
+ stream_cache_misses=global_cache_misses,
469
+ retrieval_ms=_perf_counter_ns_delta_in_ms(time_start, end_time),
470
+ success=True,
471
+ cache_delay_minutes=cache_delay / 60,
472
+ dynamo_ms=_perf_counter_ns_delta_in_ms(dynamo_start, dynamo_end),
473
+ parsing_ms=_perf_counter_ns_delta_in_ms(dynamo_end, parsing_end),
474
+ )
475
+
476
+
477
+ def _process_channels(
478
+ channels: List[Dict[str, Any]],
479
+ global_features: List[str],
480
+ user_features: List[str],
481
+ global_holder: Dict[str, Any],
482
+ user: Dict[str, Any],
483
+ os_cat: str,
484
+ ) -> None:
485
+ channel_candidates = global_holder.get("CHANNEL_CANDIDATES")
486
+
487
+ global_features_os_cat = [f for f in global_features if f.endswith(f"#{os_cat}")]
488
+
489
+ playlist_stats_global = {
490
+ "OS_CAT_" + k.replace(f"#{os_cat}", "").replace("#", "_"): v
491
+ for k, v in global_holder.items()
492
+ if k in global_features_os_cat
493
+ }
494
+ playlist_stats_user = {
495
+ "USER_" + k.replace("#", "_"): v for k, v in user.items() if k in user_features
496
+ }
497
+
498
+ all_channels = set()
499
+ if playlist_stats_global:
500
+ for stat in playlist_stats_global.values():
501
+ all_channels.update(stat.data.keys())
502
+ if playlist_stats_user:
503
+ for stat in playlist_stats_user.values():
504
+ all_channels.update(stat.data.keys())
505
+
506
+ # Get not preferred channels to be ignored
507
+ # The group labels are also ignored, not real channels for UI
508
+ ignore_channels = set(["national_favorite", "local_favorite"])
509
+ inserted_channel_names = set()
510
+ if channel_candidates is not None:
511
+ preferred = set(user.get("preferredChannels", []) or [])
512
+ for ch in channel_candidates.data:
513
+ if ch.category_group in ("national_favorite", "local_favorite"):
514
+ if ch.name not in preferred:
515
+ ignore_channels.add(ch.name)
516
+
517
+ for ch in channel_candidates.data:
518
+ if ch.name not in ignore_channels:
519
+ channels.append(
520
+ {
521
+ "name": ch.name,
522
+ "category_group": ch.category_group,
523
+ "start_date": ch.start_date,
524
+ }
525
+ )
526
+ inserted_channel_names.add(ch.name)
527
+
528
+ DEFAULT_CHANNELS = [
529
+ "local news",
530
+ "science & technology",
531
+ "business & finance",
532
+ "entertainment news",
533
+ "live",
534
+ "live_es",
535
+ "weather",
536
+ "politics",
537
+ "international",
538
+ "top videos",
539
+ "editor picks",
540
+ ]
541
+
542
+ for name in all_channels | set(DEFAULT_CHANNELS):
543
+ if name not in ignore_channels and name not in inserted_channel_names:
544
+ channels.append(
545
+ {
546
+ "name": name,
547
+ "category_group": "",
548
+ "start_date": int(datetime.datetime.now().timestamp()),
549
+ }
550
+ )
551
+
552
+ _complete_features_for_channels(
553
+ channels=channels,
554
+ user_features=playlist_stats_user,
555
+ global_features=playlist_stats_global,
556
+ )
557
+
558
+
271
559
  def _check_cache(
272
560
  obj: dict,
273
- id_type: Literal["STREAM", "USER"],
561
+ id_type: IdType,
274
562
  id_key: str,
275
563
  feature_key: str,
276
564
  cache_sep: str,
@@ -9,7 +9,7 @@ class Settings(BaseSettings):
9
9
  default=None, alias="KAFKA_BOOTSTRAP_SERVERS"
10
10
  )
11
11
  kafka_fraction: float = Field(0.01, alias="KAFKA_FRACTION")
12
- kafka_topic: str = Field(default=None, alias="KAFKA_TOPIC")
12
+ kafka_topic: str | None = Field(default=None, alias="KAFKA_TOPIC")
13
13
 
14
14
  # Model (S3)
15
15
  s3_model_path: str | None = Field(default=None, alias="S3_MODEL_PATH")
@@ -21,6 +21,7 @@ class Settings(BaseSettings):
21
21
  # Cache
22
22
  stream_cache_maxsize: int = 50_000
23
23
  user_cache_maxsize: int = 80_000
24
+ global_cache_maxsize: int = 1_000
24
25
  cache_separator: str = "--"
25
26
 
26
27
  class Config:
@@ -13,6 +13,10 @@ from .generated.v1.features_pb2 import (
13
13
  UserPersonalizingPSelect,
14
14
  UserPSelect,
15
15
  EntryContextPWatched,
16
+ PlaylistStatsForGlobal,
17
+ PlaylistStatsForUser,
18
+ UserPlaylistStats,
19
+ GlobalPlaylistStats,
16
20
  )
17
21
  from ._serializers import SerializerRegistry
18
22
  from . import exceptions
@@ -673,3 +677,54 @@ def _validate_pwatched_entry_context(entry_contexts: list[str]):
673
677
  invalid_contexts = [c for c in entry_contexts if c not in valid_contexts]
674
678
  if invalid_contexts:
675
679
  raise ValueError(f"Invalid entry contexts found: {invalid_contexts}")
680
+
681
+
682
+ def _complete_features_for_channels(
683
+ channels: list[dict],
684
+ user_features: dict[str, UserPlaylistStats],
685
+ global_features: dict[str, GlobalPlaylistStats],
686
+ ) -> None:
687
+
688
+ GLOBAL_FEATURES = [
689
+ "watched_count",
690
+ "not_watched_count",
691
+ "capped_watched_secs",
692
+ "capped_not_watched_secs",
693
+ "watched_secs",
694
+ "not_watched_secs",
695
+ ]
696
+
697
+ USER_FEATURES = [
698
+ "total_days",
699
+ "start_days",
700
+ "active_days",
701
+ "total_watched",
702
+ "capped_total_watched",
703
+ ]
704
+
705
+ for ch in channels:
706
+ name = ch.get("name", "")
707
+ category_group = ch.get("category_group", "")
708
+ for prefix, global_feature in global_features.items():
709
+ global_feature = global_feature or GlobalPlaylistStats()
710
+ ch[prefix] = {}
711
+
712
+ if name in global_feature.data:
713
+ features = global_feature.data.get(name, PlaylistStatsForGlobal())
714
+ for feature in GLOBAL_FEATURES:
715
+ ch[prefix][feature] = getattr(features, feature, 0)
716
+ else:
717
+ # Some global features are at the category group level instead of the channel level
718
+ features = global_feature.data.get(
719
+ category_group, PlaylistStatsForGlobal()
720
+ )
721
+ for feature in GLOBAL_FEATURES:
722
+ ch[prefix][feature] = getattr(features, feature, 0)
723
+
724
+ for prefix, user_feature in user_features.items():
725
+ user_feature = user_feature or UserPlaylistStats()
726
+ ch[prefix] = {}
727
+ for feature in USER_FEATURES:
728
+ ch[prefix][feature] = getattr(
729
+ user_feature.data.get(name, PlaylistStatsForUser()), feature, 0
730
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.4.4
3
+ Version: 0.4.5
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
@@ -1 +0,0 @@
1
- __version__ = "0.4.4"