haystack-ml-stack 0.4.3__tar.gz → 0.4.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/PKG-INFO +1 -1
  2. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/pyproject.toml +1 -1
  3. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/__init__.py +2 -2
  4. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/_kafka.py +40 -0
  5. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/_serializers.py +112 -1
  6. haystack_ml_stack-0.4.5/src/haystack_ml_stack/_version.py +1 -0
  7. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/app.py +242 -2
  8. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/dynamo.py +296 -8
  9. haystack_ml_stack-0.4.5/src/haystack_ml_stack/generated/v1/features_pb2.py +90 -0
  10. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/generated/v1/features_pb2.pyi +79 -1
  11. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/settings.py +2 -1
  12. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/utils.py +55 -0
  13. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/PKG-INFO +1 -1
  14. haystack_ml_stack-0.4.3/src/haystack_ml_stack/_version.py +0 -1
  15. haystack_ml_stack-0.4.3/src/haystack_ml_stack/generated/v1/features_pb2.py +0 -70
  16. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/README.md +0 -0
  17. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/setup.cfg +0 -0
  18. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/cache.py +0 -0
  19. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/exceptions.py +0 -0
  20. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/generated/__init__.py +0 -0
  21. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/generated/v1/__init__.py +0 -0
  22. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack/model_store.py +0 -0
  23. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/SOURCES.txt +0 -0
  24. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
  25. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/requires.txt +0 -0
  26. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
  27. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/tests/test_serializers.py +0 -0
  28. {haystack_ml_stack-0.4.3 → haystack_ml_stack-0.4.5}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.4.3
3
+ Version: 0.4.5
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "haystack-ml-stack"
8
- version = "0.4.3"
8
+ version = "0.4.5"
9
9
  description = "Functions related to Haystack ML"
10
10
  readme = "README.md"
11
11
  authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
@@ -1,9 +1,9 @@
1
1
  __all__ = []
2
2
 
3
3
  try:
4
- from .app import create_app
4
+ from .app import create_app, create_stream_app, create_channel_app
5
5
 
6
- __all__ = ["create_app"]
6
+ __all__ = ["create_app", "create_stream_app", "create_channel_app"]
7
7
  except ImportError as e:
8
8
  pass
9
9
 
@@ -23,6 +23,7 @@ async def send_to_kafka(
23
23
  state: dict,
24
24
  model_output: dict,
25
25
  monitoring_meta: dict,
26
+ query_params: dict,
26
27
  processed_at: datetime.datetime,
27
28
  ) -> None:
28
29
  if topic is None or producer is None:
@@ -43,6 +44,45 @@ async def send_to_kafka(
43
44
  "playlist_category": playlist.get("category"),
44
45
  "user_features": state.get("user_features", []),
45
46
  "stream_features": state.get("stream_features", []),
47
+ "query_params": query_params,
48
+ },
49
+ "processed_at": processed_at.isoformat(),
50
+ }
51
+ delivery_future = await producer.produce(
52
+ topic, orjson.dumps(message, default=default_serialization)
53
+ )
54
+ await delivery_future
55
+ return
56
+
57
+
58
+ async def send_channels_to_kafka(
59
+ producer: AIOProducer,
60
+ topic: str,
61
+ user: dict,
62
+ channels: list[dict],
63
+ state: dict,
64
+ model_output: dict,
65
+ monitoring_meta: dict,
66
+ query_params: dict,
67
+ processed_at: datetime.datetime,
68
+ ) -> None:
69
+ if topic is None or producer is None:
70
+ return
71
+ message = {
72
+ "userid": user.get("userid"),
73
+ "client_os": user.get("clientOs"),
74
+ "model_input": {"user": user, "channels": channels},
75
+ "model_output": model_output,
76
+ "model_name": state["model_name"].replace(".pkl", "")
77
+ if state["model_name"]
78
+ else None,
79
+ "model_type": "channels",
80
+ "meta": {
81
+ "monitoring": monitoring_meta,
82
+ "haystack_ml_stack_version": __version__,
83
+ "user_features": state.get("user_features", []),
84
+ "global_features": state.get("global_features", []),
85
+ "query_params": query_params,
46
86
  },
47
87
  "processed_at": processed_at.isoformat(),
48
88
  }
@@ -231,6 +231,73 @@ class UserPersonalizingPSelectSerializerV1(SimpleSerializer):
231
231
  return root_msg.SerializeToString()
232
232
 
233
233
 
234
+ class GlobalPlaylistStatsSerializerV1(SimpleSerializer):
235
+ def __init__(self):
236
+ super().__init__(msg_class=features_pb2_v1.GlobalPlaylistStats)
237
+
238
+ def serialize(self, value: dict) -> bytes:
239
+ root_msg = self.build_msg(value)
240
+ return root_msg.SerializeToString()
241
+
242
+ def build_msg(self, value) -> features_pb2_v1.GlobalPlaylistStats:
243
+ root_msg = features_pb2_v1.GlobalPlaylistStats()
244
+ assert value["version"] == 1, "Wrong version given!"
245
+ root_msg.version = value["version"]
246
+ data = value["data"]
247
+ for category, stats in data.items():
248
+ category_msg = root_msg.data[category]
249
+ category_msg.watched_count = int(stats["watched_count"])
250
+ category_msg.not_watched_count = int(stats["not_watched_count"])
251
+ category_msg.capped_watched_secs = float(stats["capped_watched_secs"])
252
+ category_msg.capped_not_watched_secs = float(stats["capped_not_watched_secs"])
253
+ category_msg.watched_secs = float(stats["watched_secs"])
254
+ category_msg.not_watched_secs = float(stats["not_watched_secs"])
255
+ return root_msg
256
+
257
+
258
+ class UserPlaylistStatsSerializerV1(SimpleSerializer):
259
+ def __init__(self):
260
+ super().__init__(msg_class=features_pb2_v1.UserPlaylistStats)
261
+
262
+ def serialize(self, value: dict) -> bytes:
263
+ root_msg = self.build_msg(value)
264
+ return root_msg.SerializeToString()
265
+
266
+ def build_msg(self, value) -> features_pb2_v1.UserPlaylistStats:
267
+ root_msg = features_pb2_v1.UserPlaylistStats()
268
+ assert value["version"] == 1, "Wrong version given!"
269
+ root_msg.version = value["version"]
270
+ data = value["data"]
271
+ for category, stats in data.items():
272
+ category_msg = root_msg.data[category]
273
+ category_msg.total_days = int(stats["total_days"])
274
+ category_msg.start_days = int(stats["start_days"])
275
+ category_msg.active_days = int(stats["active_days"])
276
+ category_msg.total_watched = float(stats["total_watched"])
277
+ category_msg.capped_total_watched = float(stats["capped_total_watched"])
278
+ return root_msg
279
+
280
+
281
+ class GlobalChannelsSerializerV1(SimpleSerializer):
282
+ def __init__(self):
283
+ super().__init__(msg_class=features_pb2_v1.GlobalChannels)
284
+
285
+ def serialize(self, value: dict) -> bytes:
286
+ root_msg = self.build_msg(value)
287
+ return root_msg.SerializeToString()
288
+
289
+ def build_msg(self, value) -> features_pb2_v1.GlobalChannels:
290
+ root_msg = features_pb2_v1.GlobalChannels()
291
+ assert value["version"] == 1, "Wrong version given!"
292
+ root_msg.version = value["version"]
293
+ for channel in value["data"]:
294
+ channel_msg = root_msg.data.add()
295
+ channel_msg.name = str(channel["name"])
296
+ channel_msg.category_group = str(channel["category_group"])
297
+ channel_msg.start_date = int(channel["start_date"])
298
+ return root_msg
299
+
300
+
234
301
  class PassThroughSerializer(Serializer):
235
302
  def serialize(self, value):
236
303
  return value
@@ -249,10 +316,13 @@ stream_pselect_serializer_v0 = StreamPSelectSerializerV0()
249
316
  stream_pselect_serializer_v1 = StreamPSelectSerializerV1()
250
317
  stream_similarity_scores_serializer_v0 = StreamSimilaritySerializerV0()
251
318
  stream_similarity_scores_serializer_v1 = StreamSimilaritySerializerV1()
319
+ global_playlist_stats_serializer_v1 = GlobalPlaylistStatsSerializerV1()
320
+ user_playlist_stats_serializer_v1 = UserPlaylistStatsSerializerV1()
321
+ global_channels_serializer_v1 = GlobalChannelsSerializerV1()
252
322
 
253
323
 
254
324
  class FeatureRegistryId(_t.NamedTuple):
255
- entity_type: _t.Literal["STREAM", "USER"]
325
+ entity_type: _t.Literal["STREAM", "USER", "GLOBAL"]
256
326
  feature_id: str
257
327
  version: str
258
328
 
@@ -341,6 +411,35 @@ user_bias_pselect_v1_features: list[FeatureRegistryId] = [
341
411
  FeatureRegistryId(entity_type="USER", feature_id="PSELECT#6M", version="v1")
342
412
  ]
343
413
 
414
+ global_playlist_stats_v1_features: list[FeatureRegistryId] = [
415
+ FeatureRegistryId(
416
+ entity_type="GLOBAL",
417
+ feature_id="PLAYLIST_CATEGORY_STATS#1D#MOBILE",
418
+ version="v1",
419
+ ),
420
+ FeatureRegistryId(
421
+ entity_type="GLOBAL",
422
+ feature_id="PLAYLIST_CATEGORY_STATS#1D#TV",
423
+ version="v1",
424
+ ),
425
+ ]
426
+
427
+ user_playlist_stats_v1_features: list[FeatureRegistryId] = [
428
+ FeatureRegistryId(
429
+ entity_type="USER",
430
+ feature_id="PLAYLIST_CATEGORY_STATS#3M",
431
+ version="v1",
432
+ ),
433
+ ]
434
+
435
+ global_channels_v1_features: list[FeatureRegistryId] = [
436
+ FeatureRegistryId(
437
+ entity_type="GLOBAL",
438
+ feature_id="CHANNEL_CANDIDATES",
439
+ version="v1",
440
+ ),
441
+ ]
442
+
344
443
  features_serializer_tuples: list[tuple[list[FeatureRegistryId], Serializer]] = [
345
444
  (stream_pwatched_v0_features, stream_pwatched_serializer_v0),
346
445
  (stream_pwatched_v1_features, stream_pwatched_serializer_v1),
@@ -355,6 +454,18 @@ features_serializer_tuples: list[tuple[list[FeatureRegistryId], Serializer]] = [
355
454
  (user_bias_pwatched_v1_features, user_pwatched_serializer_v1),
356
455
  (user_personalizing_pselect_v1_features, user_personalizing_pselect_serializer_v1),
357
456
  (user_bias_pselect_v1_features, user_pselect_serializer_v1),
457
+ (
458
+ global_playlist_stats_v1_features,
459
+ global_playlist_stats_serializer_v1,
460
+ ),
461
+ (
462
+ user_playlist_stats_v1_features,
463
+ user_playlist_stats_serializer_v1,
464
+ ),
465
+ (
466
+ global_channels_v1_features,
467
+ global_channels_serializer_v1,
468
+ ),
358
469
  ]
359
470
 
360
471
  SerializerRegistry: dict[FeatureRegistryId, Serializer] = {
@@ -0,0 +1 @@
1
+ __version__ = "0.4.5"
@@ -18,12 +18,12 @@ from fastapi.encoders import jsonable_encoder
18
18
  import newrelic.agent
19
19
 
20
20
  from .cache import make_features_cache
21
- from .dynamo import set_all_features, FeatureRetrievalMeta
21
+ from .dynamo import set_all_features, create_channel_candidates, FeatureRetrievalMeta
22
22
  from .model_store import download_and_load_model
23
23
  from .settings import Settings
24
24
  from . import exceptions
25
25
  from ._serializers import SerializerRegistry
26
- from ._kafka import send_to_kafka, initialize_kafka_producer, should_log_user
26
+ from ._kafka import send_to_kafka, send_channels_to_kafka, initialize_kafka_producer, should_log_user
27
27
  from google.protobuf import text_format
28
28
 
29
29
  logging.basicConfig(
@@ -65,6 +65,7 @@ async def load_model(state, cfg: Settings) -> None:
65
65
  )
66
66
  state["stream_features"] = state["model"].get("stream_features", [])
67
67
  state["user_features"] = state["model"].get("user_features", [])
68
+ state["global_features"] = state["model"].get("global_features", [])
68
69
  valid_features = set(
69
70
  (entity_type, feature_id)
70
71
  for entity_type, feature_id, _ in SerializerRegistry.keys()
@@ -332,6 +333,7 @@ def create_app(
332
333
  state=state,
333
334
  model_output=model_output,
334
335
  monitoring_meta=monitoring_meta,
336
+ query_params=query_params,
335
337
  processed_at=processed_at,
336
338
  )
337
339
  return jsonable_encoder(model_output)
@@ -348,3 +350,241 @@ def create_app(
348
350
  }
349
351
 
350
352
  return app
353
+
354
+
355
+ create_stream_app = create_app
356
+
357
+
358
+ def create_channel_app(
359
+ settings: Optional[Settings] = None,
360
+ *,
361
+ preloaded_model: Optional[Dict[str, Any]] = None,
362
+ ) -> FastAPI:
363
+ cfg = settings or Settings()
364
+
365
+ global_features_cache = make_features_cache(cfg.global_cache_maxsize)
366
+ user_features_cache = make_features_cache(cfg.user_cache_maxsize)
367
+ aws_session = aiobotocore.session.get_session()
368
+ state: Dict[str, Any] = {
369
+ "model": preloaded_model,
370
+ "session": aws_session,
371
+ "model_name": (
372
+ os.path.basename(cfg.s3_model_path) if cfg.s3_model_path else None
373
+ ),
374
+ }
375
+
376
+ @asynccontextmanager
377
+ async def lifespan(app_server: FastAPI):
378
+ if state["model"] is None:
379
+ await load_model(state, cfg)
380
+ kafka_producer = None
381
+ if cfg.kafka_bootstrap_servers is not None:
382
+ kafka_producer = initialize_kafka_producer(app_config=cfg)
383
+ state["kafka_producer"] = kafka_producer
384
+ async with AsyncExitStack() as stack:
385
+ session = state["session"]
386
+ state["dynamo_client"] = await stack.enter_async_context(
387
+ session.create_client(
388
+ "dynamodb",
389
+ config=AioConfig(max_pool_connections=MAX_POOL_CONNECTIONS),
390
+ )
391
+ )
392
+ logger.info("DynamoDB persistent client initialized.")
393
+ yield
394
+ logger.info("Shutting down: Connection pools closed.")
395
+ logger.info("Shutting down: Flushing Kafka queue.")
396
+ if kafka_producer is not None:
397
+ try:
398
+ await kafka_producer.flush()
399
+ except Exception:
400
+ logger.error(
401
+ "Unknown exception while flushing kafka queue, shutting down producer.\n%s",
402
+ traceback.format_exc(),
403
+ )
404
+ finally:
405
+ await kafka_producer.close()
406
+
407
+ app = FastAPI(
408
+ title="ML Channel Scorer",
409
+ description="Scores channels using a pre-trained ML model and DynamoDB features.",
410
+ version="1.0.0",
411
+ lifespan=lifespan,
412
+ )
413
+
414
+ @app.get("/health", status_code=HTTPStatus.OK)
415
+ async def health():
416
+ model_ok = state["model"] is not None
417
+ if not model_ok:
418
+ raise HTTPException(
419
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
420
+ detail="ML Model not loaded",
421
+ )
422
+ return {
423
+ "status": "ok",
424
+ "model_loaded": True,
425
+ "global_cache_size": len(global_features_cache),
426
+ "user_cache_size": len(user_features_cache),
427
+ "global_features": state.get("global_features", []),
428
+ "user_features": state.get("user_features", []),
429
+ "model_name": state.get("model_name"),
430
+ }
431
+
432
+ @app.post("/score", status_code=HTTPStatus.OK)
433
+ async def score_channels(
434
+ request: Request, response: Response, background_tasks: BackgroundTasks
435
+ ):
436
+ if state["model"] is None:
437
+ raise HTTPException(
438
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
439
+ detail="ML Model not loaded",
440
+ )
441
+
442
+ try:
443
+ data = await request.json()
444
+ except json.JSONDecodeError as e:
445
+ body = await request.body()
446
+ logger.error(
447
+ "Received malformed json. Raw body: %s\n%s",
448
+ body.decode(errors="replace"),
449
+ traceback.format_exc(),
450
+ )
451
+ raise HTTPException(
452
+ status_code=HTTPStatus.BAD_REQUEST, detail="Invalid JSON payload"
453
+ ) from e
454
+ except Exception as e:
455
+ logger.error(
456
+ "Unexpected exception when parsing request.\n %s",
457
+ traceback.format_exc(),
458
+ )
459
+ raise HTTPException(
460
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Unknown exception"
461
+ ) from e
462
+
463
+ query_params = {}
464
+ for k in request.query_params.keys():
465
+ values = request.query_params.getlist(k)
466
+ query_params[k] = values[0] if len(values) == 1 else values
467
+
468
+ user = data
469
+ channels: List[Dict[str, Any]] = []
470
+
471
+ retrieval_meta = FeatureRetrievalMeta(
472
+ cache_misses=0,
473
+ stream_cache_misses=0,
474
+ user_cache_misses=0,
475
+ retrieval_ms=0,
476
+ success=True,
477
+ cache_delay_minutes=0,
478
+ dynamo_ms=0,
479
+ parsing_ms=0,
480
+ )
481
+ try:
482
+ retrieval_meta = await create_channel_candidates(
483
+ dynamo_client=state["dynamo_client"],
484
+ user=user,
485
+ channels=channels,
486
+ global_features=state.get("global_features", []),
487
+ user_features=state.get("user_features", []),
488
+ global_features_cache=global_features_cache,
489
+ user_features_cache=user_features_cache,
490
+ features_table=cfg.features_table,
491
+ cache_sep=cfg.cache_separator,
492
+ )
493
+ except exceptions.InvalidFeaturesException as e:
494
+ logger.error(
495
+ "The following features are not present in the SerializerRegistry %s",
496
+ e,
497
+ )
498
+ raise HTTPException(
499
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE,
500
+ detail=f"Received invalid features from feature store: {e}",
501
+ ) from e
502
+
503
+ random_number = random.random()
504
+ userid = user.get("userid", "")
505
+
506
+ if random_number < cfg.logs_fraction:
507
+ logger.info(
508
+ "User %s data: %s",
509
+ userid,
510
+ user,
511
+ )
512
+ logger.info(
513
+ "User %s channels: %s",
514
+ userid,
515
+ channels,
516
+ )
517
+
518
+ model = state["model"]
519
+ try:
520
+ preprocess_start = time.perf_counter_ns()
521
+ model["params"]["query_params"] = query_params
522
+ model_input = model["preprocess"](
523
+ user,
524
+ channels,
525
+ model["params"],
526
+ )
527
+ predict_start = time.perf_counter_ns()
528
+ model_output = model["predict"](model_input, model["params"])
529
+ predict_end = time.perf_counter_ns()
530
+ except Exception as e:
531
+ logger.error("Model prediction failed: \n%s", traceback.format_exc())
532
+ raise HTTPException(
533
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
534
+ detail="Model prediction failed",
535
+ ) from e
536
+
537
+ monitoring_meta = {
538
+ "cache_misses": retrieval_meta.cache_misses,
539
+ "user_cache_misses": retrieval_meta.user_cache_misses,
540
+ "global_cache_misses": retrieval_meta.stream_cache_misses,
541
+ "user_cache_size": len(user_features_cache),
542
+ "global_cache_size": len(global_features_cache),
543
+ "retrieval_success": int(retrieval_meta.success),
544
+ "cache_delay_minutes": retrieval_meta.cache_delay_minutes,
545
+ "dynamo_ms": retrieval_meta.dynamo_ms,
546
+ "dynamo_parse_ms": retrieval_meta.parsing_ms,
547
+ "retrieval_ms": retrieval_meta.retrieval_ms,
548
+ "preprocess_ms": (predict_start - preprocess_start) * 1e-6,
549
+ "predict_ms": (predict_end - predict_start) * 1e-6,
550
+ "total_channels": len(model_output),
551
+ }
552
+ newrelic.agent.record_custom_event(
553
+ "ChannelInference",
554
+ monitoring_meta,
555
+ )
556
+ if model_output:
557
+ if random_number < cfg.logs_fraction:
558
+ logger.info(
559
+ "User %s - model output %s",
560
+ userid,
561
+ model_output,
562
+ )
563
+ if should_log_user(userid=userid, kafka_fraction=cfg.kafka_fraction):
564
+ processed_at = datetime.datetime.now(tz=datetime.UTC)
565
+ background_tasks.add_task(
566
+ send_channels_to_kafka,
567
+ producer=state["kafka_producer"],
568
+ topic=cfg.kafka_topic,
569
+ user=user,
570
+ channels=channels,
571
+ state=state,
572
+ model_output=model_output,
573
+ monitoring_meta=monitoring_meta,
574
+ query_params=query_params,
575
+ processed_at=processed_at,
576
+ )
577
+ return jsonable_encoder(model_output)
578
+
579
+ raise HTTPException(
580
+ status_code=HTTPStatus.NOT_FOUND, detail="No model output generated"
581
+ )
582
+
583
+ @app.get("/", status_code=HTTPStatus.OK)
584
+ async def root():
585
+ return {
586
+ "message": "ML Channel Scoring Service is running.",
587
+ "model_name": state.get("model_name"),
588
+ }
589
+
590
+ return app
@@ -1,13 +1,15 @@
1
- from typing import Any, Dict, List, NamedTuple, Literal
1
+ import asyncio
2
+ import datetime
2
3
  import logging
3
4
  import time
4
- import datetime
5
- from boto3.dynamodb.types import TypeDeserializer
5
+ from typing import Any, Dict, List, Literal, NamedTuple
6
+
6
7
  import newrelic.agent
7
- import asyncio
8
- from ._serializers import SerializerRegistry, FeatureRegistryId
9
- from . import exceptions
8
+ from boto3.dynamodb.types import TypeDeserializer
10
9
 
10
+ from . import exceptions
11
+ from ._serializers import FeatureRegistryId, SerializerRegistry
12
+ from .utils import _complete_features_for_channels
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
@@ -21,7 +23,7 @@ class FloatDeserializer(TypeDeserializer):
21
23
 
22
24
 
23
25
  _deser = FloatDeserializer()
24
- IdType = Literal["STREAM", "USER"]
26
+ IdType = Literal["STREAM", "USER", "GLOBAL"]
25
27
 
26
28
 
27
29
  class FeatureRetrievalMeta(NamedTuple):
@@ -268,9 +270,295 @@ async def set_all_features(
268
270
  )
269
271
 
270
272
 
273
+ _MOBILE_OS = {"ios", "android", "iphone", "galaxy"}
274
+
275
+
276
+ def _get_os_cat(client_os: str) -> str:
277
+ normalized = client_os.lower().replace("debug","") if client_os else ""
278
+ return "MOBILE" if normalized in _MOBILE_OS else "TV"
279
+
280
+
281
+ @newrelic.agent.function_trace()
282
+ async def create_channel_candidates(
283
+ *,
284
+ user: Dict[str, Any],
285
+ channels: List[Dict[str, Any]],
286
+ global_features,
287
+ user_features,
288
+ global_features_cache,
289
+ user_features_cache,
290
+ features_table: str,
291
+ cache_sep: str,
292
+ dynamo_client,
293
+ ) -> FeatureRetrievalMeta:
294
+ time_start = time.perf_counter_ns()
295
+
296
+ os_cat = _get_os_cat(user.get("clientOs", ""))
297
+
298
+ global_holder: Dict[str, Any] = {}
299
+ cache_miss: Dict[str, Dict[str, Any]] = {}
300
+ all_feature_keys = [*global_features, *user_features]
301
+ cache_delay_obj: dict[str, float] = {f: 0 for f in all_feature_keys}
302
+ now = datetime.datetime.utcnow()
303
+
304
+ for f in global_features:
305
+ cache_miss, cache_delay_obj = _check_cache(
306
+ obj=global_holder,
307
+ id_type="GLOBAL",
308
+ id_key="GLOBAL",
309
+ feature_key=f,
310
+ cache_sep=cache_sep,
311
+ features_cache=global_features_cache,
312
+ cache_miss=cache_miss,
313
+ cache_delay=cache_delay_obj,
314
+ now=now,
315
+ )
316
+ global_cache_misses = len(cache_miss)
317
+
318
+ for f in user_features:
319
+ cache_miss, cache_delay_obj = _check_cache(
320
+ obj=user,
321
+ id_type="USER",
322
+ id_key=user["userid"],
323
+ feature_key=f,
324
+ cache_sep=cache_sep,
325
+ features_cache=user_features_cache,
326
+ cache_miss=cache_miss,
327
+ cache_delay=cache_delay_obj,
328
+ now=now,
329
+ )
330
+ user_cache_misses = len(cache_miss) - global_cache_misses
331
+
332
+ valid_cache_delays = list(v for v in cache_delay_obj.values() if v > 0)
333
+ cache_delay = min(valid_cache_delays) if valid_cache_delays else 0
334
+
335
+ if not cache_miss:
336
+ _process_channels(
337
+ channels, global_features, user_features, global_holder, user, os_cat
338
+ )
339
+ return FeatureRetrievalMeta(
340
+ user_cache_misses=0,
341
+ stream_cache_misses=0,
342
+ cache_misses=0,
343
+ retrieval_ms=(time.perf_counter_ns() - time_start) * 1e-6,
344
+ success=True,
345
+ cache_delay_minutes=cache_delay / 60,
346
+ dynamo_ms=0,
347
+ parsing_ms=0,
348
+ )
349
+
350
+ cache_misses = len(cache_miss)
351
+ logger.info(
352
+ "Channel candidates cache miss for %d items (%d global, %d user)",
353
+ cache_misses,
354
+ global_cache_misses,
355
+ user_cache_misses,
356
+ )
357
+
358
+ keys = []
359
+ for k in cache_miss.keys():
360
+ id_type, id_key, sk = k.split(cache_sep, 2)
361
+ if id_type == "GLOBAL":
362
+ pk = "GLOBAL"
363
+ else:
364
+ pk = f"{id_type}#{id_key}"
365
+ keys.append({"pk": {"S": pk}, "sk": {"S": sk}})
366
+
367
+ dynamo_start = time.perf_counter_ns()
368
+ try:
369
+ items = await async_batch_get(dynamo_client, features_table, keys)
370
+ except Exception as e:
371
+ logger.error("DynamoDB batch_get failed for channel candidates: %s", e)
372
+ end_time = time.perf_counter_ns()
373
+ return FeatureRetrievalMeta(
374
+ user_cache_misses=user_cache_misses,
375
+ stream_cache_misses=global_cache_misses,
376
+ cache_misses=0,
377
+ retrieval_ms=(end_time - time_start) * 1e-6,
378
+ success=False,
379
+ cache_delay_minutes=cache_delay / 60,
380
+ dynamo_ms=(end_time - dynamo_start) * 1e-6,
381
+ parsing_ms=0,
382
+ )
383
+
384
+ dynamo_end = time.perf_counter_ns()
385
+ updated_keys = set()
386
+ for item in items:
387
+ full_id = item["pk"]["S"]
388
+ if "#" in full_id:
389
+ id_type, id_key = full_id.split("#", 1)
390
+ else:
391
+ id_type = full_id
392
+ id_key = full_id
393
+ feature_name = item["sk"]["S"]
394
+
395
+ if id_type == "GLOBAL":
396
+ cache_to_use = global_features_cache
397
+ elif id_type == "USER":
398
+ cache_to_use = user_features_cache
399
+ else:
400
+ raise ValueError(
401
+ f"Unexpected id type in channel candidates. "
402
+ f"Expected 'GLOBAL' or 'USER', received {id_type}"
403
+ )
404
+
405
+ cache_key = _build_cache_key(
406
+ id_type=id_type,
407
+ id_key=id_key,
408
+ feature_key=feature_name,
409
+ cache_sep=cache_sep,
410
+ )
411
+ parsed = parse_dynamo_item(item)
412
+ feature_version = parsed.get("version", "v0")
413
+ feature_id = FeatureRegistryId(
414
+ entity_type=id_type, feature_id=feature_name, version=feature_version
415
+ )
416
+ try:
417
+ serializer = SerializerRegistry[feature_id]
418
+ except KeyError as e:
419
+ raise exceptions.InvalidFeaturesException(
420
+ f"Could not find '{feature_id}' in serializer registry"
421
+ ) from e
422
+ try:
423
+ value = (
424
+ serializer.deserialize(parsed.get("value"))
425
+ if parsed.get("value")
426
+ else None
427
+ )
428
+ except TypeError as e:
429
+ raise exceptions.DeserializationException(
430
+ f"Ran into an error while deserializing {feature_id}. Error: {e}"
431
+ ) from e
432
+
433
+ cache_to_use[cache_key] = {
434
+ "value": value,
435
+ "cache_ttl_in_seconds": int(parsed.get("cache_ttl_in_seconds", -1)),
436
+ "inserted_at": datetime.datetime.utcnow(),
437
+ }
438
+
439
+ if cache_key in cache_miss:
440
+ cache_miss[cache_key][feature_name] = value
441
+ updated_keys.add(cache_key)
442
+
443
+ parsing_end = time.perf_counter_ns()
444
+
445
+ if len(updated_keys) < len(cache_miss):
446
+ missing_keys = set(cache_miss.keys()) - updated_keys
447
+ for k in missing_keys:
448
+ id_type = _get_id_type_from_partition_key(k, sep=cache_sep)
449
+ if id_type == "GLOBAL":
450
+ global_features_cache[k] = {
451
+ "value": None,
452
+ "cache_ttl_in_seconds": 300,
453
+ }
454
+ elif id_type == "USER":
455
+ user_features_cache[k] = {
456
+ "value": None,
457
+ "cache_ttl_in_seconds": 6 * 3600,
458
+ }
459
+
460
+ _process_channels(
461
+ channels, global_features, user_features, global_holder, user, os_cat
462
+ )
463
+
464
+ end_time = time.perf_counter_ns()
465
+ return FeatureRetrievalMeta(
466
+ cache_misses=global_cache_misses + user_cache_misses,
467
+ user_cache_misses=user_cache_misses,
468
+ stream_cache_misses=global_cache_misses,
469
+ retrieval_ms=_perf_counter_ns_delta_in_ms(time_start, end_time),
470
+ success=True,
471
+ cache_delay_minutes=cache_delay / 60,
472
+ dynamo_ms=_perf_counter_ns_delta_in_ms(dynamo_start, dynamo_end),
473
+ parsing_ms=_perf_counter_ns_delta_in_ms(dynamo_end, parsing_end),
474
+ )
475
+
476
+
477
+ def _process_channels(
478
+ channels: List[Dict[str, Any]],
479
+ global_features: List[str],
480
+ user_features: List[str],
481
+ global_holder: Dict[str, Any],
482
+ user: Dict[str, Any],
483
+ os_cat: str,
484
+ ) -> None:
485
+ channel_candidates = global_holder.get("CHANNEL_CANDIDATES")
486
+
487
+ global_features_os_cat = [f for f in global_features if f.endswith(f"#{os_cat}")]
488
+
489
+ playlist_stats_global = {
490
+ "OS_CAT_" + k.replace(f"#{os_cat}", "").replace("#", "_"): v
491
+ for k, v in global_holder.items()
492
+ if k in global_features_os_cat
493
+ }
494
+ playlist_stats_user = {
495
+ "USER_" + k.replace("#", "_"): v for k, v in user.items() if k in user_features
496
+ }
497
+
498
+ all_channels = set()
499
+ if playlist_stats_global:
500
+ for stat in playlist_stats_global.values():
501
+ all_channels.update(stat.data.keys())
502
+ if playlist_stats_user:
503
+ for stat in playlist_stats_user.values():
504
+ all_channels.update(stat.data.keys())
505
+
506
+ # Get not preferred channels to be ignored
507
+ # The group labels are also ignored, not real channels for UI
508
+ ignore_channels = set(["national_favorite", "local_favorite"])
509
+ inserted_channel_names = set()
510
+ if channel_candidates is not None:
511
+ preferred = set(user.get("preferredChannels", []) or [])
512
+ for ch in channel_candidates.data:
513
+ if ch.category_group in ("national_favorite", "local_favorite"):
514
+ if ch.name not in preferred:
515
+ ignore_channels.add(ch.name)
516
+
517
+ for ch in channel_candidates.data:
518
+ if ch.name not in ignore_channels:
519
+ channels.append(
520
+ {
521
+ "name": ch.name,
522
+ "category_group": ch.category_group,
523
+ "start_date": ch.start_date,
524
+ }
525
+ )
526
+ inserted_channel_names.add(ch.name)
527
+
528
+ DEFAULT_CHANNELS = [
529
+ "local news",
530
+ "science & technology",
531
+ "business & finance",
532
+ "entertainment news",
533
+ "live",
534
+ "live_es",
535
+ "weather",
536
+ "politics",
537
+ "international",
538
+ "top videos",
539
+ "editor picks",
540
+ ]
541
+
542
+ for name in all_channels | set(DEFAULT_CHANNELS):
543
+ if name not in ignore_channels and name not in inserted_channel_names:
544
+ channels.append(
545
+ {
546
+ "name": name,
547
+ "category_group": "",
548
+ "start_date": int(datetime.datetime.now().timestamp()),
549
+ }
550
+ )
551
+
552
+ _complete_features_for_channels(
553
+ channels=channels,
554
+ user_features=playlist_stats_user,
555
+ global_features=playlist_stats_global,
556
+ )
557
+
558
+
271
559
  def _check_cache(
272
560
  obj: dict,
273
- id_type: Literal["STREAM", "USER"],
561
+ id_type: IdType,
274
562
  id_key: str,
275
563
  feature_key: str,
276
564
  cache_sep: str,
@@ -0,0 +1,90 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: features.proto
5
+ # Protobuf Python Version: 6.33.2
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 33,
16
+ 2,
17
+ '',
18
+ 'features.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x66\x65\x61tures.proto\x12\x1ahaystack_ml_stack.features\"7\n\x12\x45ntryContextCounts\x12\x10\n\x08\x61ttempts\x18\x01 \x01(\x05\x12\x0f\n\x07watched\x18\x02 \x01(\x05\"_\n\x0cSelectCounts\x12\x15\n\rtotal_selects\x18\x01 \x01(\x05\x12!\n\x19total_selects_and_watched\x18\x02 \x01(\x05\x12\x15\n\rtotal_browsed\x18\x03 \x01(\x05\"\xf3\x02\n\x14\x45ntryContextPWatched\x12@\n\x08\x61utoplay\x18\x01 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12\x41\n\tsel_thumb\x18\x02 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12\x43\n\x0b\x63hoose_next\x18\x03 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12@\n\x08\x63h_swtch\x18\x04 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12O\n\x17launch_first_in_session\x18\x05 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\"\x85\x02\n\x0fPositionPSelect\x12;\n\tfirst_pos\x18\x01 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\x12<\n\nsecond_pos\x18\x02 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\x12;\n\tthird_pos\x18\x03 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\x12:\n\x08rest_pos\x18\x04 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\"\xa9\x01\n\x1f\x42rowsedDebiasedPositionPSelects\x12\x44\n\x0fup_to_4_browsed\x18\x01 \x01(\x0b\x32+.haystack_ml_stack.features.PositionPSelect\x12@\n\x0b\x61ll_browsed\x18\x02 \x01(\x0b\x32+.haystack_ml_stack.features.PositionPSelect\"\xb8\x01\n\x16PlaylistStatsForGlobal\x12\x15\n\rwatched_count\x18\x01 \x01(\x05\x12\x19\n\x11not_watched_count\x18\x02 \x01(\x05\x12\x1b\n\x13\x63\x61pped_watched_secs\x18\x03 \x01(\x02\x12\x1f\n\x17\x63\x61pped_not_watched_secs\x18\x04 \x01(\x02\x12\x14\n\x0cwatched_secs\x18\x05 \x01(\x02\x12\x18\n\x10not_watched_secs\x18\x06 \x01(\x02\"\x88\x01\n\x14PlaylistStatsForUser\x12\x12\n\ntotal_days\x18\x01 \x01(\x05\x12\x12\n\nstart_days\x18\x02 \x01(\x05\x12\x13\n\x0b\x61\x63tive_days\x18\x03 \x01(\x05\x12\x15\n\rtotal_watched\x18\x04 \x01(\x02\x12\x1c\n\x14\x63\x61pped_total_watched\x18\x05 \x01(\x02\"C\n\x07\x43hannel\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x16\n\x0e\x63\x61tegory_group\x18\x02 \x01(\t\x12\x12\n\nstart_date\x18\x03 \x01(\x05\"k\n\rStreamPSelect\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12I\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32;.haystack_ml_stack.features.BrowsedDebiasedPositionPSelects\"a\n\x0eStreamPWatched\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12>\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.haystack_ml_stack.features.EntryContextPWatched\"_\n\x0cUserPWatched\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12>\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.haystack_ml_stack.features.EntryContextPWatched\"\xda\x01\n\x19UserPersonalizingPWatched\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12M\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32?.haystack_ml_stack.features.UserPersonalizingPWatched.DataEntry\x1a]\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12?\n\x05value\x18\x02 \x01(\x0b\x32\x30.haystack_ml_stack.features.EntryContextPWatched:\x02\x38\x01\"i\n\x0bUserPSelect\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12I\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32;.haystack_ml_stack.features.BrowsedDebiasedPositionPSelects\"\xe3\x01\n\x18UserPersonalizingPSelect\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12L\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32>.haystack_ml_stack.features.UserPersonalizingPSelect.DataEntry\x1ah\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12J\n\x05value\x18\x02 \x01(\x0b\x32;.haystack_ml_stack.features.BrowsedDebiasedPositionPSelects:\x02\x38\x01\"\xa2\x01\n\x16StreamSimilarityScores\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12J\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32<.haystack_ml_stack.features.StreamSimilarityScores.DataEntry\x1a+\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x01:\x02\x38\x01\"\xd0\x01\n\x13GlobalPlaylistStats\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12G\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32\x39.haystack_ml_stack.features.GlobalPlaylistStats.DataEntry\x1a_\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32\x32.haystack_ml_stack.features.PlaylistStatsForGlobal:\x02\x38\x01\"\xca\x01\n\x11UserPlaylistStats\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x45\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32\x37.haystack_ml_stack.features.UserPlaylistStats.DataEntry\x1a]\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12?\n\x05value\x18\x02 \x01(\x0b\x32\x30.haystack_ml_stack.features.PlaylistStatsForUser:\x02\x38\x01\"T\n\x0eGlobalChannels\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x31\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32#.haystack_ml_stack.features.Channelb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'features_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._loaded_options = None
35
+ _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._serialized_options = b'8\001'
36
+ _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._loaded_options = None
37
+ _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._serialized_options = b'8\001'
38
+ _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._loaded_options = None
39
+ _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._serialized_options = b'8\001'
40
+ _globals['_GLOBALPLAYLISTSTATS_DATAENTRY']._loaded_options = None
41
+ _globals['_GLOBALPLAYLISTSTATS_DATAENTRY']._serialized_options = b'8\001'
42
+ _globals['_USERPLAYLISTSTATS_DATAENTRY']._loaded_options = None
43
+ _globals['_USERPLAYLISTSTATS_DATAENTRY']._serialized_options = b'8\001'
44
+ _globals['_ENTRYCONTEXTCOUNTS']._serialized_start=46
45
+ _globals['_ENTRYCONTEXTCOUNTS']._serialized_end=101
46
+ _globals['_SELECTCOUNTS']._serialized_start=103
47
+ _globals['_SELECTCOUNTS']._serialized_end=198
48
+ _globals['_ENTRYCONTEXTPWATCHED']._serialized_start=201
49
+ _globals['_ENTRYCONTEXTPWATCHED']._serialized_end=572
50
+ _globals['_POSITIONPSELECT']._serialized_start=575
51
+ _globals['_POSITIONPSELECT']._serialized_end=836
52
+ _globals['_BROWSEDDEBIASEDPOSITIONPSELECTS']._serialized_start=839
53
+ _globals['_BROWSEDDEBIASEDPOSITIONPSELECTS']._serialized_end=1008
54
+ _globals['_PLAYLISTSTATSFORGLOBAL']._serialized_start=1011
55
+ _globals['_PLAYLISTSTATSFORGLOBAL']._serialized_end=1195
56
+ _globals['_PLAYLISTSTATSFORUSER']._serialized_start=1198
57
+ _globals['_PLAYLISTSTATSFORUSER']._serialized_end=1334
58
+ _globals['_CHANNEL']._serialized_start=1336
59
+ _globals['_CHANNEL']._serialized_end=1403
60
+ _globals['_STREAMPSELECT']._serialized_start=1405
61
+ _globals['_STREAMPSELECT']._serialized_end=1512
62
+ _globals['_STREAMPWATCHED']._serialized_start=1514
63
+ _globals['_STREAMPWATCHED']._serialized_end=1611
64
+ _globals['_USERPWATCHED']._serialized_start=1613
65
+ _globals['_USERPWATCHED']._serialized_end=1708
66
+ _globals['_USERPERSONALIZINGPWATCHED']._serialized_start=1711
67
+ _globals['_USERPERSONALIZINGPWATCHED']._serialized_end=1929
68
+ _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._serialized_start=1836
69
+ _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._serialized_end=1929
70
+ _globals['_USERPSELECT']._serialized_start=1931
71
+ _globals['_USERPSELECT']._serialized_end=2036
72
+ _globals['_USERPERSONALIZINGPSELECT']._serialized_start=2039
73
+ _globals['_USERPERSONALIZINGPSELECT']._serialized_end=2266
74
+ _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._serialized_start=2162
75
+ _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._serialized_end=2266
76
+ _globals['_STREAMSIMILARITYSCORES']._serialized_start=2269
77
+ _globals['_STREAMSIMILARITYSCORES']._serialized_end=2431
78
+ _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._serialized_start=2388
79
+ _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._serialized_end=2431
80
+ _globals['_GLOBALPLAYLISTSTATS']._serialized_start=2434
81
+ _globals['_GLOBALPLAYLISTSTATS']._serialized_end=2642
82
+ _globals['_GLOBALPLAYLISTSTATS_DATAENTRY']._serialized_start=2547
83
+ _globals['_GLOBALPLAYLISTSTATS_DATAENTRY']._serialized_end=2642
84
+ _globals['_USERPLAYLISTSTATS']._serialized_start=2645
85
+ _globals['_USERPLAYLISTSTATS']._serialized_end=2847
86
+ _globals['_USERPLAYLISTSTATS_DATAENTRY']._serialized_start=2754
87
+ _globals['_USERPLAYLISTSTATS_DATAENTRY']._serialized_end=2847
88
+ _globals['_GLOBALCHANNELS']._serialized_start=2849
89
+ _globals['_GLOBALCHANNELS']._serialized_end=2933
90
+ # @@protoc_insertion_point(module_scope)
@@ -1,7 +1,7 @@
1
1
  from google.protobuf.internal import containers as _containers
2
2
  from google.protobuf import descriptor as _descriptor
3
3
  from google.protobuf import message as _message
4
- from collections.abc import Mapping as _Mapping
4
+ from collections.abc import Iterable as _Iterable, Mapping as _Mapping
5
5
  from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
6
6
 
7
7
  DESCRIPTOR: _descriptor.FileDescriptor
@@ -58,6 +58,46 @@ class BrowsedDebiasedPositionPSelects(_message.Message):
58
58
  all_browsed: PositionPSelect
59
59
  def __init__(self, up_to_4_browsed: _Optional[_Union[PositionPSelect, _Mapping]] = ..., all_browsed: _Optional[_Union[PositionPSelect, _Mapping]] = ...) -> None: ...
60
60
 
61
+ class PlaylistStatsForGlobal(_message.Message):
62
+ __slots__ = ()
63
+ WATCHED_COUNT_FIELD_NUMBER: _ClassVar[int]
64
+ NOT_WATCHED_COUNT_FIELD_NUMBER: _ClassVar[int]
65
+ CAPPED_WATCHED_SECS_FIELD_NUMBER: _ClassVar[int]
66
+ CAPPED_NOT_WATCHED_SECS_FIELD_NUMBER: _ClassVar[int]
67
+ WATCHED_SECS_FIELD_NUMBER: _ClassVar[int]
68
+ NOT_WATCHED_SECS_FIELD_NUMBER: _ClassVar[int]
69
+ watched_count: int
70
+ not_watched_count: int
71
+ capped_watched_secs: float
72
+ capped_not_watched_secs: float
73
+ watched_secs: float
74
+ not_watched_secs: float
75
+ def __init__(self, watched_count: _Optional[int] = ..., not_watched_count: _Optional[int] = ..., capped_watched_secs: _Optional[float] = ..., capped_not_watched_secs: _Optional[float] = ..., watched_secs: _Optional[float] = ..., not_watched_secs: _Optional[float] = ...) -> None: ...
76
+
77
+ class PlaylistStatsForUser(_message.Message):
78
+ __slots__ = ()
79
+ TOTAL_DAYS_FIELD_NUMBER: _ClassVar[int]
80
+ START_DAYS_FIELD_NUMBER: _ClassVar[int]
81
+ ACTIVE_DAYS_FIELD_NUMBER: _ClassVar[int]
82
+ TOTAL_WATCHED_FIELD_NUMBER: _ClassVar[int]
83
+ CAPPED_TOTAL_WATCHED_FIELD_NUMBER: _ClassVar[int]
84
+ total_days: int
85
+ start_days: int
86
+ active_days: int
87
+ total_watched: float
88
+ capped_total_watched: float
89
+ def __init__(self, total_days: _Optional[int] = ..., start_days: _Optional[int] = ..., active_days: _Optional[int] = ..., total_watched: _Optional[float] = ..., capped_total_watched: _Optional[float] = ...) -> None: ...
90
+
91
+ class Channel(_message.Message):
92
+ __slots__ = ()
93
+ NAME_FIELD_NUMBER: _ClassVar[int]
94
+ CATEGORY_GROUP_FIELD_NUMBER: _ClassVar[int]
95
+ START_DATE_FIELD_NUMBER: _ClassVar[int]
96
+ name: str
97
+ category_group: str
98
+ start_date: int
99
+ def __init__(self, name: _Optional[str] = ..., category_group: _Optional[str] = ..., start_date: _Optional[int] = ...) -> None: ...
100
+
61
101
  class StreamPSelect(_message.Message):
62
102
  __slots__ = ()
63
103
  VERSION_FIELD_NUMBER: _ClassVar[int]
@@ -134,3 +174,41 @@ class StreamSimilarityScores(_message.Message):
134
174
  version: int
135
175
  data: _containers.ScalarMap[str, float]
136
176
  def __init__(self, version: _Optional[int] = ..., data: _Optional[_Mapping[str, float]] = ...) -> None: ...
177
+
178
+ class GlobalPlaylistStats(_message.Message):
179
+ __slots__ = ()
180
+ class DataEntry(_message.Message):
181
+ __slots__ = ()
182
+ KEY_FIELD_NUMBER: _ClassVar[int]
183
+ VALUE_FIELD_NUMBER: _ClassVar[int]
184
+ key: str
185
+ value: PlaylistStatsForGlobal
186
+ def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[PlaylistStatsForGlobal, _Mapping]] = ...) -> None: ...
187
+ VERSION_FIELD_NUMBER: _ClassVar[int]
188
+ DATA_FIELD_NUMBER: _ClassVar[int]
189
+ version: int
190
+ data: _containers.MessageMap[str, PlaylistStatsForGlobal]
191
+ def __init__(self, version: _Optional[int] = ..., data: _Optional[_Mapping[str, PlaylistStatsForGlobal]] = ...) -> None: ...
192
+
193
+ class UserPlaylistStats(_message.Message):
194
+ __slots__ = ()
195
+ class DataEntry(_message.Message):
196
+ __slots__ = ()
197
+ KEY_FIELD_NUMBER: _ClassVar[int]
198
+ VALUE_FIELD_NUMBER: _ClassVar[int]
199
+ key: str
200
+ value: PlaylistStatsForUser
201
+ def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[PlaylistStatsForUser, _Mapping]] = ...) -> None: ...
202
+ VERSION_FIELD_NUMBER: _ClassVar[int]
203
+ DATA_FIELD_NUMBER: _ClassVar[int]
204
+ version: int
205
+ data: _containers.MessageMap[str, PlaylistStatsForUser]
206
+ def __init__(self, version: _Optional[int] = ..., data: _Optional[_Mapping[str, PlaylistStatsForUser]] = ...) -> None: ...
207
+
208
+ class GlobalChannels(_message.Message):
209
+ __slots__ = ()
210
+ VERSION_FIELD_NUMBER: _ClassVar[int]
211
+ DATA_FIELD_NUMBER: _ClassVar[int]
212
+ version: int
213
+ data: _containers.RepeatedCompositeFieldContainer[Channel]
214
+ def __init__(self, version: _Optional[int] = ..., data: _Optional[_Iterable[_Union[Channel, _Mapping]]] = ...) -> None: ...
@@ -9,7 +9,7 @@ class Settings(BaseSettings):
9
9
  default=None, alias="KAFKA_BOOTSTRAP_SERVERS"
10
10
  )
11
11
  kafka_fraction: float = Field(0.01, alias="KAFKA_FRACTION")
12
- kafka_topic: str = Field(default=None, alias="KAFKA_TOPIC")
12
+ kafka_topic: str | None = Field(default=None, alias="KAFKA_TOPIC")
13
13
 
14
14
  # Model (S3)
15
15
  s3_model_path: str | None = Field(default=None, alias="S3_MODEL_PATH")
@@ -21,6 +21,7 @@ class Settings(BaseSettings):
21
21
  # Cache
22
22
  stream_cache_maxsize: int = 50_000
23
23
  user_cache_maxsize: int = 80_000
24
+ global_cache_maxsize: int = 1_000
24
25
  cache_separator: str = "--"
25
26
 
26
27
  class Config:
@@ -13,6 +13,10 @@ from .generated.v1.features_pb2 import (
13
13
  UserPersonalizingPSelect,
14
14
  UserPSelect,
15
15
  EntryContextPWatched,
16
+ PlaylistStatsForGlobal,
17
+ PlaylistStatsForUser,
18
+ UserPlaylistStats,
19
+ GlobalPlaylistStats,
16
20
  )
17
21
  from ._serializers import SerializerRegistry
18
22
  from . import exceptions
@@ -673,3 +677,54 @@ def _validate_pwatched_entry_context(entry_contexts: list[str]):
673
677
  invalid_contexts = [c for c in entry_contexts if c not in valid_contexts]
674
678
  if invalid_contexts:
675
679
  raise ValueError(f"Invalid entry contexts found: {invalid_contexts}")
680
+
681
+
682
+ def _complete_features_for_channels(
683
+ channels: list[dict],
684
+ user_features: dict[str, UserPlaylistStats],
685
+ global_features: dict[str, GlobalPlaylistStats],
686
+ ) -> None:
687
+
688
+ GLOBAL_FEATURES = [
689
+ "watched_count",
690
+ "not_watched_count",
691
+ "capped_watched_secs",
692
+ "capped_not_watched_secs",
693
+ "watched_secs",
694
+ "not_watched_secs",
695
+ ]
696
+
697
+ USER_FEATURES = [
698
+ "total_days",
699
+ "start_days",
700
+ "active_days",
701
+ "total_watched",
702
+ "capped_total_watched",
703
+ ]
704
+
705
+ for ch in channels:
706
+ name = ch.get("name", "")
707
+ category_group = ch.get("category_group", "")
708
+ for prefix, global_feature in global_features.items():
709
+ global_feature = global_feature or GlobalPlaylistStats()
710
+ ch[prefix] = {}
711
+
712
+ if name in global_feature.data:
713
+ features = global_feature.data.get(name, PlaylistStatsForGlobal())
714
+ for feature in GLOBAL_FEATURES:
715
+ ch[prefix][feature] = getattr(features, feature, 0)
716
+ else:
717
+ # Some global features are at the category group level instead of the channel level
718
+ features = global_feature.data.get(
719
+ category_group, PlaylistStatsForGlobal()
720
+ )
721
+ for feature in GLOBAL_FEATURES:
722
+ ch[prefix][feature] = getattr(features, feature, 0)
723
+
724
+ for prefix, user_feature in user_features.items():
725
+ user_feature = user_feature or UserPlaylistStats()
726
+ ch[prefix] = {}
727
+ for feature in USER_FEATURES:
728
+ ch[prefix][feature] = getattr(
729
+ user_feature.data.get(name, PlaylistStatsForUser()), feature, 0
730
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.4.3
3
+ Version: 0.4.5
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
@@ -1 +0,0 @@
1
- __version__ = "0.4.3"
@@ -1,70 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Generated by the protocol buffer compiler. DO NOT EDIT!
3
- # NO CHECKED-IN PROTOBUF GENCODE
4
- # source: features.proto
5
- # Protobuf Python Version: 6.33.2
6
- """Generated protocol buffer code."""
7
- from google.protobuf import descriptor as _descriptor
8
- from google.protobuf import descriptor_pool as _descriptor_pool
9
- from google.protobuf import runtime_version as _runtime_version
10
- from google.protobuf import symbol_database as _symbol_database
11
- from google.protobuf.internal import builder as _builder
12
- _runtime_version.ValidateProtobufRuntimeVersion(
13
- _runtime_version.Domain.PUBLIC,
14
- 6,
15
- 33,
16
- 2,
17
- '',
18
- 'features.proto'
19
- )
20
- # @@protoc_insertion_point(imports)
21
-
22
- _sym_db = _symbol_database.Default()
23
-
24
-
25
-
26
-
27
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x66\x65\x61tures.proto\x12\x1ahaystack_ml_stack.features\"7\n\x12\x45ntryContextCounts\x12\x10\n\x08\x61ttempts\x18\x01 \x01(\x05\x12\x0f\n\x07watched\x18\x02 \x01(\x05\"_\n\x0cSelectCounts\x12\x15\n\rtotal_selects\x18\x01 \x01(\x05\x12!\n\x19total_selects_and_watched\x18\x02 \x01(\x05\x12\x15\n\rtotal_browsed\x18\x03 \x01(\x05\"\xf3\x02\n\x14\x45ntryContextPWatched\x12@\n\x08\x61utoplay\x18\x01 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12\x41\n\tsel_thumb\x18\x02 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12\x43\n\x0b\x63hoose_next\x18\x03 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12@\n\x08\x63h_swtch\x18\x04 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\x12O\n\x17launch_first_in_session\x18\x05 \x01(\x0b\x32..haystack_ml_stack.features.EntryContextCounts\"\x85\x02\n\x0fPositionPSelect\x12;\n\tfirst_pos\x18\x01 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\x12<\n\nsecond_pos\x18\x02 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\x12;\n\tthird_pos\x18\x03 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\x12:\n\x08rest_pos\x18\x04 \x01(\x0b\x32(.haystack_ml_stack.features.SelectCounts\"\xa9\x01\n\x1f\x42rowsedDebiasedPositionPSelects\x12\x44\n\x0fup_to_4_browsed\x18\x01 \x01(\x0b\x32+.haystack_ml_stack.features.PositionPSelect\x12@\n\x0b\x61ll_browsed\x18\x02 \x01(\x0b\x32+.haystack_ml_stack.features.PositionPSelect\"k\n\rStreamPSelect\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12I\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32;.haystack_ml_stack.features.BrowsedDebiasedPositionPSelects\"a\n\x0eStreamPWatched\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12>\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.haystack_ml_stack.features.EntryContextPWatched\"_\n\x0cUserPWatched\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12>\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.haystack_ml_stack.features.EntryContextPWatched\"\xda\x01\n\x19UserPersonalizingPWatched\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12M\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32?.haystack_ml_stack.features.UserPersonalizingPWatched.DataEntry\x1a]\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12?\n\x05value\x18\x02 \x01(\x0b\x32\x30.haystack_ml_stack.features.EntryContextPWatched:\x02\x38\x01\"i\n\x0bUserPSelect\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12I\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32;.haystack_ml_stack.features.BrowsedDebiasedPositionPSelects\"\xe3\x01\n\x18UserPersonalizingPSelect\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12L\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32>.haystack_ml_stack.features.UserPersonalizingPSelect.DataEntry\x1ah\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12J\n\x05value\x18\x02 \x01(\x0b\x32;.haystack_ml_stack.features.BrowsedDebiasedPositionPSelects:\x02\x38\x01\"\xa2\x01\n\x16StreamSimilarityScores\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12J\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32<.haystack_ml_stack.features.StreamSimilarityScores.DataEntry\x1a+\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x01:\x02\x38\x01\x62\x06proto3')
28
-
29
- _globals = globals()
30
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'features_pb2', _globals)
32
- if not _descriptor._USE_C_DESCRIPTORS:
33
- DESCRIPTOR._loaded_options = None
34
- _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._loaded_options = None
35
- _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._serialized_options = b'8\001'
36
- _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._loaded_options = None
37
- _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._serialized_options = b'8\001'
38
- _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._loaded_options = None
39
- _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._serialized_options = b'8\001'
40
- _globals['_ENTRYCONTEXTCOUNTS']._serialized_start=46
41
- _globals['_ENTRYCONTEXTCOUNTS']._serialized_end=101
42
- _globals['_SELECTCOUNTS']._serialized_start=103
43
- _globals['_SELECTCOUNTS']._serialized_end=198
44
- _globals['_ENTRYCONTEXTPWATCHED']._serialized_start=201
45
- _globals['_ENTRYCONTEXTPWATCHED']._serialized_end=572
46
- _globals['_POSITIONPSELECT']._serialized_start=575
47
- _globals['_POSITIONPSELECT']._serialized_end=836
48
- _globals['_BROWSEDDEBIASEDPOSITIONPSELECTS']._serialized_start=839
49
- _globals['_BROWSEDDEBIASEDPOSITIONPSELECTS']._serialized_end=1008
50
- _globals['_STREAMPSELECT']._serialized_start=1010
51
- _globals['_STREAMPSELECT']._serialized_end=1117
52
- _globals['_STREAMPWATCHED']._serialized_start=1119
53
- _globals['_STREAMPWATCHED']._serialized_end=1216
54
- _globals['_USERPWATCHED']._serialized_start=1218
55
- _globals['_USERPWATCHED']._serialized_end=1313
56
- _globals['_USERPERSONALIZINGPWATCHED']._serialized_start=1316
57
- _globals['_USERPERSONALIZINGPWATCHED']._serialized_end=1534
58
- _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._serialized_start=1441
59
- _globals['_USERPERSONALIZINGPWATCHED_DATAENTRY']._serialized_end=1534
60
- _globals['_USERPSELECT']._serialized_start=1536
61
- _globals['_USERPSELECT']._serialized_end=1641
62
- _globals['_USERPERSONALIZINGPSELECT']._serialized_start=1644
63
- _globals['_USERPERSONALIZINGPSELECT']._serialized_end=1871
64
- _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._serialized_start=1767
65
- _globals['_USERPERSONALIZINGPSELECT_DATAENTRY']._serialized_end=1871
66
- _globals['_STREAMSIMILARITYSCORES']._serialized_start=1874
67
- _globals['_STREAMSIMILARITYSCORES']._serialized_end=2036
68
- _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._serialized_start=1993
69
- _globals['_STREAMSIMILARITYSCORES_DATAENTRY']._serialized_end=2036
70
- # @@protoc_insertion_point(module_scope)