haystack-ml-stack 0.2.3__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/PKG-INFO +1 -1
  2. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/pyproject.toml +1 -1
  3. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack/__init__.py +1 -1
  4. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack/app.py +66 -28
  5. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack/cache.py +2 -2
  6. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack/dynamo.py +79 -75
  7. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack/utils.py +125 -85
  8. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack.egg-info/PKG-INFO +1 -1
  9. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack.egg-info/SOURCES.txt +2 -1
  10. haystack_ml_stack-0.2.5/tests/test_utils.py +76 -0
  11. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/README.md +0 -0
  12. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/setup.cfg +0 -0
  13. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack/model_store.py +0 -0
  14. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack/settings.py +0 -0
  15. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack.egg-info/dependency_links.txt +0 -0
  16. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack.egg-info/requires.txt +0 -0
  17. {haystack_ml_stack-0.2.3 → haystack_ml_stack-0.2.5}/src/haystack_ml_stack.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
5
5
 
6
6
  [project]
7
7
  name = "haystack-ml-stack"
8
- version = "0.2.3"
8
+ version = "0.2.5"
9
9
  description = "Functions related to Haystack ML"
10
10
  readme = "README.md"
11
11
  authors = [{ name = "Oscar Vega", email = "oscar@haystack.tv" }]
@@ -1,4 +1,4 @@
1
1
  from .app import create_app
2
2
 
3
3
  __all__ = ["create_app"]
4
- __version__ = "0.2.3"
4
+ __version__ = "0.2.5"
@@ -5,10 +5,13 @@ import sys
5
5
  from http import HTTPStatus
6
6
  from typing import Any, Dict, List, Optional
7
7
  import time
8
+ from contextlib import asynccontextmanager, AsyncExitStack
8
9
 
9
10
  import aiobotocore.session
11
+ from aiobotocore.config import AioConfig
10
12
  from fastapi import FastAPI, HTTPException, Request, Response
11
13
  from fastapi.encoders import jsonable_encoder
14
+ import newrelic.agent
12
15
 
13
16
 
14
17
  from .cache import make_features_cache
@@ -24,8 +27,7 @@ logging.basicConfig(
24
27
  )
25
28
 
26
29
  logger = logging.getLogger(__name__)
27
-
28
- import newrelic.agent
30
+ MAX_POOL_CONNECTIONS = int(os.environ.get("MAX_POOL_CONNECTIONS", 50))
29
31
 
30
32
 
31
33
  def create_app(
@@ -39,12 +41,6 @@ def create_app(
39
41
  """
40
42
  cfg = settings or Settings()
41
43
 
42
- app = FastAPI(
43
- title="ML Stream Scorer",
44
- description="Scores video streams using a pre-trained ML model and DynamoDB features.",
45
- version="1.0.0",
46
- )
47
-
48
44
  # Mutable state: cache + model
49
45
  features_cache = make_features_cache(cfg.cache_maxsize)
50
46
  state: Dict[str, Any] = {
@@ -55,24 +51,59 @@ def create_app(
55
51
  ),
56
52
  }
57
53
 
58
- @app.on_event("startup")
59
- async def _startup() -> None:
60
- if state["model"] is not None:
61
- logger.info("Using preloaded model.")
62
- return
63
-
64
- if not cfg.s3_model_path:
65
- logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
66
- return
67
-
68
- try:
69
- state["model"] = await download_and_load_model(
70
- cfg.s3_model_path, aio_session=state["session"]
54
+ @asynccontextmanager
55
+ async def lifespan(app_server: FastAPI):
56
+ """
57
+ Handles startup and shutdown logic.
58
+ Everything before 'yield' runs on startup.
59
+ Everything after 'yield' runs on shutdown.
60
+ """
61
+ async with AsyncExitStack() as stack:
62
+ # 1. Initialize DynamoDB Client (Persistent Pool)
63
+ session = state["session"]
64
+ state["dynamo_client"] = await stack.enter_async_context(
65
+ session.create_client(
66
+ "dynamodb",
67
+ # Ensure the pool is large enough for ML concurrency
68
+ config=AioConfig(max_pool_connections=MAX_POOL_CONNECTIONS),
69
+ )
71
70
  )
72
- state["stream_features"] = state["model"].get("stream_features", [])
73
- logger.info("Model loaded on startup.")
74
- except Exception as e:
75
- logger.critical("Failed to load model: %s", e)
71
+ logger.info("DynamoDB persistent client initialized.")
72
+
73
+ # 2. Load ML Model
74
+ if state["model"] is None:
75
+ if not cfg.s3_model_path:
76
+ logger.critical("S3_MODEL_PATH not set; service will be unhealthy.")
77
+ else:
78
+ try:
79
+ # Pass the persistent session/client if needed
80
+ state["model"] = await download_and_load_model(
81
+ cfg.s3_model_path, aio_session=state["session"]
82
+ )
83
+ state["stream_features"] = state["model"].get(
84
+ "stream_features", []
85
+ )
86
+ state["user_features"] = state["model"].get("user_features", [])
87
+
88
+ newrelic.agent.add_custom_attribute(
89
+ "total_stream_features", len(state["stream_features"])
90
+ )
91
+ logger.info("Model loaded successfully.")
92
+ except Exception as e:
93
+ logger.critical("Failed to load model: %s", e)
94
+
95
+ yield
96
+
97
+ # 3. Shutdown Logic
98
+ # The AsyncExitStack automatically closes the DynamoDB client pool here
99
+ logger.info("Shutting down: Connection pools closed.")
100
+
101
+ app = FastAPI(
102
+ title="ML Stream Scorer",
103
+ description="Scores video streams using a pre-trained ML model and DynamoDB features.",
104
+ version="1.0.0",
105
+ lifespan=lifespan,
106
+ )
76
107
 
77
108
  @app.get("/health", status_code=HTTPStatus.OK)
78
109
  async def health():
@@ -121,11 +152,16 @@ def create_app(
121
152
  model = state["model"]
122
153
  stream_features = model.get("stream_features", []) or []
123
154
  retrieval_meta = FeatureRetrievalMeta(
124
- cache_misses=0, retrieval_ms=0, success=True, cache_delay_minutes=0
155
+ cache_misses=0,
156
+ retrieval_ms=0,
157
+ success=True,
158
+ cache_delay_minutes=0,
159
+ dynamo_ms=0,
160
+ parsing_ms=0,
125
161
  )
126
162
  if stream_features:
127
163
  retrieval_meta = await set_stream_features(
128
- aio_session=state["session"],
164
+ dynamo_client=state["dynamo_client"],
129
165
  streams=streams,
130
166
  stream_features=stream_features,
131
167
  features_cache=features_cache,
@@ -166,10 +202,12 @@ def create_app(
166
202
  "cache_misses": retrieval_meta.cache_misses,
167
203
  "retrieval_success": int(retrieval_meta.success),
168
204
  "cache_delay_minutes": retrieval_meta.cache_delay_minutes,
205
+ "dynamo_ms": retrieval_meta.dynamo_ms,
206
+ "dynamo_parse_ms": retrieval_meta.parsing_ms,
169
207
  "retrieval_ms": retrieval_meta.retrieval_ms,
170
208
  "preprocess_ms": (predict_start - preprocess_start) * 1e-6,
171
209
  "predict_ms": (predict_end - predict_start) * 1e-6,
172
- "total_scores": len(model_output),
210
+ "total_streams": len(model_output),
173
211
  },
174
212
  )
175
213
  if model_output:
@@ -5,14 +5,14 @@ from cachetools import TLRUCache
5
5
 
6
6
  def _ttu(_, value: Any, now: float) -> float:
7
7
  """Time-To-Use policy: allow per-item TTL via 'cache_ttl_in_seconds' or fallback."""
8
- ONE_YEAR = 365 * 24 * 60 * 60
8
+ ONE_WEEK = 7 * 24 * 60 * 60
9
9
  try:
10
10
  ttl = int(value.get("cache_ttl_in_seconds", -1))
11
11
  if ttl > 0:
12
12
  return now + ttl
13
13
  except Exception:
14
14
  pass
15
- return now + ONE_YEAR
15
+ return now + ONE_WEEK
16
16
 
17
17
 
18
18
  def make_features_cache(maxsize: int) -> TLRUCache:
@@ -2,19 +2,29 @@ from typing import Any, Dict, List, NamedTuple
2
2
  import logging
3
3
  import time
4
4
  import datetime
5
-
6
- import aiobotocore.session
5
+ from boto3.dynamodb.types import TypeDeserializer
7
6
  import newrelic.agent
7
+ import asyncio
8
8
 
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
12
 
13
+ class FloatDeserializer(TypeDeserializer):
14
+ def _deserialize_n(self, value):
15
+ return float(value)
16
+
17
+
18
+ _deser = FloatDeserializer()
19
+
20
+
13
21
  class FeatureRetrievalMeta(NamedTuple):
14
22
  cache_misses: int
15
23
  retrieval_ms: float
16
24
  success: bool
17
25
  cache_delay_minutes: float
26
+ dynamo_ms: float
27
+ parsing_ms: float
18
28
 
19
29
 
20
30
  @newrelic.agent.function_trace()
@@ -25,68 +35,54 @@ async def async_batch_get(
25
35
  Asynchronous batch_get_item with chunking for requests > 100 keys
26
36
  and handling for unprocessed keys.
27
37
  """
28
- all_items: List[Dict[str, Any]] = []
29
38
  # DynamoDB's BatchGetItem has a 100-item limit per request.
30
39
  CHUNK_SIZE = 100
31
40
 
32
- # Split the keys into chunks of 100
33
- for i in range(0, len(keys), CHUNK_SIZE):
34
- chunk_keys = keys[i : i + CHUNK_SIZE]
35
- to_fetch = {table_name: {"Keys": chunk_keys}}
36
-
37
- # Inner loop to handle unprocessed keys for the current chunk
38
- # Max retries of 3
39
- retries = 3
40
- while to_fetch and retries > 0:
41
- retries -= 1
42
- try:
43
- resp = await dynamo_client.batch_get_item(RequestItems=to_fetch)
44
-
45
- if "Responses" in resp and table_name in resp["Responses"]:
46
- all_items.extend(resp["Responses"][table_name])
47
-
48
- unprocessed = resp.get("UnprocessedKeys", {})
49
- # If there are unprocessed keys, set them to be fetched in the next iteration
50
- if unprocessed and unprocessed.get(table_name):
51
- logger.warning(
52
- "Retrying %d unprocessed keys.",
53
- len(unprocessed[table_name]["Keys"]),
54
- )
55
- to_fetch = unprocessed
56
- else:
57
- # All keys in the chunk were processed, exit the inner loop
58
- to_fetch = {}
59
-
60
- except Exception as e:
61
- logger.error("Error during batch_get_item for a chunk: %s", e)
62
- # Stop trying to process this chunk on error and move to the next
41
+ if len(keys) <= CHUNK_SIZE:
42
+ all_items = await _fetch_chunk(dynamo_client, table_name, keys)
43
+ else:
44
+ chunks = [keys[i : i + CHUNK_SIZE] for i in range(0, len(keys), CHUNK_SIZE)]
45
+ tasks = [_fetch_chunk(dynamo_client, table_name, chunk) for chunk in chunks]
46
+ results = await asyncio.gather(*tasks)
47
+ all_items = [item for batch in results for item in batch]
48
+ return all_items
49
+
50
+
51
+ async def _fetch_chunk(dynamo_client, table_name: str, chunk_keys):
52
+ """Fetch a single chunk of up to 100 keys with retry handling."""
53
+ to_fetch = {table_name: {"Keys": chunk_keys}}
54
+ retries = 3
55
+ items = []
56
+
57
+ while to_fetch and retries > 0:
58
+ retries -= 1
59
+ try:
60
+ resp = await dynamo_client.batch_get_item(RequestItems=to_fetch)
61
+
62
+ # Collect retrieved items
63
+ if "Responses" in resp and table_name in resp["Responses"]:
64
+ items.extend(resp["Responses"][table_name])
65
+
66
+ # Check for unprocessed keys
67
+ unprocessed = resp.get("UnprocessedKeys", {})
68
+ if unprocessed and unprocessed.get(table_name):
69
+ unp = unprocessed[table_name]["Keys"]
70
+ logger.warning("Retrying %d unprocessed keys.", len(unp))
71
+ to_fetch = {table_name: {"Keys": unp}}
72
+ else:
63
73
  to_fetch = {}
64
74
 
65
- return all_items
75
+ except Exception as e:
76
+ logger.error("Error in batch_get_item chunk: %s", e)
77
+ break
78
+
79
+ return items
66
80
 
67
81
 
68
- @newrelic.agent.function_trace()
69
82
  def parse_dynamo_item(item: Dict[str, Any]) -> Dict[str, Any]:
70
83
  """Parse a DynamoDB attribute map (low-level) to Python types."""
71
- out: Dict[str, Any] = {}
72
- for k, v in item.items():
73
- if "N" in v:
74
- out[k] = float(v["N"])
75
- elif "S" in v:
76
- out[k] = v["S"]
77
- elif "SS" in v:
78
- out[k] = v["SS"]
79
- elif "NS" in v:
80
- out[k] = [float(n) for n in v["NS"]]
81
- elif "BOOL" in v:
82
- out[k] = v["BOOL"]
83
- elif "NULL" in v:
84
- out[k] = None
85
- elif "L" in v:
86
- out[k] = [parse_dynamo_item({"value": i})["value"] for i in v["L"]]
87
- elif "M" in v:
88
- out[k] = parse_dynamo_item(v["M"])
89
- return out
84
+ # out: Dict[str, Any] = {}
85
+ return {k: _deser.deserialize(v) for k, v in item.items()}
90
86
 
91
87
 
92
88
  @newrelic.agent.function_trace()
@@ -98,7 +94,7 @@ async def set_stream_features(
98
94
  features_table: str,
99
95
  stream_pk_prefix: str,
100
96
  cache_sep: str,
101
- aio_session: aiobotocore.session.Session | None = None,
97
+ dynamo_client,
102
98
  ) -> FeatureRetrievalMeta:
103
99
  time_start = time.perf_counter_ns()
104
100
  """Fetch missing features for streams from DynamoDB and fill them into streams."""
@@ -108,6 +104,8 @@ async def set_stream_features(
108
104
  retrieval_ms=(time.perf_counter_ns() - time_start) * 1e-6,
109
105
  success=True,
110
106
  cache_delay_minutes=0,
107
+ dynamo_ms=0,
108
+ parsing_ms=0,
111
109
  )
112
110
 
113
111
  cache_miss: Dict[str, Dict[str, Any]] = {}
@@ -122,7 +120,8 @@ async def set_stream_features(
122
120
  if cached["value"] is not None:
123
121
  s[f] = cached["value"]
124
122
  cache_delay_obj[f] = max(
125
- cache_delay_obj[f], (now - cached["updated_at"]).total_seconds()
123
+ cache_delay_obj[f],
124
+ (now - cached["inserted_at"]).total_seconds(),
126
125
  )
127
126
  else:
128
127
  cache_miss[key] = s
@@ -135,6 +134,8 @@ async def set_stream_features(
135
134
  retrieval_ms=(time.perf_counter_ns() - time_start) * 1e-6,
136
135
  success=True,
137
136
  cache_delay_minutes=cache_delay / 60,
137
+ dynamo_ms=0,
138
+ parsing_ms=0,
138
139
  )
139
140
  cache_misses = len(cache_miss)
140
141
  logger.info("Cache miss for %d items", cache_misses)
@@ -146,19 +147,21 @@ async def set_stream_features(
146
147
  pk = f"{stream_pk_prefix}{stream_url}"
147
148
  keys.append({"pk": {"S": pk}, "sk": {"S": sk}})
148
149
 
149
- session = aio_session or aiobotocore.session.get_session()
150
- async with session.create_client("dynamodb") as dynamodb:
151
- try:
152
- items = await async_batch_get(dynamodb, features_table, keys)
153
- except Exception as e:
154
- logger.error("DynamoDB batch_get failed: %s", e)
155
- return FeatureRetrievalMeta(
156
- cache_misses=cache_misses,
157
- retrieval_ms=(time.perf_counter_ns() - time_start) * 1e-6,
158
- success=False,
159
- cache_delay_minutes=cache_delay / 60,
160
- )
161
-
150
+ dynamo_start = time.perf_counter_ns()
151
+ try:
152
+ items = await async_batch_get(dynamo_client, features_table, keys)
153
+ except Exception as e:
154
+ logger.error("DynamoDB batch_get failed: %s", e)
155
+ end_time = time.perf_counter_ns()
156
+ return FeatureRetrievalMeta(
157
+ cache_misses=cache_misses,
158
+ retrieval_ms=(end_time - time_start) * 1e-6,
159
+ success=False,
160
+ cache_delay_minutes=cache_delay / 60,
161
+ dynamo_ms=(end_time - dynamo_start) * 1e-6,
162
+ parsing_ms=0,
163
+ )
164
+ dynamo_end = time.perf_counter_ns()
162
165
  updated_keys = set()
163
166
  for item in items:
164
167
  stream_url = item["pk"]["S"].removeprefix(stream_pk_prefix)
@@ -169,22 +172,23 @@ async def set_stream_features(
169
172
  features_cache[cache_key] = {
170
173
  "value": parsed.get("value"),
171
174
  "cache_ttl_in_seconds": int(parsed.get("cache_ttl_in_seconds", -1)),
172
- "updated_at": datetime.datetime.fromisoformat(
173
- parsed.get("updated_at")
174
- ).replace(tzinfo=None),
175
+ "inserted_at": datetime.datetime.utcnow(),
175
176
  }
176
177
  if cache_key in cache_miss:
177
178
  cache_miss[cache_key][feature_name] = parsed.get("value")
178
179
  updated_keys.add(cache_key)
179
-
180
+ parsing_end = time.perf_counter_ns()
180
181
  # Save keys that were not found in DynamoDB with None value
181
182
  if len(updated_keys) < len(cache_miss):
182
183
  missing_keys = set(cache_miss.keys()) - updated_keys
183
184
  for k in missing_keys:
184
185
  features_cache[k] = {"value": None, "cache_ttl_in_seconds": 300}
186
+ end_time = time.perf_counter_ns()
185
187
  return FeatureRetrievalMeta(
186
188
  cache_misses=cache_misses,
187
- retrieval_ms=(time.perf_counter_ns() - time_start) * 1e-6,
189
+ retrieval_ms=(end_time - time_start) * 1e-6,
188
190
  success=True,
189
191
  cache_delay_minutes=cache_delay / 60,
192
+ dynamo_ms=(dynamo_end - dynamo_start) * 1e-6,
193
+ parsing_ms=(parsing_end - dynamo_end) * 1e-6,
190
194
  )
@@ -4,8 +4,13 @@ import typing as _t
4
4
 
5
5
 
6
6
  def stream_favorites_cleanup(
7
- stream, user_favorite_tags: list[str], user_favorite_authors: list[str]
7
+ stream,
8
+ user_favorite_tags: list[str],
9
+ user_favorite_authors: list[str],
10
+ out: dict = None,
8
11
  ) -> dict:
12
+ if out is None:
13
+ out = {}
9
14
  stream_tags = stream.get("haystackTags", [])
10
15
  is_favorite_tag = (
11
16
  any(stream_tag in user_favorite_tags for stream_tag in stream_tags)
@@ -17,15 +22,15 @@ def stream_favorites_cleanup(
17
22
  if user_favorite_authors is not None
18
23
  else False
19
24
  )
20
- return {
21
- "IS_FAVORITE_TAG": is_favorite_tag,
22
- "IS_FAVORITE_AUTHOR": is_favorite_author,
23
- }
25
+ out["IS_FAVORITE_TAG"] = is_favorite_tag
26
+ out["IS_FAVORITE_AUTHOR"] = is_favorite_author
27
+ return out
24
28
 
25
29
 
26
30
  def browsed_count_cleanups(
27
31
  stream,
28
32
  position_debiasing: _t.Literal["4_browsed", "all_browsed"] = "4_browsed",
33
+ out: dict = None,
29
34
  ) -> dict:
30
35
  position_alias_mapping = {
31
36
  "0": "1ST_POS",
@@ -43,7 +48,8 @@ def browsed_count_cleanups(
43
48
  total_selects = 0
44
49
  total_browsed = 0
45
50
  total_selects_and_watched = 0
46
- feats = {}
51
+ if out is None:
52
+ out = {}
47
53
  for position in position_alias_mapping.keys():
48
54
  pos_counts = browsed_count_obj.get(position, {})
49
55
  total_browsed += pos_counts.get("total_browsed", 0)
@@ -55,16 +61,17 @@ def browsed_count_cleanups(
55
61
  suffix = ""
56
62
  else:
57
63
  raise ValueError("Should not be here.")
58
- feats[f"STREAM_24H_TOTAL_BROWSED{suffix}"] = total_browsed
59
- feats[f"STREAM_24H_TOTAL_SELECTS{suffix}"] = total_selects
60
- feats[f"STREAM_24H_TOTAL_SELECTS_AND_WATCHED{suffix}"] = total_selects_and_watched
61
- return feats
64
+ out[f"STREAM_24H_TOTAL_BROWSED{suffix}"] = total_browsed
65
+ out[f"STREAM_24H_TOTAL_SELECTS{suffix}"] = total_selects
66
+ out[f"STREAM_24H_TOTAL_SELECTS_AND_WATCHED{suffix}"] = total_selects_and_watched
67
+ return out
62
68
 
63
69
 
64
70
  def device_split_browsed_count_cleanups(
65
71
  stream,
66
72
  device_type: _t.Literal["TV", "MOBILE"],
67
73
  position_debiasing: _t.Literal["4_browsed", "all_browsed"] = "4_browsed",
74
+ out: dict = None,
68
75
  ) -> dict:
69
76
  position_alias_mapping = {
70
77
  "0": "1ST_POS",
@@ -87,21 +94,24 @@ def device_split_browsed_count_cleanups(
87
94
  total_selects = 0
88
95
  total_browsed = 0
89
96
  total_selects_and_watched = 0
90
- feats = {}
97
+ if out is None:
98
+ out = {}
91
99
  for position, alias in position_alias_mapping.items():
92
100
  pos_counts = browsed_count_obj.get(position, {})
93
101
  total_browsed = pos_counts.get("total_browsed", 0)
94
102
  total_selects = pos_counts.get("total_selects", 0)
95
103
  total_selects_and_watched = pos_counts.get("total_selects_and_watched", 0)
96
- feats[f"STREAM_{alias}_{device_type}_24H_TOTAL_BROWSED{suffix}"] = total_browsed
97
- feats[f"STREAM_{alias}_{device_type}_24H_TOTAL_SELECTS{suffix}"] = total_selects
98
- feats[f"STREAM_{alias}_{device_type}_24H_TOTAL_SELECTS_AND_WATCHED{suffix}"] = (
104
+ out[f"STREAM_{alias}_{device_type}_24H_TOTAL_BROWSED{suffix}"] = total_browsed
105
+ out[f"STREAM_{alias}_{device_type}_24H_TOTAL_SELECTS{suffix}"] = total_selects
106
+ out[f"STREAM_{alias}_{device_type}_24H_TOTAL_SELECTS_AND_WATCHED{suffix}"] = (
99
107
  total_selects_and_watched
100
108
  )
101
- return feats
109
+ return out
102
110
 
103
111
 
104
- def watched_count_cleanups(stream, entry_contexts: list[str] = None) -> dict:
112
+ def watched_count_cleanups(
113
+ stream, entry_contexts: list[str] = None, out: dict = None
114
+ ) -> dict:
105
115
  if entry_contexts is None:
106
116
  entry_contexts = [
107
117
  "autoplay",
@@ -113,19 +123,20 @@ def watched_count_cleanups(stream, entry_contexts: list[str] = None) -> dict:
113
123
  _validate_pwatched_entry_context(entry_contexts)
114
124
 
115
125
  counts_obj = stream.get(f"PWATCHED#24H", {})
116
- feats = {}
126
+ if out is None:
127
+ out = {}
117
128
  for entry_context in entry_contexts:
118
129
  attempts = counts_obj.get(entry_context, {}).get("attempts", 0)
119
130
  watched = counts_obj.get(entry_context, {}).get("watched", 0)
120
131
  context_key = entry_context if "launch" not in entry_context else "launch"
121
132
  context_key = context_key.upper().replace(" ", "_")
122
- feats[f"STREAM_{context_key}_24H_TOTAL_WATCHED"] = watched
123
- feats[f"STREAM_{context_key}_24H_TOTAL_ATTEMPTS"] = attempts
124
- return feats
133
+ out[f"STREAM_{context_key}_24H_TOTAL_WATCHED"] = watched
134
+ out[f"STREAM_{context_key}_24H_TOTAL_ATTEMPTS"] = attempts
135
+ return out
125
136
 
126
137
 
127
138
  def device_watched_count_cleanups(
128
- stream, device_type: str, entry_contexts: list[str] = None
139
+ stream, device_type: str, entry_contexts: list[str] = None, out: dict = None
129
140
  ) -> dict:
130
141
  if entry_contexts is None:
131
142
  entry_contexts = [
@@ -140,23 +151,24 @@ def device_watched_count_cleanups(
140
151
  _validate_device_type(device_type)
141
152
 
142
153
  counts_obj = stream.get(f"PWATCHED#24H#{device_type}", {})
143
- feats = {}
154
+ if out is None:
155
+ out = {}
144
156
  for entry_context in entry_contexts:
145
157
  attempts = counts_obj.get(entry_context, {}).get("attempts", 0)
146
158
  watched = counts_obj.get(entry_context, {}).get("watched", 0)
147
159
  context_key = entry_context if "launch" not in entry_context else "launch"
148
160
  context_key = context_key.upper().replace(" ", "_")
149
- feats[f"STREAM_{context_key}_{device_type}_24H_TOTAL_WATCHED"] = watched
150
- feats[f"STREAM_{context_key}_{device_type}_24H_TOTAL_ATTEMPTS"] = attempts
151
- return feats
161
+ out[f"STREAM_{context_key}_{device_type}_24H_TOTAL_WATCHED"] = watched
162
+ out[f"STREAM_{context_key}_{device_type}_24H_TOTAL_ATTEMPTS"] = attempts
163
+ return out
152
164
 
153
165
 
154
166
  def generic_beta_adjust_features(
155
167
  data: pd.DataFrame,
156
168
  prefix: str,
157
- pwatched_beta_params: dict,
158
- pselect_beta_params: dict,
159
- pslw_beta_params: dict,
169
+ pwatched_beta_params: dict = None,
170
+ pselect_beta_params: dict = None,
171
+ pslw_beta_params: dict = None,
160
172
  use_low_sample_flags: bool = False,
161
173
  low_sample_threshold: int = 3,
162
174
  use_attempt_features: bool = False,
@@ -164,67 +176,92 @@ def generic_beta_adjust_features(
164
176
  debiased_pselect: bool = True,
165
177
  use_logodds: bool = False,
166
178
  ) -> pd.DataFrame:
167
- pwatched_features = {}
168
- for context, (alpha, beta) in pwatched_beta_params.items():
169
- total_watched = data[f"{prefix}_{context}_TOTAL_WATCHED"].fillna(0)
170
- total_attempts = data[f"{prefix}_{context}_TOTAL_ATTEMPTS"].fillna(0)
171
- pwatched_features[f"{prefix}_{context}_ADJ_PWATCHED"] = (
172
- total_watched + alpha
173
- ) / (total_attempts + alpha + beta)
174
- if use_low_sample_flags:
175
- pwatched_features[f"{prefix}_{context}_LOW_SAMPLE"] = total_attempts.le(
176
- low_sample_threshold
177
- ).astype(int)
178
- if use_attempt_features:
179
- pwatched_features[f"{prefix}_{context}_ATTEMPTS"] = total_attempts.clip(
180
- upper=max_attempt_cap
179
+ features = {}
180
+ counting_feature_cols = [
181
+ c
182
+ for c in data.columns
183
+ if "TOTAL_WATCHED" in c
184
+ or "TOTAL_ATTEMPTS" in c
185
+ or "SELECT" in c
186
+ or "BROWSED" in c
187
+ ]
188
+ data_arr = data[counting_feature_cols].to_numpy(dtype=float)
189
+ col_to_idx = {col: i for i, col in enumerate(counting_feature_cols)}
190
+ if pwatched_beta_params is not None:
191
+ for context, (alpha, beta) in pwatched_beta_params.items():
192
+ total_watched = np.nan_to_num(
193
+ data_arr[:, col_to_idx[f"{prefix}_{context}_TOTAL_WATCHED"]]
194
+ )
195
+ total_attempts = np.nan_to_num(
196
+ data_arr[:, col_to_idx[f"{prefix}_{context}_TOTAL_ATTEMPTS"]]
181
197
  )
198
+ features[f"{prefix}_{context}_ADJ_PWATCHED"] = (total_watched + alpha) / (
199
+ total_attempts + alpha + beta
200
+ )
201
+ low_sample_arr = np.empty_like(total_attempts, dtype=float)
202
+ if use_low_sample_flags:
203
+ features[f"{prefix}_{context}_LOW_SAMPLE"] = np.less_equal(
204
+ total_attempts, low_sample_threshold, out=low_sample_arr
205
+ )
206
+ if use_attempt_features:
207
+ features[f"{prefix}_{context}_ATTEMPTS"] = np.clip(
208
+ total_attempts, a_min=None, a_max=max_attempt_cap
209
+ )
182
210
 
183
- pselect_features = {}
184
211
  debias_suffix = "_UP_TO_4_BROWSED" if debiased_pselect else ""
185
- for key, (alpha, beta) in pselect_beta_params.items():
186
- total_selects = data[f"{prefix}_{key}_TOTAL_SELECTS{debias_suffix}"].fillna(0)
187
- total_browsed = data[f"{prefix}_{key}_TOTAL_BROWSED{debias_suffix}"].fillna(0)
188
- pselect_features[f"{prefix}_{key}_ADJ_PSELECT{debias_suffix}"] = (
189
- total_selects + alpha
190
- ) / (total_selects + total_browsed + alpha + beta)
191
- if use_low_sample_flags:
192
- pselect_features[f"{prefix}_{key}_PSELECT_LOW_SAMPLE{debias_suffix}"] = (
193
- (total_selects + total_browsed).le(low_sample_threshold).astype(int)
194
- )
195
- if use_attempt_features:
196
- pselect_features[f"{prefix}_{key}_PSELECT_ATTEMPTS{debias_suffix}"] = (
197
- total_selects + total_browsed
198
- ).clip(upper=max_attempt_cap)
199
- total_slw = data[
200
- f"{prefix}_{key}_TOTAL_SELECTS_AND_WATCHED{debias_suffix}"
201
- ].fillna(0)
202
- pslw_alpha, pslw_beta = pslw_beta_params[key]
203
- pselect_features[f"{prefix}_{key}_ADJ_PSLW{debias_suffix}"] = (
204
- total_slw + pslw_alpha
205
- ) / (total_selects + total_browsed + pslw_alpha + pslw_beta)
206
- pselect_features[f"{prefix}_{key}_PSelNotW{debias_suffix}"] = (
207
- pselect_features[f"{prefix}_{key}_ADJ_PSELECT{debias_suffix}"]
208
- - pselect_features[f"{prefix}_{key}_ADJ_PSLW{debias_suffix}"]
209
- )
212
+ if pselect_beta_params is not None or pslw_beta_params is not None:
213
+ for key, (alpha, beta) in pselect_beta_params.items():
214
+ total_selects_idx = col_to_idx[
215
+ f"{prefix}_{key}_TOTAL_SELECTS{debias_suffix}"
216
+ ]
217
+ total_browsed_idx = col_to_idx[
218
+ f"{prefix}_{key}_TOTAL_BROWSED{debias_suffix}"
219
+ ]
220
+ total_slw_idx = col_to_idx[
221
+ f"{prefix}_{key}_TOTAL_SELECTS_AND_WATCHED{debias_suffix}"
222
+ ]
223
+ total_selects = np.nan_to_num(data_arr[:, total_selects_idx])
224
+ total_browsed = np.nan_to_num(data_arr[:, total_browsed_idx])
225
+ total_slw = np.nan_to_num(data_arr[:, total_slw_idx])
226
+ if pselect_beta_params is not None:
227
+ features[f"{prefix}_{key}_ADJ_PSELECT{debias_suffix}"] = (
228
+ total_selects + alpha
229
+ ) / (total_selects + total_browsed + alpha + beta)
230
+ if use_low_sample_flags:
231
+ low_sample_arr = np.empty_like(total_selects, dtype=float)
232
+ features[f"{prefix}_{key}_PSELECT_LOW_SAMPLE{debias_suffix}"] = (
233
+ np.less_equal(
234
+ total_selects + total_browsed,
235
+ low_sample_threshold,
236
+ out=low_sample_arr,
237
+ )
238
+ )
239
+ if use_attempt_features:
240
+ features[f"{prefix}_{key}_PSELECT_ATTEMPTS{debias_suffix}"] = np.clip(
241
+ total_selects + total_browsed, a_min=0, a_max=max_attempt_cap
242
+ )
243
+ if pslw_beta_params is not None:
244
+ pslw_alpha, pslw_beta = pslw_beta_params[key]
245
+ features[f"{prefix}_{key}_ADJ_PSLW{debias_suffix}"] = (
246
+ total_slw + pslw_alpha
247
+ ) / (total_selects + total_browsed + pslw_alpha + pslw_beta)
248
+ if pslw_beta_params is not None and pselect_beta_params is not None:
249
+ features[f"{prefix}_{key}_PSelNotW{debias_suffix}"] = (
250
+ features[f"{prefix}_{key}_ADJ_PSELECT{debias_suffix}"]
251
+ - features[f"{prefix}_{key}_ADJ_PSLW{debias_suffix}"]
252
+ )
210
253
 
211
- adjusted_feats = pd.DataFrame({**pwatched_features, **pselect_features})
254
+ adjusted_feats = pd.DataFrame(features, index=data.index)
212
255
  if use_logodds:
213
- adjusted_feats = adjusted_feats.pipe(
214
- lambda x: x.assign(
215
- **x[
216
- [
217
- c
218
- for c in x.columns
219
- if "PSELECT" in c
220
- or "PSLW" in c
221
- or "PWATCHED" in c
222
- or "PSelNotW" in c
223
- ]
224
- ]
225
- .clip(lower=0.001)
226
- .pipe(prob_to_logodds)
227
- )
256
+ arr = adjusted_feats.to_numpy()
257
+ col_idxs = [
258
+ i
259
+ for i, c in enumerate(adjusted_feats.columns)
260
+ if ("PSELECT" in c or "PSLW" in c or "PWATCHED" in c or "PSelNotW" in c)
261
+ and ("LOW_SAMPLE" not in c and "ATTEMPTS" not in c)
262
+ ]
263
+ arr[:, col_idxs] = prob_to_logodds(
264
+ np.clip(arr[:, col_idxs], a_min=0.001, a_max=None)
228
265
  )
229
266
  return adjusted_feats
230
267
 
@@ -251,7 +288,10 @@ def sigmoid(x: float) -> float:
251
288
  def generic_logistic_predict(
252
289
  data: pd.DataFrame, coeffs: pd.Series, intercept: float
253
290
  ) -> pd.Series:
254
- return ((data[coeffs.index] * coeffs).sum(axis=1) + intercept).pipe(sigmoid)
291
+ scores = (data[coeffs.index] * coeffs).sum(axis=1) + intercept
292
+ raw_arr = scores.to_numpy()
293
+ raw_arr[:] = sigmoid(raw_arr)
294
+ return scores
255
295
 
256
296
 
257
297
  def _validate_device_type(device_type: str):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: haystack-ml-stack
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Functions related to Haystack ML
5
5
  Author-email: Oscar Vega <oscar@haystack.tv>
6
6
  License: MIT
@@ -11,4 +11,5 @@ src/haystack_ml_stack.egg-info/PKG-INFO
11
11
  src/haystack_ml_stack.egg-info/SOURCES.txt
12
12
  src/haystack_ml_stack.egg-info/dependency_links.txt
13
13
  src/haystack_ml_stack.egg-info/requires.txt
14
- src/haystack_ml_stack.egg-info/top_level.txt
14
+ src/haystack_ml_stack.egg-info/top_level.txt
15
+ tests/test_utils.py
@@ -0,0 +1,76 @@
1
+ import pytest
2
+ import pandas as pd
3
+ from haystack_ml_stack import utils
4
+ import numpy as np
5
+
6
+
7
+ def test_sigmoid():
8
+ values_to_test = np.array([-1, 0, 1])
9
+ expected = np.array([0.26894142136992605, 0.5, 0.731058578630074])
10
+ actual = utils.sigmoid(values_to_test)
11
+ assert np.isclose(actual, expected).all()
12
+
13
+
14
+ def test_prob_to_logodds():
15
+ values_to_test = np.array([0.25, 0.5, 0.75])
16
+ expected = np.array([-1.0986122886681096, 0, 1.0986122886681096])
17
+ actual = utils.prob_to_logodds(values_to_test)
18
+ assert np.isclose(actual, expected).all(), print(actual - expected)
19
+
20
+
21
+ def test_generic_beta_adjust_features():
22
+ data_to_test = pd.DataFrame(
23
+ {
24
+ "STREAM_AUTOPLAY_24H_TOTAL_ATTEMPTS": [1, 2],
25
+ "STREAM_AUTOPLAY_24H_TOTAL_WATCHED": [0, 1],
26
+ "STREAM_24H_TOTAL_SELECTS_UP_TO_4_BROWSED": [1, 1],
27
+ "STREAM_24H_TOTAL_SELECTS_AND_WATCHED_UP_TO_4_BROWSED": [0, 1],
28
+ "STREAM_24H_TOTAL_BROWSED_UP_TO_4_BROWSED": [2, 0],
29
+ },
30
+ dtype=float,
31
+ )
32
+ actual = utils.generic_beta_adjust_features(
33
+ data=data_to_test,
34
+ prefix="STREAM",
35
+ pwatched_beta_params={"AUTOPLAY_24H": (2, 1)},
36
+ pselect_beta_params={"24H": (1, 1)},
37
+ pslw_beta_params={"24H": (0.5, 1)},
38
+ use_low_sample_flags=True,
39
+ )
40
+ # print(actual)
41
+ expected = pd.DataFrame(
42
+ {
43
+ "STREAM_AUTOPLAY_24H_ADJ_PWATCHED": [
44
+ (0 + 2) / (1 + 2 + 1),
45
+ (1 + 2) / (2 + 2 + 1),
46
+ ],
47
+ "STREAM_24H_ADJ_PSELECT_UP_TO_4_BROWSED": [
48
+ (1 + 1) / (1 + 2 + 1 + 1),
49
+ (1 + 1) / (1 + 0 + 1 + 1),
50
+ ],
51
+ "STREAM_24H_ADJ_PSLW_UP_TO_4_BROWSED": [
52
+ (0 + 0.5) / (1 + 2 + 0.5 + 1),
53
+ (1 + 0.5) / (1 + 0 + 0.5 + 1),
54
+ ],
55
+ "STREAM_24H_PSelNotW_UP_TO_4_BROWSED": [
56
+ (1 + 1) / (1 + 2 + 1 + 1) - (0 + 0.5) / (1 + 2 + 0.5 + 1),
57
+ (1 + 1) / (1 + 0 + 1 + 1) - (1 + 0.5) / (1 + 0 + 0.5 + 1),
58
+ ],
59
+ "STREAM_AUTOPLAY_24H_LOW_SAMPLE": [1, 1],
60
+ "STREAM_24H_PSELECT_LOW_SAMPLE_UP_TO_4_BROWSED": [1, 1],
61
+ }
62
+ )
63
+ assert (actual[expected.columns] == expected).all(axis=None), actual - expected
64
+
65
+
66
+ def test_generic_logistic_predict():
67
+ features = pd.DataFrame({"feat1": [0, 1, 2], "feat2": [3, 3, 5]}, dtype=float)
68
+ coeffs = pd.Series({"feat1": 1, "feat2": 2})
69
+ intercept = 1
70
+ expected = utils.sigmoid(
71
+ pd.Series([0 * 1 + 2 * 3, 1 * 1 + 2 * 3, 2 * 1 + 5 * 2]) + 1
72
+ )
73
+ actual = utils.generic_logistic_predict(
74
+ data=features, coeffs=coeffs, intercept=intercept
75
+ )
76
+ assert (expected == actual).all(), actual - expected