hindsight-api 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +10 -9
- hindsight_api/alembic/env.py +5 -8
- hindsight_api/alembic/versions/5a366d414dce_initial_schema.py +266 -180
- hindsight_api/alembic/versions/b7c4d8e9f1a2_add_chunks_table.py +32 -32
- hindsight_api/alembic/versions/c8e5f2a3b4d1_add_retain_params_to_documents.py +11 -11
- hindsight_api/alembic/versions/d9f6a3b4c5e2_rename_bank_to_interactions.py +7 -12
- hindsight_api/alembic/versions/e0a1b2c3d4e5_disposition_to_3_traits.py +23 -15
- hindsight_api/alembic/versions/rename_personality_to_disposition.py +30 -21
- hindsight_api/api/__init__.py +10 -10
- hindsight_api/api/http.py +575 -593
- hindsight_api/api/mcp.py +31 -33
- hindsight_api/banner.py +13 -6
- hindsight_api/config.py +17 -12
- hindsight_api/engine/__init__.py +9 -9
- hindsight_api/engine/cross_encoder.py +23 -27
- hindsight_api/engine/db_utils.py +5 -4
- hindsight_api/engine/embeddings.py +22 -21
- hindsight_api/engine/entity_resolver.py +81 -75
- hindsight_api/engine/llm_wrapper.py +74 -88
- hindsight_api/engine/memory_engine.py +663 -673
- hindsight_api/engine/query_analyzer.py +100 -97
- hindsight_api/engine/response_models.py +105 -106
- hindsight_api/engine/retain/__init__.py +9 -16
- hindsight_api/engine/retain/bank_utils.py +34 -58
- hindsight_api/engine/retain/chunk_storage.py +4 -12
- hindsight_api/engine/retain/deduplication.py +9 -28
- hindsight_api/engine/retain/embedding_processing.py +4 -11
- hindsight_api/engine/retain/embedding_utils.py +3 -4
- hindsight_api/engine/retain/entity_processing.py +7 -17
- hindsight_api/engine/retain/fact_extraction.py +155 -165
- hindsight_api/engine/retain/fact_storage.py +11 -23
- hindsight_api/engine/retain/link_creation.py +11 -39
- hindsight_api/engine/retain/link_utils.py +166 -95
- hindsight_api/engine/retain/observation_regeneration.py +39 -52
- hindsight_api/engine/retain/orchestrator.py +72 -62
- hindsight_api/engine/retain/types.py +49 -43
- hindsight_api/engine/search/__init__.py +15 -1
- hindsight_api/engine/search/fusion.py +6 -15
- hindsight_api/engine/search/graph_retrieval.py +234 -0
- hindsight_api/engine/search/mpfp_retrieval.py +438 -0
- hindsight_api/engine/search/observation_utils.py +9 -16
- hindsight_api/engine/search/reranking.py +4 -7
- hindsight_api/engine/search/retrieval.py +388 -193
- hindsight_api/engine/search/scoring.py +5 -7
- hindsight_api/engine/search/temporal_extraction.py +8 -11
- hindsight_api/engine/search/think_utils.py +115 -39
- hindsight_api/engine/search/trace.py +68 -38
- hindsight_api/engine/search/tracer.py +49 -35
- hindsight_api/engine/search/types.py +22 -16
- hindsight_api/engine/task_backend.py +21 -26
- hindsight_api/engine/utils.py +25 -10
- hindsight_api/main.py +21 -40
- hindsight_api/mcp_local.py +190 -0
- hindsight_api/metrics.py +44 -30
- hindsight_api/migrations.py +10 -8
- hindsight_api/models.py +60 -72
- hindsight_api/pg0.py +64 -337
- hindsight_api/server.py +3 -6
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/METADATA +6 -5
- hindsight_api-0.1.6.dist-info/RECORD +64 -0
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/entry_points.txt +1 -0
- hindsight_api-0.1.4.dist-info/RECORD +0 -61
- {hindsight_api-0.1.4.dist-info → hindsight_api-0.1.6.dist-info}/WHEEL +0 -0
|
@@ -4,11 +4,12 @@ Query analysis abstraction for the memory system.
|
|
|
4
4
|
Provides an interface for analyzing natural language queries to extract
|
|
5
5
|
structured information like temporal constraints.
|
|
6
6
|
"""
|
|
7
|
-
|
|
8
|
-
from typing import Optional
|
|
9
|
-
from datetime import datetime, timedelta
|
|
7
|
+
|
|
10
8
|
import logging
|
|
11
9
|
import re
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from datetime import datetime, timedelta
|
|
12
|
+
|
|
12
13
|
from pydantic import BaseModel, Field
|
|
13
14
|
|
|
14
15
|
logger = logging.getLogger(__name__)
|
|
@@ -20,6 +21,7 @@ class TemporalConstraint(BaseModel):
|
|
|
20
21
|
|
|
21
22
|
Represents a time range with start and end dates.
|
|
22
23
|
"""
|
|
24
|
+
|
|
23
25
|
start_date: datetime = Field(description="Start of the time range (inclusive)")
|
|
24
26
|
end_date: datetime = Field(description="End of the time range (inclusive)")
|
|
25
27
|
|
|
@@ -33,9 +35,9 @@ class QueryAnalysis(BaseModel):
|
|
|
33
35
|
|
|
34
36
|
Contains extracted structured information like temporal constraints.
|
|
35
37
|
"""
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
description="Extracted temporal constraint, if any"
|
|
38
|
+
|
|
39
|
+
temporal_constraint: TemporalConstraint | None = Field(
|
|
40
|
+
default=None, description="Extracted temporal constraint, if any"
|
|
39
41
|
)
|
|
40
42
|
|
|
41
43
|
|
|
@@ -58,9 +60,7 @@ class QueryAnalyzer(ABC):
|
|
|
58
60
|
pass
|
|
59
61
|
|
|
60
62
|
@abstractmethod
|
|
61
|
-
def analyze(
|
|
62
|
-
self, query: str, reference_date: Optional[datetime] = None
|
|
63
|
-
) -> QueryAnalysis:
|
|
63
|
+
def analyze(self, query: str, reference_date: datetime | None = None) -> QueryAnalysis:
|
|
64
64
|
"""
|
|
65
65
|
Analyze a natural language query.
|
|
66
66
|
|
|
@@ -95,11 +95,10 @@ class DateparserQueryAnalyzer(QueryAnalyzer):
|
|
|
95
95
|
"""Load dateparser (lazy import)."""
|
|
96
96
|
if self._search_dates is None:
|
|
97
97
|
from dateparser.search import search_dates
|
|
98
|
+
|
|
98
99
|
self._search_dates = search_dates
|
|
99
100
|
|
|
100
|
-
def analyze(
|
|
101
|
-
self, query: str, reference_date: Optional[datetime] = None
|
|
102
|
-
) -> QueryAnalysis:
|
|
101
|
+
def analyze(self, query: str, reference_date: datetime | None = None) -> QueryAnalysis:
|
|
103
102
|
"""
|
|
104
103
|
Analyze query using dateparser.
|
|
105
104
|
|
|
@@ -126,9 +125,9 @@ class DateparserQueryAnalyzer(QueryAnalyzer):
|
|
|
126
125
|
|
|
127
126
|
# Use dateparser's search_dates to find temporal expressions
|
|
128
127
|
settings = {
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
128
|
+
"RELATIVE_BASE": reference_date,
|
|
129
|
+
"PREFER_DATES_FROM": "past",
|
|
130
|
+
"RETURN_AS_TIMEZONE_AWARE": False,
|
|
132
131
|
}
|
|
133
132
|
|
|
134
133
|
results = self._search_dates(query, settings=settings)
|
|
@@ -137,11 +136,8 @@ class DateparserQueryAnalyzer(QueryAnalyzer):
|
|
|
137
136
|
return QueryAnalysis(temporal_constraint=None)
|
|
138
137
|
|
|
139
138
|
# Filter out false positives (common words parsed as dates)
|
|
140
|
-
false_positives = {
|
|
141
|
-
valid_results = [
|
|
142
|
-
(text, date) for text, date in results
|
|
143
|
-
if text.lower() not in false_positives or len(text) > 3
|
|
144
|
-
]
|
|
139
|
+
false_positives = {"do", "may", "march", "will", "can", "sat", "sun", "mon", "tue", "wed", "thu", "fri"}
|
|
140
|
+
valid_results = [(text, date) for text, date in results if text.lower() not in false_positives or len(text) > 3]
|
|
145
141
|
|
|
146
142
|
if not valid_results:
|
|
147
143
|
return QueryAnalysis(temporal_constraint=None)
|
|
@@ -153,84 +149,94 @@ class DateparserQueryAnalyzer(QueryAnalyzer):
|
|
|
153
149
|
start_date = parsed_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
154
150
|
end_date = parsed_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
|
155
151
|
|
|
156
|
-
return QueryAnalysis(
|
|
157
|
-
temporal_constraint=TemporalConstraint(
|
|
158
|
-
start_date=start_date,
|
|
159
|
-
end_date=end_date
|
|
160
|
-
)
|
|
161
|
-
)
|
|
152
|
+
return QueryAnalysis(temporal_constraint=TemporalConstraint(start_date=start_date, end_date=end_date))
|
|
162
153
|
|
|
163
|
-
def _extract_period(
|
|
164
|
-
self, query: str, reference_date: datetime
|
|
165
|
-
) -> Optional[TemporalConstraint]:
|
|
154
|
+
def _extract_period(self, query: str, reference_date: datetime) -> TemporalConstraint | None:
|
|
166
155
|
"""
|
|
167
156
|
Extract period-based temporal expressions (week, month, year, weekend).
|
|
168
157
|
|
|
169
158
|
These need special handling as they represent date ranges, not single dates.
|
|
170
159
|
Supports multiple languages.
|
|
171
160
|
"""
|
|
161
|
+
|
|
172
162
|
def constraint(start: datetime, end: datetime) -> TemporalConstraint:
|
|
173
163
|
return TemporalConstraint(
|
|
174
164
|
start_date=start.replace(hour=0, minute=0, second=0, microsecond=0),
|
|
175
|
-
end_date=end.replace(hour=23, minute=59, second=59, microsecond=999999)
|
|
165
|
+
end_date=end.replace(hour=23, minute=59, second=59, microsecond=999999),
|
|
176
166
|
)
|
|
177
167
|
|
|
178
168
|
# Yesterday patterns (English, Spanish, Italian, French, German)
|
|
179
|
-
if re.search(r
|
|
169
|
+
if re.search(r"\b(yesterday|ayer|ieri|hier|gestern)\b", query, re.IGNORECASE):
|
|
180
170
|
d = reference_date - timedelta(days=1)
|
|
181
171
|
return constraint(d, d)
|
|
182
172
|
|
|
183
173
|
# Today patterns
|
|
184
|
-
if re.search(r
|
|
174
|
+
if re.search(r"\b(today|hoy|oggi|aujourd\'?hui|heute)\b", query, re.IGNORECASE):
|
|
185
175
|
return constraint(reference_date, reference_date)
|
|
186
176
|
|
|
187
177
|
# "a couple of days ago" / "a few days ago" patterns
|
|
188
178
|
# These are imprecise so we create a range
|
|
189
|
-
if re.search(r
|
|
179
|
+
if re.search(r"\b(a\s+)?couple\s+(of\s+)?days?\s+ago\b", query, re.IGNORECASE):
|
|
190
180
|
# "a couple of days" = approximately 2 days, give range of 1-3 days
|
|
191
181
|
return constraint(reference_date - timedelta(days=3), reference_date - timedelta(days=1))
|
|
192
182
|
|
|
193
|
-
if re.search(r
|
|
183
|
+
if re.search(r"\b(a\s+)?few\s+days?\s+ago\b", query, re.IGNORECASE):
|
|
194
184
|
# "a few days" = approximately 3-4 days, give range of 2-5 days
|
|
195
185
|
return constraint(reference_date - timedelta(days=5), reference_date - timedelta(days=2))
|
|
196
186
|
|
|
197
187
|
# "a couple of weeks ago" / "a few weeks ago" patterns
|
|
198
|
-
if re.search(r
|
|
188
|
+
if re.search(r"\b(a\s+)?couple\s+(of\s+)?weeks?\s+ago\b", query, re.IGNORECASE):
|
|
199
189
|
# "a couple of weeks" = approximately 2 weeks, give range of 1-3 weeks
|
|
200
190
|
return constraint(reference_date - timedelta(weeks=3), reference_date - timedelta(weeks=1))
|
|
201
191
|
|
|
202
|
-
if re.search(r
|
|
192
|
+
if re.search(r"\b(a\s+)?few\s+weeks?\s+ago\b", query, re.IGNORECASE):
|
|
203
193
|
# "a few weeks" = approximately 3-4 weeks, give range of 2-5 weeks
|
|
204
194
|
return constraint(reference_date - timedelta(weeks=5), reference_date - timedelta(weeks=2))
|
|
205
195
|
|
|
206
196
|
# "a couple of months ago" / "a few months ago" patterns
|
|
207
|
-
if re.search(r
|
|
197
|
+
if re.search(r"\b(a\s+)?couple\s+(of\s+)?months?\s+ago\b", query, re.IGNORECASE):
|
|
208
198
|
# "a couple of months" = approximately 2 months, give range of 1-3 months
|
|
209
199
|
return constraint(reference_date - timedelta(days=90), reference_date - timedelta(days=30))
|
|
210
200
|
|
|
211
|
-
if re.search(r
|
|
201
|
+
if re.search(r"\b(a\s+)?few\s+months?\s+ago\b", query, re.IGNORECASE):
|
|
212
202
|
# "a few months" = approximately 3-4 months, give range of 2-5 months
|
|
213
203
|
return constraint(reference_date - timedelta(days=150), reference_date - timedelta(days=60))
|
|
214
204
|
|
|
215
205
|
# Last week patterns (English, Spanish, Italian, French, German)
|
|
216
|
-
if re.search(
|
|
206
|
+
if re.search(
|
|
207
|
+
r"\b(last\s+week|la\s+semana\s+pasada|la\s+settimana\s+scorsa|la\s+semaine\s+derni[eè]re|letzte\s+woche)\b",
|
|
208
|
+
query,
|
|
209
|
+
re.IGNORECASE,
|
|
210
|
+
):
|
|
217
211
|
start = reference_date - timedelta(days=reference_date.weekday() + 7)
|
|
218
212
|
return constraint(start, start + timedelta(days=6))
|
|
219
213
|
|
|
220
214
|
# Last month patterns
|
|
221
|
-
if re.search(
|
|
215
|
+
if re.search(
|
|
216
|
+
r"\b(last\s+month|el\s+mes\s+pasado|il\s+mese\s+scorso|le\s+mois\s+dernier|letzten?\s+monat)\b",
|
|
217
|
+
query,
|
|
218
|
+
re.IGNORECASE,
|
|
219
|
+
):
|
|
222
220
|
first = reference_date.replace(day=1)
|
|
223
221
|
end = first - timedelta(days=1)
|
|
224
222
|
start = end.replace(day=1)
|
|
225
223
|
return constraint(start, end)
|
|
226
224
|
|
|
227
225
|
# Last year patterns
|
|
228
|
-
if re.search(
|
|
226
|
+
if re.search(
|
|
227
|
+
r"\b(last\s+year|el\s+a[ñn]o\s+pasado|l\'anno\s+scorso|l\'ann[ée]e\s+derni[eè]re|letztes?\s+jahr)\b",
|
|
228
|
+
query,
|
|
229
|
+
re.IGNORECASE,
|
|
230
|
+
):
|
|
229
231
|
year = reference_date.year - 1
|
|
230
232
|
return constraint(datetime(year, 1, 1), datetime(year, 12, 31))
|
|
231
233
|
|
|
232
234
|
# Last weekend patterns
|
|
233
|
-
if re.search(
|
|
235
|
+
if re.search(
|
|
236
|
+
r"\b(last\s+weekend|el\s+fin\s+de\s+semana\s+pasado|lo\s+scorso\s+fine\s+settimana|le\s+week-?end\s+dernier|letztes?\s+wochenende)\b",
|
|
237
|
+
query,
|
|
238
|
+
re.IGNORECASE,
|
|
239
|
+
):
|
|
234
240
|
days_since_sat = (reference_date.weekday() + 2) % 7
|
|
235
241
|
if days_since_sat == 0:
|
|
236
242
|
days_since_sat = 7
|
|
@@ -239,22 +245,22 @@ class DateparserQueryAnalyzer(QueryAnalyzer):
|
|
|
239
245
|
|
|
240
246
|
# Month + Year patterns (e.g., "June 2024", "junio 2024", "giugno 2024")
|
|
241
247
|
month_patterns = {
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
248
|
+
"january|enero|gennaio|janvier|januar": 1,
|
|
249
|
+
"february|febrero|febbraio|f[ée]vrier|februar": 2,
|
|
250
|
+
"march|marzo|mars|m[äa]rz": 3,
|
|
251
|
+
"april|abril|aprile|avril": 4,
|
|
252
|
+
"may|mayo|maggio|mai": 5,
|
|
253
|
+
"june|junio|giugno|juin|juni": 6,
|
|
254
|
+
"july|julio|luglio|juillet|juli": 7,
|
|
255
|
+
"august|agosto|ao[uû]t": 8,
|
|
256
|
+
"september|septiembre|settembre|septembre": 9,
|
|
257
|
+
"october|octubre|ottobre|octobre|oktober": 10,
|
|
258
|
+
"november|noviembre|novembre": 11,
|
|
259
|
+
"december|diciembre|dicembre|d[ée]cembre|dezember": 12,
|
|
254
260
|
}
|
|
255
261
|
|
|
256
262
|
for pattern, month_num in month_patterns.items():
|
|
257
|
-
match = re.search(rf
|
|
263
|
+
match = re.search(rf"\b({pattern})\s+(\d{{4}})\b", query, re.IGNORECASE)
|
|
258
264
|
if match:
|
|
259
265
|
year = int(match.group(2))
|
|
260
266
|
start = datetime(year, month_num, 1)
|
|
@@ -279,11 +285,7 @@ class TransformerQueryAnalyzer(QueryAnalyzer):
|
|
|
279
285
|
- Model size: ~80M params (~300MB download)
|
|
280
286
|
"""
|
|
281
287
|
|
|
282
|
-
def __init__(
|
|
283
|
-
self,
|
|
284
|
-
model_name: str = "google/flan-t5-small",
|
|
285
|
-
device: str = "cpu"
|
|
286
|
-
):
|
|
288
|
+
def __init__(self, model_name: str = "google/flan-t5-small", device: str = "cpu"):
|
|
287
289
|
"""
|
|
288
290
|
Initialize T5 query analyzer.
|
|
289
291
|
|
|
@@ -304,11 +306,10 @@ class TransformerQueryAnalyzer(QueryAnalyzer):
|
|
|
304
306
|
return
|
|
305
307
|
|
|
306
308
|
try:
|
|
307
|
-
from transformers import
|
|
309
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
308
310
|
except ImportError:
|
|
309
311
|
raise ImportError(
|
|
310
|
-
"transformers is required for TransformerQueryAnalyzer. "
|
|
311
|
-
"Install it with: pip install transformers"
|
|
312
|
+
"transformers is required for TransformerQueryAnalyzer. Install it with: pip install transformers"
|
|
312
313
|
)
|
|
313
314
|
|
|
314
315
|
logger.info(f"Loading query analyzer model: {self.model_name}...")
|
|
@@ -322,9 +323,7 @@ class TransformerQueryAnalyzer(QueryAnalyzer):
|
|
|
322
323
|
"""Lazy load the T5 model for temporal extraction (calls load())."""
|
|
323
324
|
self.load()
|
|
324
325
|
|
|
325
|
-
def _extract_with_rules(
|
|
326
|
-
self, query: str, reference_date: datetime
|
|
327
|
-
) -> Optional[TemporalConstraint]:
|
|
326
|
+
def _extract_with_rules(self, query: str, reference_date: datetime) -> TemporalConstraint | None:
|
|
328
327
|
"""
|
|
329
328
|
Extract temporal expressions using rule-based patterns.
|
|
330
329
|
|
|
@@ -332,6 +331,7 @@ class TransformerQueryAnalyzer(QueryAnalyzer):
|
|
|
332
331
|
patterns that need model-based extraction.
|
|
333
332
|
"""
|
|
334
333
|
import re
|
|
334
|
+
|
|
335
335
|
query_lower = query.lower()
|
|
336
336
|
|
|
337
337
|
def get_last_weekday(weekday: int) -> datetime:
|
|
@@ -343,50 +343,60 @@ class TransformerQueryAnalyzer(QueryAnalyzer):
|
|
|
343
343
|
def constraint(start: datetime, end: datetime) -> TemporalConstraint:
|
|
344
344
|
return TemporalConstraint(
|
|
345
345
|
start_date=start.replace(hour=0, minute=0, second=0, microsecond=0),
|
|
346
|
-
end_date=end.replace(hour=23, minute=59, second=59, microsecond=999999)
|
|
346
|
+
end_date=end.replace(hour=23, minute=59, second=59, microsecond=999999),
|
|
347
347
|
)
|
|
348
348
|
|
|
349
349
|
# Yesterday
|
|
350
|
-
if re.search(r
|
|
350
|
+
if re.search(r"\byesterday\b", query_lower):
|
|
351
351
|
d = reference_date - timedelta(days=1)
|
|
352
352
|
return constraint(d, d)
|
|
353
353
|
|
|
354
354
|
# Last week
|
|
355
|
-
if re.search(r
|
|
355
|
+
if re.search(r"\blast\s+week\b", query_lower):
|
|
356
356
|
start = reference_date - timedelta(days=reference_date.weekday() + 7)
|
|
357
357
|
return constraint(start, start + timedelta(days=6))
|
|
358
358
|
|
|
359
359
|
# Last month
|
|
360
|
-
if re.search(r
|
|
360
|
+
if re.search(r"\blast\s+month\b", query_lower):
|
|
361
361
|
first = reference_date.replace(day=1)
|
|
362
362
|
end = first - timedelta(days=1)
|
|
363
363
|
start = end.replace(day=1)
|
|
364
364
|
return constraint(start, end)
|
|
365
365
|
|
|
366
366
|
# Last year
|
|
367
|
-
if re.search(r
|
|
367
|
+
if re.search(r"\blast\s+year\b", query_lower):
|
|
368
368
|
y = reference_date.year - 1
|
|
369
369
|
return constraint(datetime(y, 1, 1), datetime(y, 12, 31))
|
|
370
370
|
|
|
371
371
|
# Last weekend
|
|
372
|
-
if re.search(r
|
|
372
|
+
if re.search(r"\blast\s+weekend\b", query_lower):
|
|
373
373
|
sat = get_last_weekday(5)
|
|
374
374
|
return constraint(sat, sat + timedelta(days=1))
|
|
375
375
|
|
|
376
376
|
# Last <weekday>
|
|
377
|
-
weekdays = {
|
|
378
|
-
'friday': 4, 'saturday': 5, 'sunday': 6}
|
|
377
|
+
weekdays = {"monday": 0, "tuesday": 1, "wednesday": 2, "thursday": 3, "friday": 4, "saturday": 5, "sunday": 6}
|
|
379
378
|
for name, num in weekdays.items():
|
|
380
|
-
if re.search(rf
|
|
379
|
+
if re.search(rf"\blast\s+{name}\b", query_lower):
|
|
381
380
|
d = get_last_weekday(num)
|
|
382
381
|
return constraint(d, d)
|
|
383
382
|
|
|
384
383
|
# Month + Year: "June 2024", "in March 2023"
|
|
385
|
-
months = {
|
|
386
|
-
|
|
387
|
-
|
|
384
|
+
months = {
|
|
385
|
+
"january": 1,
|
|
386
|
+
"february": 2,
|
|
387
|
+
"march": 3,
|
|
388
|
+
"april": 4,
|
|
389
|
+
"may": 5,
|
|
390
|
+
"june": 6,
|
|
391
|
+
"july": 7,
|
|
392
|
+
"august": 8,
|
|
393
|
+
"september": 9,
|
|
394
|
+
"october": 10,
|
|
395
|
+
"november": 11,
|
|
396
|
+
"december": 12,
|
|
397
|
+
}
|
|
388
398
|
for name, num in months.items():
|
|
389
|
-
match = re.search(rf
|
|
399
|
+
match = re.search(rf"\b{name}\s+(\d{{4}})\b", query_lower)
|
|
390
400
|
if match:
|
|
391
401
|
year = int(match.group(1))
|
|
392
402
|
if num == 12:
|
|
@@ -397,9 +407,7 @@ class TransformerQueryAnalyzer(QueryAnalyzer):
|
|
|
397
407
|
|
|
398
408
|
return None
|
|
399
409
|
|
|
400
|
-
def analyze(
|
|
401
|
-
self, query: str, reference_date: Optional[datetime] = None
|
|
402
|
-
) -> QueryAnalysis:
|
|
410
|
+
def analyze(self, query: str, reference_date: datetime | None = None) -> QueryAnalysis:
|
|
403
411
|
"""
|
|
404
412
|
Analyze query for temporal expressions.
|
|
405
413
|
|
|
@@ -435,11 +443,11 @@ class TransformerQueryAnalyzer(QueryAnalyzer):
|
|
|
435
443
|
last_saturday = get_last_weekday(5)
|
|
436
444
|
|
|
437
445
|
# Build prompt for T5
|
|
438
|
-
prompt = f"""Today is {reference_date.strftime(
|
|
446
|
+
prompt = f"""Today is {reference_date.strftime("%Y-%m-%d")}. Extract date range or "none".
|
|
439
447
|
|
|
440
448
|
June 2024 = 2024-06-01 to 2024-06-30
|
|
441
|
-
yesterday = {yesterday.strftime(
|
|
442
|
-
last Saturday = {last_saturday.strftime(
|
|
449
|
+
yesterday = {yesterday.strftime("%Y-%m-%d")} to {yesterday.strftime("%Y-%m-%d")}
|
|
450
|
+
last Saturday = {last_saturday.strftime("%Y-%m-%d")} to {last_saturday.strftime("%Y-%m-%d")}
|
|
443
451
|
what is the weather = none
|
|
444
452
|
{query} ="""
|
|
445
453
|
|
|
@@ -448,13 +456,7 @@ what is the weather = none
|
|
|
448
456
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
449
457
|
|
|
450
458
|
with self._no_grad():
|
|
451
|
-
outputs = self._model.generate(
|
|
452
|
-
**inputs,
|
|
453
|
-
max_new_tokens=30,
|
|
454
|
-
num_beams=3,
|
|
455
|
-
do_sample=False,
|
|
456
|
-
temperature=1.0
|
|
457
|
-
)
|
|
459
|
+
outputs = self._model.generate(**inputs, max_new_tokens=30, num_beams=3, do_sample=False, temperature=1.0)
|
|
458
460
|
|
|
459
461
|
result = self._tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
|
460
462
|
|
|
@@ -466,14 +468,14 @@ what is the weather = none
|
|
|
466
468
|
"""Get torch.no_grad context manager."""
|
|
467
469
|
try:
|
|
468
470
|
import torch
|
|
471
|
+
|
|
469
472
|
return torch.no_grad()
|
|
470
473
|
except ImportError:
|
|
471
474
|
from contextlib import nullcontext
|
|
475
|
+
|
|
472
476
|
return nullcontext()
|
|
473
477
|
|
|
474
|
-
def _parse_generated_output(
|
|
475
|
-
self, result: str, reference_date: datetime
|
|
476
|
-
) -> Optional[TemporalConstraint]:
|
|
478
|
+
def _parse_generated_output(self, result: str, reference_date: datetime) -> TemporalConstraint | None:
|
|
477
479
|
"""
|
|
478
480
|
Parse T5 generated output into TemporalConstraint.
|
|
479
481
|
|
|
@@ -492,7 +494,8 @@ what is the weather = none
|
|
|
492
494
|
try:
|
|
493
495
|
# Parse "YYYY-MM-DD to YYYY-MM-DD"
|
|
494
496
|
import re
|
|
495
|
-
|
|
497
|
+
|
|
498
|
+
pattern = r"(\d{4}-\d{2}-\d{2})\s+to\s+(\d{4}-\d{2}-\d{2})"
|
|
496
499
|
match = re.search(pattern, result, re.IGNORECASE)
|
|
497
500
|
|
|
498
501
|
if match:
|
|
@@ -513,7 +516,7 @@ what is the weather = none
|
|
|
513
516
|
|
|
514
517
|
return TemporalConstraint(start_date=start_date, end_date=end_date)
|
|
515
518
|
|
|
516
|
-
except (ValueError, AttributeError)
|
|
519
|
+
except (ValueError, AttributeError):
|
|
517
520
|
return None
|
|
518
521
|
|
|
519
522
|
return None
|