pulse-code 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. pulse/__init__.py +1 -0
  2. pulse/__main__.py +4 -0
  3. pulse/catalog.py +102 -0
  4. pulse/cli.py +984 -0
  5. pulse/data/catalog.json +1599 -0
  6. pulse/data/queries_index.json +328 -0
  7. pulse/data/variable_labels.json +1338 -0
  8. pulse/llm_builder.py +732 -0
  9. pulse/matcher.py +180 -0
  10. pulse/queries/aids-cases-by-year-1981-1999-req.xml +178 -0
  11. pulse/queries/births-by-year-1995-2002-req.xml +226 -0
  12. pulse/queries/births-by-year-2003-2006-req.xml +306 -0
  13. pulse/queries/births-by-year-2007-2024-req.xml +334 -0
  14. pulse/queries/cancer-incidence-by-site-by-year-1999-2022-req.xml +174 -0
  15. pulse/queries/cancer-mortality-by-site-by-year-2018-2023-req.xml +166 -0
  16. pulse/queries/covid-deaths-by-race-2020-2023-req.xml +529 -0
  17. pulse/queries/drug-deaths-by-month-1999-2020-req.xml +436 -0
  18. pulse/queries/drug-deaths-by-month-2018-2024-req.xml +544 -0
  19. pulse/queries/drug-deaths-by-year-1999-2020-req.xml +436 -0
  20. pulse/queries/drug-deaths-by-year-2018-2024-req.xml +536 -0
  21. pulse/queries/fentanyl-deaths-by-month-1999-2020-req.xml +430 -0
  22. pulse/queries/fentanyl-deaths-by-month-2018-2024-req.xml +530 -0
  23. pulse/queries/fetal-deaths-by-cause-by-year-2014-2024-req.xml +530 -0
  24. pulse/queries/fetal-deaths-by-year-2005-2024-req.xml +322 -0
  25. pulse/queries/heart-vs-cancer-by-sex-2018-2023-req.xml +532 -0
  26. pulse/queries/heat-wave-days-by-county-req.xml +154 -0
  27. pulse/queries/infant-mortality-2018-2023-req.xml +531 -0
  28. pulse/queries/infant-mortality-by-cause-by-year-2007-2023-req.xml +290 -0
  29. pulse/queries/maternal-mortality-by-year-1999-2020-req.xml +351 -0
  30. pulse/queries/maternal-mortality-by-year-2018-2024-req.xml +413 -0
  31. pulse/queries/mortality-by-race-sex-2018-2023-req.xml +490 -0
  32. pulse/queries/mortality-by-year-cause-1979-1998-req.xml +222 -0
  33. pulse/queries/mortality-by-year-cause-1999-2020-req.xml +434 -0
  34. pulse/queries/mortality-by-year-cause-2021-2024-req.xml +529 -0
  35. pulse/queries/opioid-overdose-deaths-2018-2024-req.xml +544 -0
  36. pulse/queries/pm25-by-year-2003-2011-req.xml +194 -0
  37. pulse/queries/provisional-births-by-month-2023-req.xml +854 -0
  38. pulse/queries/racial-mortality-gap-2018-2023-req.xml +531 -0
  39. pulse/queries/std-cases-by-disease-by-year-1984-2014-req.xml +178 -0
  40. pulse/queries/suicide-by-sex-1999-2020-req.xml +411 -0
  41. pulse/queries/suicide-by-sex-2021-2024-req.xml +551 -0
  42. pulse/queries/tb-cases-by-year-1993-2023-req.xml +206 -0
  43. pulse/queries/tick-borne-diseases-by-year-2016-2023-req.xml +125 -0
  44. pulse/queries/underlying-cause-mortality-by-year-1999-2020-req.xml +350 -0
  45. pulse/queries/unintentional-injuries-by-age-2018-2023-req.xml +531 -0
  46. pulse/templates/D10-base.xml +226 -0
  47. pulse/templates/D104-base.xml +142 -0
  48. pulse/templates/D117-base.xml +110 -0
  49. pulse/templates/D128-base.xml +182 -0
  50. pulse/templates/D140-base.xml +318 -0
  51. pulse/templates/D141-base.xml +454 -0
  52. pulse/templates/D149-base.xml +878 -0
  53. pulse/templates/D157-base.xml +490 -0
  54. pulse/templates/D158-base.xml +406 -0
  55. pulse/templates/D159-base.xml +774 -0
  56. pulse/templates/D16-base.xml +266 -0
  57. pulse/templates/D176-base.xml +526 -0
  58. pulse/templates/D178-base.xml +158 -0
  59. pulse/templates/D18-base.xml +262 -0
  60. pulse/templates/D192-base.xml +854 -0
  61. pulse/templates/D204-base.xml +142 -0
  62. pulse/templates/D23-base.xml +258 -0
  63. pulse/templates/D27-base.xml +342 -0
  64. pulse/templates/D31-base.xml +262 -0
  65. pulse/templates/D60-base.xml +274 -0
  66. pulse/templates/D61-base.xml +250 -0
  67. pulse/templates/D66-base.xml +378 -0
  68. pulse/templates/D69-base.xml +278 -0
  69. pulse/templates/D73-base.xml +182 -0
  70. pulse/templates/D74-base.xml +254 -0
  71. pulse/templates/D76-base.xml +350 -0
  72. pulse/templates/D77-base.xml +434 -0
  73. pulse/templates/D8-base.xml +314 -0
  74. pulse/templates/D80-base.xml +174 -0
  75. pulse/templates/D81-base.xml +178 -0
  76. pulse/wonder_client.py +161 -0
  77. pulse_code-1.0.1.dist-info/METADATA +249 -0
  78. pulse_code-1.0.1.dist-info/RECORD +82 -0
  79. pulse_code-1.0.1.dist-info/WHEEL +5 -0
  80. pulse_code-1.0.1.dist-info/entry_points.txt +2 -0
  81. pulse_code-1.0.1.dist-info/licenses/LICENSE +121 -0
  82. pulse_code-1.0.1.dist-info/top_level.txt +1 -0
pulse/llm_builder.py ADDED
@@ -0,0 +1,732 @@
1
+ """LLM-powered CDC WONDER query builder using Anthropic Claude."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import re
8
+ import xml.etree.ElementTree as ET
9
+ from pathlib import Path
10
+ from typing import NamedTuple, Optional
11
+
12
+ import anthropic
13
+ import httpx
14
+ from dotenv import load_dotenv
15
+ from pydantic import BaseModel, Field
16
+
17
+ load_dotenv()
18
+
19
+ _TEMPLATES_DIR = Path(__file__).parent / "templates"
20
+ _QUERIES_DIR = Path(__file__).parent / "queries"
21
+ _QUERIES_INDEX_PATH = Path(__file__).parent / "data" / "queries_index.json"
22
+
23
+ _dataset_query_index: Optional[dict[str, str]] = None
24
+
25
+
26
+ def _build_http_client() -> Optional[httpx.Client]:
27
+ """Build an httpx.Client routed through LLM_HTTP_PROXY, if set.
28
+
29
+ Supports socks5h:// (DNS resolved remotely, through the proxy) as well
30
+ as http(s):// proxy URLs — useful when the LLM endpoint isn't directly
31
+ reachable and needs to be bridged through a SOCKS proxy.
32
+ """
33
+ proxy = os.getenv("LLM_HTTP_PROXY")
34
+ if not proxy:
35
+ return None
36
+ return httpx.Client(proxy=proxy)
37
+
38
+
39
+ # Age variables — AAR is incompatible when grouping by these
40
+ _AGE_VARS = {
41
+ "D176.V5",
42
+ "D176.V51",
43
+ "D176.V52",
44
+ "D176.V6",
45
+ "D157.V5",
46
+ "D157.V51",
47
+ "D157.V52",
48
+ "D157.V6",
49
+ "D158.V5",
50
+ "D158.V51",
51
+ "D158.V52",
52
+ "D158.V6",
53
+ "D141.V5",
54
+ "D141.V51",
55
+ "D141.V52",
56
+ "D141.V6",
57
+ "D77.V5",
58
+ "D77.V51",
59
+ "D77.V52",
60
+ "D77.V6",
61
+ "D76.V5",
62
+ "D76.V51",
63
+ "D76.V52",
64
+ "D76.V6",
65
+ "D74.V5",
66
+ "D74.V6",
67
+ "D16.V5",
68
+ "D16.V6",
69
+ "D140.V5",
70
+ "D140.V6",
71
+ }
72
+
73
+ _SYSTEM_PROMPT = """\
74
+ You are a CDC WONDER query builder. Convert natural language into WONDER API XML queries.
75
+
76
+ ## Dataset Selection Guide
77
+
78
+ ### Mortality — Recent/Provisional (use for current trends)
79
+ - D176: Provisional Mortality 2018–present (weekly updates; default for recent mortality)
80
+ - D157: Final MCD+UCD Single Race 2018–2023 (finalized; single-race detail)
81
+ - D158: Final UCD Single Race 2018–2023 (no MCD filters; use for maternal mortality)
82
+
83
+ ### Mortality — Historical ICD-10 (1999–2020)
84
+ - D77: Multiple Cause of Death 1999–2020 (drug overdose deaths; MCD filters)
85
+ - D76: Underlying Cause of Death 1999–2020 (suicide, cause-specific; no MCD)
86
+ - D141: MCD with US-Mexico Border 1999–2020 (adds border/metro geography)
87
+
88
+ ### Mortality — Older ICD
89
+ - D140: Compressed Mortality 1999–2016 (simpler; no MCD)
90
+ - D16: Compressed Mortality 1979–1998 (ICD-9)
91
+ - D74: Compressed Mortality 1968–1978 (ICD-8)
92
+
93
+ ### Infant Mortality (Linked Birth/Death)
94
+ - D69: Linked Birth/Infant Death 2007–2023 (default for infant mortality)
95
+ - D159: Linked Birth/Infant Death Expanded 2017–2023 (more race/ethnicity detail)
96
+ - D31: Linked Birth/Infant Death 2003–2006
97
+ - D18: Linked Birth/Infant Death 1999–2002
98
+ - D23: Linked Birth/Infant Death 1995–1998
99
+
100
+ ### Natality (Live Births)
101
+ - D66: Natality 2007–2024 (default for birth data)
102
+ - D149: Natality Expanded 2016–2024 (single-race detail)
103
+ - D192: Provisional Natality 2023–present (latest; limited groupings)
104
+ - D27: Natality 2003–2006
105
+ - D10: Natality 1995–2002
106
+
107
+ ### Environmental / Climate
108
+ - D104: Heat Wave Days 1981–2010 (annual county-level)
109
+ - D60: NLDAS Air Temperatures/Heat Index 1979–2011
110
+ - D80: NLDAS Daily Sunlight 1979–2011
111
+ - D81: NLDAS Daily Precipitation 1979–2011
112
+ - D61: MODIS Land Surface Temperature 2003–2008
113
+ - D73: Fine Particulate Matter PM2.5 2003–2011
114
+
115
+ ### Vaccine Safety
116
+ - D8: VAERS 1990–present (adverse event reports, not incidence)
117
+
118
+ ## Group-By Variables (B_1 through B_5)
119
+
120
+ ### D176 (Provisional Mortality) — key B_ values
121
+ D176.V1-level1 Year
122
+ D176.V1-level2 Month
123
+ D176.V9-level1 Residence State
124
+ D176.V9-level2 Residence County
125
+ D176.V10-level1 Census Region
126
+ D176.V27-level1 HHS Region
127
+ D176.V19 2013 Urbanization
128
+ D176.V2-level1 ICD-10 Chapter (cause of death)
129
+ D176.V2-level2 ICD-10 Subcategory
130
+ D176.V13-level3 MCD Drug/Alcohol Cause Code
131
+ D176.V5 Ten-Year Age Groups
132
+ D176.V51 Five-Year Age Groups
133
+ D176.V6 Infant Age Groups
134
+ D176.V7 Gender/Sex
135
+ D176.V42 Race/Ethnicity (bridged)
136
+ D176.V43 Single Race (Hispanic origin)
137
+ D176.V44 Hispanic Origin
138
+
139
+ ### D77 / D76 (Historical Mortality 1999–2020) — key B_ values
140
+ D77.V1-level1 Year D76.V1-level1 Year
141
+ D77.V1-level2 Month D76.V1-level2 Month
142
+ D77.V9-level1 State D76.V9-level1 State
143
+ D77.V2-level1 ICD Chapter D76.V2-level1 ICD Chapter
144
+ D77.V13-level3 MCD Drug Code
145
+ D77.V5 Ten-Year Age D76.V5 Ten-Year Age
146
+ D77.V7 Gender/Sex D76.V7 Gender/Sex
147
+ D77.V8 Race (bridged) D76.V8 Race (bridged)
148
+
149
+ ### D158 (UCD Single Race 2018–2023) — key B_ values
150
+ D158.V1-level1 Year
151
+ D158.V1-level2 Month
152
+ D158.V9-level1 State
153
+ D158.V2-level1 ICD Chapter
154
+ D158.V2-level2 ICD Subcategory
155
+ D158.V5 Ten-Year Age
156
+ D158.V7 Gender/Sex
157
+ D158.V42 Single Race
158
+
159
+ ### D66 (Natality 2007–2024) — key B_ values
160
+ D66.V6-level1 Year
161
+ D66.V6-level2 Month
162
+ D66.V9-level1 State
163
+ D66.V2 Mother's Age
164
+ D66.V7 Race/Hispanic origin (4-category)
165
+ D66.V13 Gestational age (weekly)
166
+ D66.V14 Birth weight (grams)
167
+ D66.V5 Delivery method
168
+
169
+ ### D69 (Infant Mortality 2007–2023) — key B_ values
170
+ D69.V1-level1 Year
171
+ D69.V9-level1 State
172
+ D69.V2-level1 ICD Chapter (cause of death)
173
+ D69.V4 Age at death (neonatal/post-neonatal)
174
+ D69.V7 Gender
175
+ D69.V8 Race (bridged)
176
+
177
+ ### D8 (VAERS) — key B_ values
178
+ D8.V14-level1 Vaccine Type
179
+ D8.V14-level2 Vaccine (specific product)
180
+ D8.V13-level2 Symptom
181
+ D8.V2-level1 Year Received
182
+ D8.V1 Age Group
183
+ D8.V5 Sex
184
+ D8.V11 Event Category (Death, Hospitalized, Life Threatening)
185
+
186
+ ## Filters (F_* and V_*)
187
+
188
+ ### Common filter patterns (D176):
189
+ F_D176.V1 = *All* (or year codes like "2020","2021")
190
+ F_D176.V9 = *All* (state FIPS codes for specific states)
191
+ F_D176.V2 = *All* (all ICD chapters; or specific chapter codes)
192
+ F_D176.V13 = *All* (all drug codes; V_D176.V13 for specific ICD codes)
193
+ V_D176.V13 = T40.1\\nT40.2\\nT40.3\\nT40.4 (specific opioid codes — newline separated)
194
+ V_D176.V7 = M or F (sex filter)
195
+ V_D176.V42 = *All* (all races)
196
+
197
+ ### Drug ICD-10 codes (for V_*.V13 in D176/D77):
198
+ T40.1 Heroin
199
+ T40.2 Other opioids (oxycodone, hydrocodone, etc.)
200
+ T40.3 Methadone
201
+ T40.4 Other synthetic narcotics (fentanyl)
202
+ T40.5 Cocaine
203
+ T40.7 Cannabis
204
+ T43.6 Psychostimulants (meth, amphetamines, MDMA)
205
+
206
+ ### Suicide ICD-10 codes (for F_*.V2 underlying cause):
207
+ X60-X84 Intentional self-harm (ICD-10 chapter for suicide)
208
+
209
+ ### Maternal mortality (D158):
210
+ Underlying cause filter: O00-O99 (pregnancy/childbirth chapter)
211
+
212
+ ## Mode Selectors (must match active filter/groupby)
213
+ O_ucd = D{N}.V2 when filtering by ICD chapter
214
+ O_ucd = D{N}.V25 when filtering by drug/alcohol cause (simple)
215
+ O_mcd = D{N}.V13 when filtering by MCD drug codes
216
+ O_age = D{N}.V5 when grouping by 10-year age
217
+ O_age = D{N}.V51 when grouping by 5-year age
218
+ O_age = D{N}.V6 when grouping by infant age
219
+
220
+ ## Measures
221
+ M_1 = D{N}.M1 Deaths (or Births/Events)
222
+ M_2 = D{N}.M2 Population
223
+ M_3 = D{N}.M3 Crude Rate
224
+ M_9 = D{N}.M9 Age-Adjusted Rate (mortality only; disable with O_aar_enable=false when grouping by age)
225
+
226
+ ## Output Options
227
+ O_rate_per = 100000 rate denominator
228
+ O_show_totals = true include grand total row
229
+ O_aar_enable = false disable AAR (required when grouping by age)
230
+ O_aar = aar_none (goes with O_aar_enable=false)
231
+
232
+ ## Rules
233
+ 1. Select the most appropriate dataset based on topic and year range.
234
+ 2. Specify B_1..B_5 group-by slots — use *None* for unused slots.
235
+ 3. Set mode selectors (O_ucd/O_age) to match your active filter or group-by.
236
+ 4. Set O_aar_enable=false when grouping by any age variable.
237
+ 5. Output OVERRIDES ONLY — the base template fills in all boilerplate (V_*, I_*, finder-stage-*, VM_*).
238
+ 6. Do NOT output finder-stage-*, O_*_fmode, I_*, or VM_* — those come from the template.
239
+ 7. Use the build_comparison_query tool instead of build_wonder_query when the
240
+ request compares two or more distinct causes, subjects, or datasets (e.g.
241
+ "opioid deaths vs suicide deaths", "COVID deaths vs flu deaths by state").
242
+ Each sub-query in the comparison gets its own short label, dataset_id, and
243
+ parameters, following the same rules above.
244
+
245
+ ## Dataset-Specific Quirks
246
+ - D128 (STD Morbidity by Age/Race/Sex): the disease filter (V_D128.V3)
247
+ defaults to *All* (chlamydia + gonorrhea + syphilis together). CDC WONDER
248
+ requires Disease (D128.V3) to be one of the B_1..B_5 groupings whenever
249
+ more than one disease is in scope — otherwise it returns HTTP 500 ("must
250
+ be grouped by Disease when more than one disease is selected"). Either
251
+ include D128.V3 in the group-by, or restrict V_D128.V3 to a single
252
+ disease code.
253
+ """
254
+
255
+ _TOOL_SCHEMA = {
256
+ "name": "build_wonder_query",
257
+ "description": (
258
+ "Output OVERRIDES for a CDC WONDER XML query. "
259
+ "The base template fills in boilerplate. You only need B_1..B_5, "
260
+ "F_* filters, V_* value filters, O_ucd/O_age mode selectors, "
261
+ "O_aar_enable, and non-default measures."
262
+ ),
263
+ "input_schema": {
264
+ "type": "object",
265
+ "properties": {
266
+ "dataset_id": {
267
+ "type": "string",
268
+ "description": "CDC WONDER dataset code (e.g. D176, D77, D66)",
269
+ },
270
+ "parameters": {
271
+ "type": "array",
272
+ "items": {
273
+ "type": "object",
274
+ "properties": {
275
+ "name": {"type": "string"},
276
+ "values": {"type": "array", "items": {"type": "string"}},
277
+ },
278
+ "required": ["name", "values"],
279
+ },
280
+ "description": "Override parameters only (B_*, F_*, V_*, O_*, M_*)",
281
+ },
282
+ },
283
+ "required": ["dataset_id", "parameters"],
284
+ },
285
+ }
286
+
287
+ _COMPARISON_TOOL_SCHEMA = {
288
+ "name": "build_comparison_query",
289
+ "description": (
290
+ "Output OVERRIDES for two or more CDC WONDER XML queries to compare "
291
+ "distinct causes, subjects, or datasets side by side (e.g. opioid "
292
+ "deaths vs suicide deaths). Each sub-query follows the same override "
293
+ "rules as build_wonder_query."
294
+ ),
295
+ "input_schema": {
296
+ "type": "object",
297
+ "properties": {
298
+ "queries": {
299
+ "type": "array",
300
+ "minItems": 2,
301
+ "items": {
302
+ "type": "object",
303
+ "properties": {
304
+ "label": {
305
+ "type": "string",
306
+ "description": "Short human-readable label for this sub-query",
307
+ },
308
+ "dataset_id": {
309
+ "type": "string",
310
+ "description": "CDC WONDER dataset code (e.g. D176, D77, D66)",
311
+ },
312
+ "parameters": {
313
+ "type": "array",
314
+ "items": {
315
+ "type": "object",
316
+ "properties": {
317
+ "name": {"type": "string"},
318
+ "values": {
319
+ "type": "array",
320
+ "items": {"type": "string"},
321
+ },
322
+ },
323
+ "required": ["name", "values"],
324
+ },
325
+ },
326
+ },
327
+ "required": ["label", "dataset_id", "parameters"],
328
+ },
329
+ },
330
+ },
331
+ "required": ["queries"],
332
+ },
333
+ }
334
+
335
+
336
+ class WonderParam(BaseModel):
337
+ name: str
338
+ values: list[str]
339
+
340
+
341
+ class WonderRequest(BaseModel):
342
+ dataset_id: str
343
+ parameters: list[WonderParam] = Field(default_factory=list)
344
+
345
+ def to_xml(self) -> str:
346
+ lines = ['<?xml version="1.0" encoding="UTF-8"?><request-parameters>']
347
+ for p in self.parameters:
348
+ lines.append("\t<parameter>")
349
+ lines.append(f"\t\t<name>{p.name}</name>")
350
+ for v in p.values:
351
+ if v:
352
+ lines.append(f"\t\t<value>{v}</value>")
353
+ else:
354
+ lines.append("\t\t<value/>")
355
+ lines.append("\t</parameter>")
356
+ lines.append("</request-parameters>")
357
+ return "\n".join(lines)
358
+
359
+
360
+ class WonderRequestSet(BaseModel):
361
+ requests: list[WonderRequest]
362
+ labels: list[str]
363
+
364
+
365
+ def _build_user_content(
366
+ prompt: str,
367
+ base_xml: Optional[str],
368
+ reference_queries: Optional[list[tuple[str, str]]],
369
+ ) -> str:
370
+ parts = []
371
+ if reference_queries:
372
+ parts.append(
373
+ "Here are real working CDC WONDER queries for reference. Use them as "
374
+ "structural inspiration (parameter combos, mode selectors) when relevant "
375
+ "— do not copy them blindly if the request calls for something different."
376
+ )
377
+ for description, xml in reference_queries:
378
+ parts.append(f'<example description="{description}">\n{xml}\n</example>')
379
+ if base_xml:
380
+ parts.append(
381
+ f"Starting from this existing query, modify it as requested:\n\n"
382
+ f"<existing-query>\n{base_xml}\n</existing-query>\n\n"
383
+ f"Modification request: {prompt}"
384
+ )
385
+ else:
386
+ parts.append(prompt)
387
+ return "\n\n".join(parts)
388
+
389
+
390
+ def _bundled_query_for_dataset(dataset_id: str) -> Optional[str]:
391
+ """First bundled example query for a dataset, used as a merge target
392
+ when the dataset has no `{id}-base.xml` template (e.g. D202, D133,
393
+ D150 — see docs/building-xml-queries.md). Without something to merge
394
+ onto, required radio-button selectors (O_age, O_race, etc.) go missing
395
+ and CDC WONDER returns HTTP 500."""
396
+ global _dataset_query_index
397
+ if _dataset_query_index is None:
398
+ raw = json.loads(_QUERIES_INDEX_PATH.read_text())
399
+ index: dict[str, str] = {}
400
+ for q in raw["queries"]:
401
+ index.setdefault(q["dataset_id"], q["filename"])
402
+ _dataset_query_index = index
403
+
404
+ filename = _dataset_query_index.get(dataset_id)
405
+ if not filename:
406
+ return None
407
+ path = _QUERIES_DIR / filename
408
+ return path.read_text() if path.exists() else None
409
+
410
+
411
+ def _load_template(dataset_id: str) -> Optional[str]:
412
+ path = _TEMPLATES_DIR / f"{dataset_id}-base.xml"
413
+ if path.exists():
414
+ return path.read_text()
415
+ return _bundled_query_for_dataset(dataset_id)
416
+
417
+
418
+ def _parse_xml_params(xml_str: str) -> list[WonderParam]:
419
+ root = ET.fromstring(xml_str)
420
+ params = []
421
+ for param in root.findall("parameter"):
422
+ name_el = param.find("name")
423
+ if name_el is None or name_el.text is None:
424
+ continue
425
+ values = [v.text or "" for v in param.findall("value")]
426
+ params.append(WonderParam(name=name_el.text, values=values))
427
+ return params
428
+
429
+
430
+ def _merge_overrides(template_xml: str, overrides: list[WonderParam]) -> str:
431
+ base_params = _parse_xml_params(template_xml)
432
+ index = {p.name: i for i, p in enumerate(base_params)}
433
+
434
+ for override in overrides:
435
+ if override.name in index:
436
+ base_params[index[override.name]] = override
437
+ else:
438
+ base_params.append(override)
439
+
440
+ dataset_id = next(
441
+ (p.values[0] for p in base_params if p.name == "dataset_code"),
442
+ "D176",
443
+ )
444
+ return WonderRequest(dataset_id=dataset_id, parameters=base_params).to_xml()
445
+
446
+
447
+ def _finalize_request(raw: WonderRequest) -> WonderRequest:
448
+ """Merge raw LLM overrides onto the dataset's base template, if one exists."""
449
+ template = _load_template(raw.dataset_id)
450
+ if not template:
451
+ return raw
452
+ constrained = _apply_constraints(raw.parameters)
453
+ merged_xml = _merge_overrides(template, constrained)
454
+ merged_params = _parse_xml_params(merged_xml)
455
+ return WonderRequest(dataset_id=raw.dataset_id, parameters=merged_params)
456
+
457
+
458
+ def _apply_constraints(overrides: list[WonderParam]) -> list[WonderParam]:
459
+ """Enforce CDC WONDER rules: disable AAR when grouping by age."""
460
+ by_name = {p.name: p for p in overrides}
461
+ group_by_values = {
462
+ v for k, p in by_name.items() if k.startswith("B_") for v in p.values
463
+ }
464
+ if group_by_values & _AGE_VARS:
465
+ by_name["O_aar_enable"] = WonderParam(name="O_aar_enable", values=["false"])
466
+ by_name["O_aar"] = WonderParam(name="O_aar", values=["aar_none"])
467
+ by_name["O_aar_CI"] = WonderParam(name="O_aar_CI", values=["false"])
468
+ return list(by_name.values())
469
+
470
+
471
+ class ModelTurn(NamedTuple):
472
+ """A normalized model response, independent of LLM provider."""
473
+
474
+ tool_name: Optional[str]
475
+ tool_input: Optional[dict]
476
+ text: str
477
+ stop_reason: str
478
+
479
+
480
+ class _BaseQueryBuilder:
481
+ """Shared tool-calling loop for building/refining CDC WONDER queries.
482
+
483
+ Subclasses only need to implement `_call()` — everything else (dataset
484
+ template merging, AAR constraints, comparison-query assembly, the
485
+ end_turn retry) is provider-agnostic.
486
+ """
487
+
488
+ def _call(
489
+ self, tools: list[dict], messages: list[dict], max_tokens: int
490
+ ) -> ModelTurn:
491
+ raise NotImplementedError
492
+
493
+ def build(
494
+ self,
495
+ prompt: str,
496
+ base_xml: Optional[str] = None,
497
+ reference_queries: Optional[list[tuple[str, str]]] = None,
498
+ max_tokens: int = 4096,
499
+ on_thinking: Optional[callable] = None,
500
+ ) -> WonderRequest:
501
+ """
502
+ Build a WONDER query from natural language.
503
+
504
+ Args:
505
+ prompt: Natural language description of the desired query.
506
+ base_xml: Optional existing query XML to use as starting context for refinement.
507
+ reference_queries: Optional [(description, xml)] of real working
508
+ queries to use as structural inspiration (parameter combos,
509
+ mode selectors) — not to be copied blindly.
510
+ max_tokens: Max tokens for LLM.
511
+ on_thinking: Optional callback(text) called with LLM reasoning text.
512
+ """
513
+ user_content = _build_user_content(prompt, base_xml, reference_queries)
514
+ messages = [{"role": "user", "content": user_content}]
515
+
516
+ while True:
517
+ turn = self._call([_TOOL_SCHEMA], messages, max_tokens)
518
+
519
+ if turn.tool_name == "build_wonder_query":
520
+ return _finalize_request(WonderRequest(**turn.tool_input))
521
+
522
+ if on_thinking:
523
+ on_thinking(turn.text)
524
+
525
+ if turn.stop_reason == "end_turn":
526
+ dataset_matches = re.findall(r"\b(D\d+)\b", turn.text)
527
+ if dataset_matches:
528
+ messages.append(
529
+ {
530
+ "role": "user",
531
+ "content": f"Please proceed with dataset {dataset_matches[0]}.",
532
+ }
533
+ )
534
+ continue
535
+ raise ValueError(
536
+ f"LLM did not produce a query. Response: {turn.text[:300]}"
537
+ )
538
+
539
+ raise ValueError(f"Unexpected stop reason: {turn.stop_reason}")
540
+
541
+ def build_any(
542
+ self,
543
+ prompt: str,
544
+ reference_queries: Optional[list[tuple[str, str]]] = None,
545
+ max_tokens: int = 4096,
546
+ on_thinking: Optional[callable] = None,
547
+ ) -> WonderRequest | WonderRequestSet:
548
+ """
549
+ Build a WONDER query or, when the request compares multiple causes/
550
+ datasets, a WonderRequestSet of side-by-side sub-queries.
551
+
552
+ Args:
553
+ prompt: Natural language description of the desired query.
554
+ reference_queries: Optional [(description, xml)] of real working
555
+ queries to use as structural inspiration.
556
+ max_tokens: Max tokens for LLM.
557
+ on_thinking: Optional callback(text) called with LLM reasoning text.
558
+ """
559
+ user_content = _build_user_content(prompt, None, reference_queries)
560
+ messages = [{"role": "user", "content": user_content}]
561
+
562
+ while True:
563
+ turn = self._call(
564
+ [_TOOL_SCHEMA, _COMPARISON_TOOL_SCHEMA], messages, max_tokens
565
+ )
566
+
567
+ if turn.tool_name == "build_wonder_query":
568
+ return _finalize_request(WonderRequest(**turn.tool_input))
569
+
570
+ if turn.tool_name == "build_comparison_query":
571
+ sub_queries = turn.tool_input["queries"]
572
+ requests = [
573
+ _finalize_request(
574
+ WonderRequest(
575
+ dataset_id=sq["dataset_id"], parameters=sq["parameters"]
576
+ )
577
+ )
578
+ for sq in sub_queries
579
+ ]
580
+ labels = [sq["label"] for sq in sub_queries]
581
+ return WonderRequestSet(requests=requests, labels=labels)
582
+
583
+ if on_thinking:
584
+ on_thinking(turn.text)
585
+
586
+ if turn.stop_reason == "end_turn":
587
+ dataset_matches = re.findall(r"\b(D\d+)\b", turn.text)
588
+ if dataset_matches:
589
+ messages.append(
590
+ {
591
+ "role": "user",
592
+ "content": f"Please proceed with dataset {dataset_matches[0]}.",
593
+ }
594
+ )
595
+ continue
596
+ raise ValueError(
597
+ f"LLM did not produce a query. Response: {turn.text[:300]}"
598
+ )
599
+
600
+ raise ValueError(f"Unexpected stop reason: {turn.stop_reason}")
601
+
602
+
603
+ class LLMQueryBuilder(_BaseQueryBuilder):
604
+ """Build or refine CDC WONDER queries using Claude as the reasoning engine."""
605
+
606
+ def __init__(
607
+ self, api_key: Optional[str] = None, model: str = "claude-sonnet-4-6"
608
+ ) -> None:
609
+ self.client = anthropic.Anthropic(
610
+ api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
611
+ http_client=_build_http_client(),
612
+ )
613
+ self.model = model
614
+
615
+ def _call(
616
+ self, tools: list[dict], messages: list[dict], max_tokens: int
617
+ ) -> ModelTurn:
618
+ response = self.client.messages.create(
619
+ model=self.model,
620
+ max_tokens=max_tokens,
621
+ system=_SYSTEM_PROMPT,
622
+ tools=tools,
623
+ messages=messages,
624
+ )
625
+
626
+ messages.append({"role": "assistant", "content": response.content})
627
+
628
+ tool_block = next((b for b in response.content if b.type == "tool_use"), None)
629
+ if tool_block:
630
+ return ModelTurn(tool_block.name, tool_block.input, "", "tool_use")
631
+
632
+ text = "".join(getattr(b, "text", "") for b in response.content)
633
+ return ModelTurn(None, None, text, response.stop_reason)
634
+
635
+
636
+ class AzureOpenAIQueryBuilder(_BaseQueryBuilder):
637
+ """Build or refine CDC WONDER queries using an Azure OpenAI Foundry deployment (e.g. GPT-5.4)."""
638
+
639
+ def __init__(
640
+ self,
641
+ api_key: Optional[str] = None,
642
+ endpoint: Optional[str] = None,
643
+ deployment: Optional[str] = None,
644
+ api_version: Optional[str] = None,
645
+ ) -> None:
646
+ import openai
647
+
648
+ api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
649
+ endpoint = endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
650
+ deployment = deployment or os.getenv("AZURE_OPENAI_DEPLOYMENT")
651
+ api_version = api_version or os.getenv("AZURE_OPENAI_API_VERSION")
652
+
653
+ missing = [
654
+ name
655
+ for name, value in [
656
+ ("AZURE_OPENAI_API_KEY", api_key),
657
+ ("AZURE_OPENAI_ENDPOINT", endpoint),
658
+ ("AZURE_OPENAI_DEPLOYMENT", deployment),
659
+ ("AZURE_OPENAI_API_VERSION", api_version),
660
+ ]
661
+ if not value
662
+ ]
663
+ if missing:
664
+ raise RuntimeError(
665
+ "Missing Azure OpenAI configuration: "
666
+ + ", ".join(missing)
667
+ + ". Set these in your environment or a .env file."
668
+ )
669
+
670
+ self.client = openai.AzureOpenAI(
671
+ api_key=api_key,
672
+ azure_endpoint=endpoint,
673
+ api_version=api_version,
674
+ http_client=_build_http_client(),
675
+ )
676
+ self.deployment = deployment
677
+
678
+ @staticmethod
679
+ def _to_openai_tools(tools: list[dict]) -> list[dict]:
680
+ return [
681
+ {
682
+ "type": "function",
683
+ "function": {
684
+ "name": t["name"],
685
+ "description": t["description"],
686
+ "parameters": t["input_schema"],
687
+ },
688
+ }
689
+ for t in tools
690
+ ]
691
+
692
+ def _call(
693
+ self, tools: list[dict], messages: list[dict], max_tokens: int
694
+ ) -> ModelTurn:
695
+ import json
696
+
697
+ full_messages = [{"role": "system", "content": _SYSTEM_PROMPT}, *messages]
698
+ response = self.client.chat.completions.create(
699
+ model=self.deployment,
700
+ messages=full_messages,
701
+ tools=self._to_openai_tools(tools),
702
+ max_completion_tokens=max_tokens,
703
+ )
704
+
705
+ message = response.choices[0].message
706
+ messages.append(message.model_dump())
707
+
708
+ if message.tool_calls:
709
+ call = message.tool_calls[0]
710
+ return ModelTurn(
711
+ call.function.name, json.loads(call.function.arguments), "", "tool_use"
712
+ )
713
+
714
+ finish_reason = response.choices[0].finish_reason
715
+ stop_reason = "end_turn" if finish_reason == "stop" else finish_reason
716
+ return ModelTurn(None, None, message.content or "", stop_reason)
717
+
718
+
719
+ def get_query_builder(provider: Optional[str] = None) -> _BaseQueryBuilder:
720
+ """Return an LLM query builder for the configured provider.
721
+
722
+ Selected via the `provider` argument, falling back to the
723
+ `LLM_PROVIDER` env var, defaulting to "anthropic".
724
+ """
725
+ provider = (provider or os.getenv("LLM_PROVIDER", "anthropic")).lower()
726
+ if provider == "anthropic":
727
+ return LLMQueryBuilder()
728
+ if provider in ("azure_openai", "azure-openai", "azure"):
729
+ return AzureOpenAIQueryBuilder()
730
+ raise ValueError(
731
+ f"Unknown LLM_PROVIDER {provider!r}. Use 'anthropic' or 'azure_openai'."
732
+ )