hindsight-api 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. hindsight_api/__init__.py +10 -9
  2. hindsight_api/alembic/env.py +5 -8
  3. hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
  4. hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
  5. hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
  6. hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
  7. hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
  8. hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
  9. hindsight_api/api/__init__.py +10 -10
  10. hindsight_api/api/http.py +575 -593
  11. hindsight_api/api/mcp.py +30 -28
  12. hindsight_api/banner.py +13 -6
  13. hindsight_api/config.py +9 -13
  14. hindsight_api/engine/__init__.py +9 -9
  15. hindsight_api/engine/cross_encoder.py +22 -21
  16. hindsight_api/engine/db_utils.py +5 -4
  17. hindsight_api/engine/embeddings.py +22 -21
  18. hindsight_api/engine/entity_resolver.py +81 -75
  19. hindsight_api/engine/llm_wrapper.py +61 -79
  20. hindsight_api/engine/memory_engine.py +603 -625
  21. hindsight_api/engine/query_analyzer.py +100 -97
  22. hindsight_api/engine/response_models.py +105 -106
  23. hindsight_api/engine/retain/__init__.py +9 -16
  24. hindsight_api/engine/retain/bank_utils.py +34 -58
  25. hindsight_api/engine/retain/chunk_storage.py +4 -12
  26. hindsight_api/engine/retain/deduplication.py +9 -28
  27. hindsight_api/engine/retain/embedding_processing.py +4 -11
  28. hindsight_api/engine/retain/embedding_utils.py +3 -4
  29. hindsight_api/engine/retain/entity_processing.py +7 -17
  30. hindsight_api/engine/retain/fact_extraction.py +155 -165
  31. hindsight_api/engine/retain/fact_storage.py +11 -23
  32. hindsight_api/engine/retain/link_creation.py +11 -39
  33. hindsight_api/engine/retain/link_utils.py +166 -95
  34. hindsight_api/engine/retain/observation_regeneration.py +39 -52
  35. hindsight_api/engine/retain/orchestrator.py +72 -62
  36. hindsight_api/engine/retain/types.py +49 -43
  37. hindsight_api/engine/search/__init__.py +5 -5
  38. hindsight_api/engine/search/fusion.py +6 -15
  39. hindsight_api/engine/search/graph_retrieval.py +22 -23
  40. hindsight_api/engine/search/mpfp_retrieval.py +76 -92
  41. hindsight_api/engine/search/observation_utils.py +9 -16
  42. hindsight_api/engine/search/reranking.py +4 -7
  43. hindsight_api/engine/search/retrieval.py +87 -66
  44. hindsight_api/engine/search/scoring.py +5 -7
  45. hindsight_api/engine/search/temporal_extraction.py +8 -11
  46. hindsight_api/engine/search/think_utils.py +115 -39
  47. hindsight_api/engine/search/trace.py +68 -39
  48. hindsight_api/engine/search/tracer.py +44 -35
  49. hindsight_api/engine/search/types.py +20 -17
  50. hindsight_api/engine/task_backend.py +21 -26
  51. hindsight_api/engine/utils.py +25 -10
  52. hindsight_api/main.py +21 -40
  53. hindsight_api/mcp_local.py +190 -0
  54. hindsight_api/metrics.py +44 -30
  55. hindsight_api/migrations.py +10 -8
  56. hindsight_api/models.py +60 -72
  57. hindsight_api/pg0.py +22 -23
  58. hindsight_api/server.py +3 -6
  59. hindsight_api-0.1.7.dist-info/METADATA +178 -0
  60. hindsight_api-0.1.7.dist-info/RECORD +64 -0
  61. {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/entry_points.txt +1 -0
  62. hindsight_api-0.1.5.dist-info/METADATA +0 -42
  63. hindsight_api-0.1.5.dist-info/RECORD +0 -63
  64. {hindsight_api-0.1.5.dist-info → hindsight_api-0.1.7.dist-info}/WHEEL +0 -0
@@ -16,13 +16,12 @@ Key properties:
16
16
 
17
17
  import asyncio
18
18
  import logging
19
- from dataclasses import dataclass, field
20
- from typing import List, Dict, Optional, Tuple
21
19
  from collections import defaultdict
20
+ from dataclasses import dataclass, field
22
21
 
23
- from .types import RetrievalResult
24
- from .graph_retrieval import GraphRetriever
25
22
  from ..db_utils import acquire_with_retry
23
+ from .graph_retrieval import GraphRetriever
24
+ from .types import RetrievalResult
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
@@ -31,9 +30,11 @@ logger = logging.getLogger(__name__)
31
30
  # Data Classes
32
31
  # -----------------------------------------------------------------------------
33
32
 
33
+
34
34
  @dataclass
35
35
  class EdgeTarget:
36
36
  """A neighbor node with its edge weight."""
37
+
37
38
  node_id: str
38
39
  weight: float
39
40
 
@@ -41,19 +42,15 @@ class EdgeTarget:
41
42
  @dataclass
42
43
  class TypedAdjacency:
43
44
  """Adjacency lists split by edge type."""
45
+
44
46
  # edge_type -> from_node_id -> list of (to_node_id, weight)
45
- graphs: Dict[str, Dict[str, List[EdgeTarget]]] = field(default_factory=dict)
47
+ graphs: dict[str, dict[str, list[EdgeTarget]]] = field(default_factory=dict)
46
48
 
47
- def get_neighbors(self, edge_type: str, node_id: str) -> List[EdgeTarget]:
49
+ def get_neighbors(self, edge_type: str, node_id: str) -> list[EdgeTarget]:
48
50
  """Get neighbors for a node via a specific edge type."""
49
51
  return self.graphs.get(edge_type, {}).get(node_id, [])
50
52
 
51
- def get_normalized_neighbors(
52
- self,
53
- edge_type: str,
54
- node_id: str,
55
- top_k: int
56
- ) -> List[EdgeTarget]:
53
+ def get_normalized_neighbors(self, edge_type: str, node_id: str, top_k: int) -> list[EdgeTarget]:
57
54
  """Get top-k neighbors with weights normalized to sum to 1."""
58
55
  neighbors = self.get_neighbors(edge_type, node_id)[:top_k]
59
56
  if not neighbors:
@@ -63,45 +60,49 @@ class TypedAdjacency:
63
60
  if total == 0:
64
61
  return []
65
62
 
66
- return [
67
- EdgeTarget(node_id=n.node_id, weight=n.weight / total)
68
- for n in neighbors
69
- ]
63
+ return [EdgeTarget(node_id=n.node_id, weight=n.weight / total) for n in neighbors]
70
64
 
71
65
 
72
66
  @dataclass
73
67
  class PatternResult:
74
68
  """Result from a single pattern traversal."""
75
- pattern: List[str]
76
- scores: Dict[str, float] # node_id -> accumulated mass
69
+
70
+ pattern: list[str]
71
+ scores: dict[str, float] # node_id -> accumulated mass
77
72
 
78
73
 
79
74
  @dataclass
80
75
  class MPFPConfig:
81
76
  """Configuration for MPFP algorithm."""
82
- alpha: float = 0.15 # teleport/keep probability
83
- threshold: float = 1e-6 # mass pruning threshold (lower = explore more)
84
- top_k_neighbors: int = 20 # fan-out limit per node
77
+
78
+ alpha: float = 0.15 # teleport/keep probability
79
+ threshold: float = 1e-6 # mass pruning threshold (lower = explore more)
80
+ top_k_neighbors: int = 20 # fan-out limit per node
85
81
 
86
82
  # Patterns from semantic seeds
87
- patterns_semantic: List[List[str]] = field(default_factory=lambda: [
88
- ['semantic', 'semantic'], # topic expansion
89
- ['entity', 'temporal'], # entity timeline
90
- ['semantic', 'causes'], # reasoning chains (forward)
91
- ['semantic', 'caused_by'], # reasoning chains (backward)
92
- ['entity', 'semantic'], # entity context
93
- ])
83
+ patterns_semantic: list[list[str]] = field(
84
+ default_factory=lambda: [
85
+ ["semantic", "semantic"], # topic expansion
86
+ ["entity", "temporal"], # entity timeline
87
+ ["semantic", "causes"], # reasoning chains (forward)
88
+ ["semantic", "caused_by"], # reasoning chains (backward)
89
+ ["entity", "semantic"], # entity context
90
+ ]
91
+ )
94
92
 
95
93
  # Patterns from temporal seeds
96
- patterns_temporal: List[List[str]] = field(default_factory=lambda: [
97
- ['temporal', 'semantic'], # what was happening then
98
- ['temporal', 'entity'], # who was involved then
99
- ])
94
+ patterns_temporal: list[list[str]] = field(
95
+ default_factory=lambda: [
96
+ ["temporal", "semantic"], # what was happening then
97
+ ["temporal", "entity"], # who was involved then
98
+ ]
99
+ )
100
100
 
101
101
 
102
102
  @dataclass
103
103
  class SeedNode:
104
104
  """An entry point node with its initial score."""
105
+
105
106
  node_id: str
106
107
  score: float # initial mass (e.g., similarity score)
107
108
 
@@ -110,9 +111,10 @@ class SeedNode:
110
111
  # Core Algorithm
111
112
  # -----------------------------------------------------------------------------
112
113
 
114
+
113
115
  def mpfp_traverse(
114
- seeds: List[SeedNode],
115
- pattern: List[str],
116
+ seeds: list[SeedNode],
117
+ pattern: list[str],
116
118
  adjacency: TypedAdjacency,
117
119
  config: MPFPConfig,
118
120
  ) -> PatternResult:
@@ -131,20 +133,18 @@ def mpfp_traverse(
131
133
  if not seeds:
132
134
  return PatternResult(pattern=pattern, scores={})
133
135
 
134
- scores: Dict[str, float] = {}
136
+ scores: dict[str, float] = {}
135
137
 
136
138
  # Initialize frontier with seed masses (normalized)
137
139
  total_seed_score = sum(s.score for s in seeds)
138
140
  if total_seed_score == 0:
139
141
  total_seed_score = len(seeds) # fallback to uniform
140
142
 
141
- frontier: Dict[str, float] = {
142
- s.node_id: s.score / total_seed_score for s in seeds
143
- }
143
+ frontier: dict[str, float] = {s.node_id: s.score / total_seed_score for s in seeds}
144
144
 
145
145
  # Follow pattern hop by hop
146
146
  for edge_type in pattern:
147
- next_frontier: Dict[str, float] = {}
147
+ next_frontier: dict[str, float] = {}
148
148
 
149
149
  for node_id, mass in frontier.items():
150
150
  if mass < config.threshold:
@@ -155,15 +155,10 @@ def mpfp_traverse(
155
155
 
156
156
  # Push (1-α) to neighbors
157
157
  push_mass = (1 - config.alpha) * mass
158
- neighbors = adjacency.get_normalized_neighbors(
159
- edge_type, node_id, config.top_k_neighbors
160
- )
158
+ neighbors = adjacency.get_normalized_neighbors(edge_type, node_id, config.top_k_neighbors)
161
159
 
162
160
  for neighbor in neighbors:
163
- next_frontier[neighbor.node_id] = (
164
- next_frontier.get(neighbor.node_id, 0) +
165
- push_mass * neighbor.weight
166
- )
161
+ next_frontier[neighbor.node_id] = next_frontier.get(neighbor.node_id, 0) + push_mass * neighbor.weight
167
162
 
168
163
  frontier = next_frontier
169
164
 
@@ -176,10 +171,10 @@ def mpfp_traverse(
176
171
 
177
172
 
178
173
  def rrf_fusion(
179
- results: List[PatternResult],
174
+ results: list[PatternResult],
180
175
  k: int = 60,
181
176
  top_k: int = 50,
182
- ) -> List[Tuple[str, float]]:
177
+ ) -> list[tuple[str, float]]:
183
178
  """
184
179
  Reciprocal Rank Fusion to combine pattern results.
185
180
 
@@ -191,28 +186,20 @@ def rrf_fusion(
191
186
  Returns:
192
187
  List of (node_id, fused_score) tuples, sorted by score descending
193
188
  """
194
- fused: Dict[str, float] = {}
189
+ fused: dict[str, float] = {}
195
190
 
196
191
  for result in results:
197
192
  if not result.scores:
198
193
  continue
199
194
 
200
195
  # Rank nodes by their score in this pattern
201
- ranked = sorted(
202
- result.scores.keys(),
203
- key=lambda n: result.scores[n],
204
- reverse=True
205
- )
196
+ ranked = sorted(result.scores.keys(), key=lambda n: result.scores[n], reverse=True)
206
197
 
207
198
  for rank, node_id in enumerate(ranked):
208
199
  fused[node_id] = fused.get(node_id, 0) + 1.0 / (k + rank + 1)
209
200
 
210
201
  # Sort by fused score and return top-k
211
- sorted_results = sorted(
212
- fused.items(),
213
- key=lambda x: x[1],
214
- reverse=True
215
- )
202
+ sorted_results = sorted(fused.items(), key=lambda x: x[1], reverse=True)
216
203
 
217
204
  return sorted_results[:top_k]
218
205
 
@@ -221,6 +208,7 @@ def rrf_fusion(
221
208
  # Database Loading
222
209
  # -----------------------------------------------------------------------------
223
210
 
211
+
224
212
  async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
225
213
  """
226
214
  Load all edges for a bank, split by edge type.
@@ -237,31 +225,27 @@ async def load_typed_adjacency(pool, bank_id: str) -> TypedAdjacency:
237
225
  AND ml.weight >= 0.1
238
226
  ORDER BY ml.from_unit_id, ml.weight DESC
239
227
  """,
240
- bank_id
228
+ bank_id,
241
229
  )
242
230
 
243
- graphs: Dict[str, Dict[str, List[EdgeTarget]]] = defaultdict(
244
- lambda: defaultdict(list)
245
- )
231
+ graphs: dict[str, dict[str, list[EdgeTarget]]] = defaultdict(lambda: defaultdict(list))
246
232
 
247
233
  for row in rows:
248
- from_id = str(row['from_unit_id'])
249
- to_id = str(row['to_unit_id'])
250
- link_type = row['link_type']
251
- weight = row['weight']
234
+ from_id = str(row["from_unit_id"])
235
+ to_id = str(row["to_unit_id"])
236
+ link_type = row["link_type"]
237
+ weight = row["weight"]
252
238
 
253
- graphs[link_type][from_id].append(
254
- EdgeTarget(node_id=to_id, weight=weight)
255
- )
239
+ graphs[link_type][from_id].append(EdgeTarget(node_id=to_id, weight=weight))
256
240
 
257
241
  return TypedAdjacency(graphs=dict(graphs))
258
242
 
259
243
 
260
244
  async def fetch_memory_units_by_ids(
261
245
  pool,
262
- node_ids: List[str],
246
+ node_ids: list[str],
263
247
  fact_type: str,
264
- ) -> List[RetrievalResult]:
248
+ ) -> list[RetrievalResult]:
265
249
  """Fetch full memory unit details for a list of node IDs."""
266
250
  if not node_ids:
267
251
  return []
@@ -276,7 +260,7 @@ async def fetch_memory_units_by_ids(
276
260
  AND fact_type = $2
277
261
  """,
278
262
  node_ids,
279
- fact_type
263
+ fact_type,
280
264
  )
281
265
 
282
266
  return [RetrievalResult.from_db_row(dict(r)) for r in rows]
@@ -286,6 +270,7 @@ async def fetch_memory_units_by_ids(
286
270
  # Graph Retriever Implementation
287
271
  # -----------------------------------------------------------------------------
288
272
 
273
+
289
274
  class MPFPGraphRetriever(GraphRetriever):
290
275
  """
291
276
  Graph retrieval using Meta-Path Forward Push.
@@ -294,7 +279,7 @@ class MPFPGraphRetriever(GraphRetriever):
294
279
  then fuses results via RRF.
295
280
  """
296
281
 
297
- def __init__(self, config: Optional[MPFPConfig] = None):
282
+ def __init__(self, config: MPFPConfig | None = None):
298
283
  """
299
284
  Initialize MPFP retriever.
300
285
 
@@ -302,7 +287,7 @@ class MPFPGraphRetriever(GraphRetriever):
302
287
  config: Algorithm configuration (uses defaults if None)
303
288
  """
304
289
  self.config = config or MPFPConfig()
305
- self._adjacency_cache: Dict[str, TypedAdjacency] = {}
290
+ self._adjacency_cache: dict[str, TypedAdjacency] = {}
306
291
 
307
292
  @property
308
293
  def name(self) -> str:
@@ -315,10 +300,10 @@ class MPFPGraphRetriever(GraphRetriever):
315
300
  bank_id: str,
316
301
  fact_type: str,
317
302
  budget: int,
318
- query_text: Optional[str] = None,
319
- semantic_seeds: Optional[List[RetrievalResult]] = None,
320
- temporal_seeds: Optional[List[RetrievalResult]] = None,
321
- ) -> List[RetrievalResult]:
303
+ query_text: str | None = None,
304
+ semantic_seeds: list[RetrievalResult] | None = None,
305
+ temporal_seeds: list[RetrievalResult] | None = None,
306
+ ) -> list[RetrievalResult]:
322
307
  """
323
308
  Retrieve facts using MPFP algorithm.
324
309
 
@@ -339,14 +324,12 @@ class MPFPGraphRetriever(GraphRetriever):
339
324
  adjacency = await load_typed_adjacency(pool, bank_id)
340
325
 
341
326
  # Convert seeds to SeedNode format
342
- semantic_seed_nodes = self._convert_seeds(semantic_seeds, 'similarity')
343
- temporal_seed_nodes = self._convert_seeds(temporal_seeds, 'temporal_score')
327
+ semantic_seed_nodes = self._convert_seeds(semantic_seeds, "similarity")
328
+ temporal_seed_nodes = self._convert_seeds(temporal_seeds, "temporal_score")
344
329
 
345
330
  # If no semantic seeds provided, fall back to finding our own
346
331
  if not semantic_seed_nodes:
347
- semantic_seed_nodes = await self._find_semantic_seeds(
348
- pool, query_embedding_str, bank_id, fact_type
349
- )
332
+ semantic_seed_nodes = await self._find_semantic_seeds(pool, query_embedding_str, bank_id, fact_type)
350
333
 
351
334
  # Run all patterns in parallel
352
335
  tasks = []
@@ -407,9 +390,9 @@ class MPFPGraphRetriever(GraphRetriever):
407
390
 
408
391
  def _convert_seeds(
409
392
  self,
410
- seeds: Optional[List[RetrievalResult]],
393
+ seeds: list[RetrievalResult] | None,
411
394
  score_attr: str,
412
- ) -> List[SeedNode]:
395
+ ) -> list[SeedNode]:
413
396
  """Convert RetrievalResult seeds to SeedNode format."""
414
397
  if not seeds:
415
398
  return []
@@ -431,7 +414,7 @@ class MPFPGraphRetriever(GraphRetriever):
431
414
  fact_type: str,
432
415
  limit: int = 20,
433
416
  threshold: float = 0.3,
434
- ) -> List[SeedNode]:
417
+ ) -> list[SeedNode]:
435
418
  """Fallback: find semantic seeds via embedding search."""
436
419
  async with acquire_with_retry(pool) as conn:
437
420
  rows = await conn.fetch(
@@ -445,10 +428,11 @@ class MPFPGraphRetriever(GraphRetriever):
445
428
  ORDER BY embedding <=> $1::vector
446
429
  LIMIT $5
447
430
  """,
448
- query_embedding_str, bank_id, fact_type, threshold, limit
431
+ query_embedding_str,
432
+ bank_id,
433
+ fact_type,
434
+ threshold,
435
+ limit,
449
436
  )
450
437
 
451
- return [
452
- SeedNode(node_id=str(r['id']), score=r['similarity'])
453
- for r in rows
454
- ]
438
+ return [SeedNode(node_id=str(r["id"]), score=r["similarity"]) for r in rows]
@@ -6,7 +6,7 @@ about an entity, without personality influence.
6
6
  """
7
7
 
8
8
  import logging
9
- from typing import List, Dict, Any
9
+
10
10
  from pydantic import BaseModel, Field
11
11
 
12
12
  from ..response_models import MemoryFact
@@ -16,18 +16,17 @@ logger = logging.getLogger(__name__)
16
16
 
17
17
  class Observation(BaseModel):
18
18
  """An observation about an entity."""
19
+
19
20
  observation: str = Field(description="The observation text - a factual statement about the entity")
20
21
 
21
22
 
22
23
  class ObservationExtractionResponse(BaseModel):
23
24
  """Response containing extracted observations."""
24
- observations: List[Observation] = Field(
25
- default_factory=list,
26
- description="List of observations about the entity"
27
- )
28
25
 
26
+ observations: list[Observation] = Field(default_factory=list, description="List of observations about the entity")
29
27
 
30
- def format_facts_for_observation_prompt(facts: List[MemoryFact]) -> str:
28
+
29
+ def format_facts_for_observation_prompt(facts: list[MemoryFact]) -> str:
31
30
  """Format facts as text for observation extraction prompt."""
32
31
  import json
33
32
 
@@ -35,9 +34,7 @@ def format_facts_for_observation_prompt(facts: List[MemoryFact]) -> str:
35
34
  return "[]"
36
35
  formatted = []
37
36
  for fact in facts:
38
- fact_obj = {
39
- "text": fact.text
40
- }
37
+ fact_obj = {"text": fact.text}
41
38
 
42
39
  # Add context if available
43
40
  if fact.context:
@@ -92,11 +89,7 @@ def get_observation_system_message() -> str:
92
89
  return "You are an objective observer synthesizing facts about an entity. Generate clear, factual observations without opinions or personality influence. Be concise and accurate."
93
90
 
94
91
 
95
- async def extract_observations_from_facts(
96
- llm_config,
97
- entity_name: str,
98
- facts: List[MemoryFact]
99
- ) -> List[str]:
92
+ async def extract_observations_from_facts(llm_config, entity_name: str, facts: list[MemoryFact]) -> list[str]:
100
93
  """
101
94
  Extract observations from facts about an entity using LLM.
102
95
 
@@ -118,10 +111,10 @@ async def extract_observations_from_facts(
118
111
  result = await llm_config.call(
119
112
  messages=[
120
113
  {"role": "system", "content": get_observation_system_message()},
121
- {"role": "user", "content": prompt}
114
+ {"role": "user", "content": prompt},
122
115
  ],
123
116
  response_format=ObservationExtractionResponse,
124
- scope="memory_extract_observation"
117
+ scope="memory_extract_observation",
125
118
  )
126
119
 
127
120
  observations = [op.observation for op in result.observations]
@@ -2,7 +2,6 @@
2
2
  Cross-encoder neural reranking for search results.
3
3
  """
4
4
 
5
- from typing import List
6
5
  from .types import MergedCandidate, ScoredResult
7
6
 
8
7
 
@@ -24,14 +23,11 @@ class CrossEncoderReranker:
24
23
  """
25
24
  if cross_encoder is None:
26
25
  from hindsight_api.engine.cross_encoder import create_cross_encoder_from_env
26
+
27
27
  cross_encoder = create_cross_encoder_from_env()
28
28
  self.cross_encoder = cross_encoder
29
29
 
30
- def rerank(
31
- self,
32
- query: str,
33
- candidates: List[MergedCandidate]
34
- ) -> List[ScoredResult]:
30
+ def rerank(self, query: str, candidates: list[MergedCandidate]) -> list[ScoredResult]:
35
31
  """
36
32
  Rerank candidates using cross-encoder scores.
37
33
 
@@ -77,6 +73,7 @@ class CrossEncoderReranker:
77
73
  # Normalize scores using sigmoid to [0, 1] range
78
74
  # Cross-encoder returns logits which can be negative
79
75
  import numpy as np
76
+
80
77
  def sigmoid(x):
81
78
  return 1 / (1 + np.exp(-x))
82
79
 
@@ -89,7 +86,7 @@ class CrossEncoderReranker:
89
86
  candidate=candidate,
90
87
  cross_encoder_score=float(raw_score),
91
88
  cross_encoder_score_normalized=float(norm_score),
92
- weight=float(norm_score) # Initial weight is just cross-encoder score
89
+ weight=float(norm_score), # Initial weight is just cross-encoder score
93
90
  )
94
91
  scored_results.append(scored_result)
95
92