learnx-cli 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. learnx_cli-0.3.0.dist-info/METADATA +240 -0
  2. learnx_cli-0.3.0.dist-info/RECORD +131 -0
  3. learnx_cli-0.3.0.dist-info/WHEEL +4 -0
  4. learnx_cli-0.3.0.dist-info/entry_points.txt +2 -0
  5. tutor/.env copy.example +4 -0
  6. tutor/__init__.py +0 -0
  7. tutor/__main__.py +4 -0
  8. tutor/assets/__init__.py +5 -0
  9. tutor/assets/html/fonts/Inter-Bold.woff2 +0 -0
  10. tutor/assets/html/fonts/Inter-Regular.woff2 +0 -0
  11. tutor/assets/html/fonts/Inter-SemiBold.woff2 +0 -0
  12. tutor/assets/html/fonts/JetBrainsMono-Regular.woff2 +0 -0
  13. tutor/assets/html/highlight-java.min.js +2 -0
  14. tutor/assets/html/highlight-javascript.min.js +2 -0
  15. tutor/assets/html/highlight-python.min.js +2 -0
  16. tutor/assets/html/highlight.min.js +17 -0
  17. tutor/assets/html/mermaid.min.js +31 -0
  18. tutor/assets/html/slide_base.css +464 -0
  19. tutor/assets/html/theme-learnx-dark.css +12 -0
  20. tutor/audio/__init__.py +0 -0
  21. tutor/audio/audio_builder.py +143 -0
  22. tutor/audio/sanitizer.py +9 -0
  23. tutor/audio/tts_renderer.py +54 -0
  24. tutor/cli/__init__.py +0 -0
  25. tutor/cli/commands.py +391 -0
  26. tutor/cli/logo.py +21 -0
  27. tutor/cli/playback_commands.py +239 -0
  28. tutor/cli/shell.py +91 -0
  29. tutor/cli/shell_context.py +18 -0
  30. tutor/cli/theme.py +39 -0
  31. tutor/cli/video_commands.py +123 -0
  32. tutor/config.py +122 -0
  33. tutor/conftest.py +5 -0
  34. tutor/constants.py +82 -0
  35. tutor/exceptions.py +26 -0
  36. tutor/generation/__init__.py +0 -0
  37. tutor/generation/assembler.py +81 -0
  38. tutor/generation/curriculum.py +97 -0
  39. tutor/generation/dialogue.py +172 -0
  40. tutor/generation/narrator.py +122 -0
  41. tutor/generation/segment_parser.py +223 -0
  42. tutor/generation/segment_planner.py +200 -0
  43. tutor/generation/visual_planner.py +205 -0
  44. tutor/infra/__init__.py +0 -0
  45. tutor/infra/llm.py +152 -0
  46. tutor/ingestion/__init__.py +0 -0
  47. tutor/ingestion/chunker.py +171 -0
  48. tutor/ingestion/doc_analyzer.py +41 -0
  49. tutor/ingestion/parse_content.py +19 -0
  50. tutor/ingestion/summarizer.py +51 -0
  51. tutor/inspector.py +117 -0
  52. tutor/llm_config.toml +58 -0
  53. tutor/models.py +147 -0
  54. tutor/player/__init__.py +0 -0
  55. tutor/player/input_handler.py +45 -0
  56. tutor/player/player.py +308 -0
  57. tutor/player/player_display.py +117 -0
  58. tutor/prompts/curriculum.txt +67 -0
  59. tutor/prompts/dialogue.txt +62 -0
  60. tutor/prompts/narrate.txt +34 -0
  61. tutor/prompts/qa.txt +17 -0
  62. tutor/prompts/summarize.txt +9 -0
  63. tutor/prompts/visual.txt +60 -0
  64. tutor/prompts/visual_v3.txt +91 -0
  65. tutor/qa/__init__.py +0 -0
  66. tutor/qa/qa.py +105 -0
  67. tutor/requirements-dev.txt +2 -0
  68. tutor/requirements.txt +12 -0
  69. tutor/sample_docs/headingless_large.md +1 -0
  70. tutor/sample_docs/headingless_test.md +1 -0
  71. tutor/sample_docs/java-basics.md +78 -0
  72. tutor/tests/__init__.py +0 -0
  73. tutor/tests/audio/__init__.py +0 -0
  74. tutor/tests/audio/test_audio_builder.py +106 -0
  75. tutor/tests/audio/test_sanitizer.py +41 -0
  76. tutor/tests/cli/__init__.py +0 -0
  77. tutor/tests/cli/test_commands.py +67 -0
  78. tutor/tests/cli/test_video_commands.py +190 -0
  79. tutor/tests/e2e/README.md +61 -0
  80. tutor/tests/e2e/__init__.py +0 -0
  81. tutor/tests/e2e/conftest.py +117 -0
  82. tutor/tests/e2e/fixtures/README.md +17 -0
  83. tutor/tests/e2e/fixtures/sample.md +13 -0
  84. tutor/tests/e2e/test_audio_quality.py +40 -0
  85. tutor/tests/e2e/test_av_sync.py +56 -0
  86. tutor/tests/e2e/test_pipeline_smoke.py +37 -0
  87. tutor/tests/e2e/test_slide_render.py +72 -0
  88. tutor/tests/e2e/test_video_streams.py +104 -0
  89. tutor/tests/generation/__init__.py +0 -0
  90. tutor/tests/generation/conftest.py +134 -0
  91. tutor/tests/generation/test_assembler.py +64 -0
  92. tutor/tests/generation/test_curriculum.py +107 -0
  93. tutor/tests/generation/test_narrator.py +165 -0
  94. tutor/tests/generation/test_segment_edge_cases.py +280 -0
  95. tutor/tests/generation/test_segment_planner.py +324 -0
  96. tutor/tests/generation/test_visual_planner.py +319 -0
  97. tutor/tests/ingestion/__init__.py +0 -0
  98. tutor/tests/ingestion/test_chunker.py +94 -0
  99. tutor/tests/ingestion/test_doc_analyzer.py +51 -0
  100. tutor/tests/player/__init__.py +0 -0
  101. tutor/tests/player/test_player_states.py +88 -0
  102. tutor/tests/test_assets.py +39 -0
  103. tutor/tests/test_models_visual.py +180 -0
  104. tutor/tests/visual/__init__.py +0 -0
  105. tutor/tests/visual/test_beat_timer.py +321 -0
  106. tutor/tests/visual/test_pipeline_integration.py +178 -0
  107. tutor/tests/visual/test_slide_renderer.py +298 -0
  108. tutor/tests/visual/test_subtitle_writer.py +165 -0
  109. tutor/tests/visual/test_video_assembler.py +108 -0
  110. tutor/tests/visual/test_visual_pipeline.py +270 -0
  111. tutor/tutor.py +365 -0
  112. tutor/visual/__init__.py +213 -0
  113. tutor/visual/beat_timer.py +222 -0
  114. tutor/visual/slide_renderer.py +236 -0
  115. tutor/visual/subtitle_writer.py +187 -0
  116. tutor/visual/templates/_base.html.j2 +40 -0
  117. tutor/visual/templates/analogy.html.j2 +21 -0
  118. tutor/visual/templates/callout.html.j2 +10 -0
  119. tutor/visual/templates/code_example.html.j2 +12 -0
  120. tutor/visual/templates/comparison.html.j2 +28 -0
  121. tutor/visual/templates/decision_guide.html.j2 +37 -0
  122. tutor/visual/templates/definition.html.j2 +13 -0
  123. tutor/visual/templates/diagram.html.j2 +11 -0
  124. tutor/visual/templates/hook_question.html.j2 +17 -0
  125. tutor/visual/templates/key_insight.html.j2 +9 -0
  126. tutor/visual/templates/memory_hook.html.j2 +7 -0
  127. tutor/visual/templates/outro.html.j2 +16 -0
  128. tutor/visual/templates/question_prompt.html.j2 +13 -0
  129. tutor/visual/templates/step_sequence.html.j2 +14 -0
  130. tutor/visual/templates/title_card.html.j2 +12 -0
  131. tutor/visual/video_assembler.py +299 -0
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ from collections.abc import Callable
7
+ from dataclasses import asdict
8
+ from pathlib import Path
9
+
10
+ from tutor.constants import SUMMARY_CACHE_DIR
11
+ from tutor.generation.segment_parser import fallback_segments, parse_segments_response
12
+ from tutor.infra.llm import load_prompt
13
+ from tutor.models import DialogueLine, SlideSegment
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ def plan_segments(
19
+ units_json_path: Path,
20
+ video_dir: Path,
21
+ llm_fn: Callable,
22
+ no_cache: bool = False,
23
+ ) -> dict[int, list[SlideSegment]]:
24
+ """For each teaching unit call LLM with its dialogue lines.
25
+
26
+ Returns dict keyed by unit_index (int) → list[SlideSegment] in line order.
27
+ Writes tutorial.segments.json to video_dir.
28
+ Skips units with no dialogue lines — logs a warning, does not crash.
29
+ Never raises; returns fallback segments on any LLM or parse error.
30
+ """
31
+ unit_lines = _load_unit_lines(units_json_path)
32
+ all_segments: dict[int, list[SlideSegment]] = {}
33
+
34
+ for unit_index, (concept, lines) in sorted(unit_lines.items()):
35
+ if not lines:
36
+ log.warning("Unit %d has no dialogue lines — skipping", unit_index)
37
+ continue
38
+
39
+ cache_file = _cache_path(unit_index, lines)
40
+ if no_cache and cache_file.exists():
41
+ cache_file.unlink()
42
+
43
+ segs = _plan_unit_segments(unit_index, concept, lines, llm_fn, cache_file)
44
+ all_segments[unit_index] = segs
45
+
46
+ video_dir.mkdir(parents=True, exist_ok=True)
47
+ segments_path = video_dir / "tutorial.segments.json"
48
+ segments_path.write_text(
49
+ json.dumps(
50
+ {
51
+ "version": 1,
52
+ "units": {str(k): [asdict(s) for s in v] for k, v in sorted(all_segments.items())},
53
+ },
54
+ indent=2,
55
+ ensure_ascii=False,
56
+ ),
57
+ encoding="utf-8",
58
+ )
59
+ log.info("Segments written: %s (%d units)", segments_path, len(all_segments))
60
+ return all_segments
61
+
62
+
63
+ def _plan_unit_segments(
64
+ unit_index: int,
65
+ unit_concept: str,
66
+ lines: list[DialogueLine],
67
+ llm_fn: Callable,
68
+ cache_file: Path,
69
+ ) -> list[SlideSegment]:
70
+ """Call LLM for one unit. Use file cache when available.
71
+
72
+ Returns fallback_segments() on any LLM or parse error.
73
+ """
74
+ if cache_file.exists():
75
+ log.debug("Cache hit for segments unit %d (%s)", unit_index, unit_concept)
76
+ try:
77
+ data = json.loads(cache_file.read_text(encoding="utf-8"))
78
+ return [SlideSegment(**s) for s in data]
79
+ except Exception:
80
+ pass
81
+
82
+ prompt = load_prompt("visual_v3.txt")
83
+ dialogue_text = "\n".join(f"{i}: [{ln.speaker}] {ln.text}" for i, ln in enumerate(lines))
84
+ unit_context = json.dumps(
85
+ {
86
+ "unit_index": unit_index,
87
+ "concept": unit_concept,
88
+ "total_lines": len(lines),
89
+ "dialogue": dialogue_text,
90
+ },
91
+ ensure_ascii=False,
92
+ )
93
+ messages = [
94
+ {"role": "system", "content": prompt},
95
+ {"role": "user", "content": unit_context},
96
+ ]
97
+
98
+ try:
99
+ raw = llm_fn(messages, call_type="segments")
100
+ segs = parse_segments_response(raw, unit_index, lines)
101
+ except Exception as exc:
102
+ log.warning(
103
+ "Segment planning failed for unit %d (%s): %s — using fallback",
104
+ unit_index,
105
+ unit_concept,
106
+ exc,
107
+ )
108
+ return fallback_segments(unit_index, lines)
109
+
110
+ cache_file.parent.mkdir(parents=True, exist_ok=True)
111
+ cache_file.write_text(
112
+ json.dumps([asdict(s) for s in segs], ensure_ascii=False),
113
+ encoding="utf-8",
114
+ )
115
+ return segs
116
+
117
+
118
+ def _cache_path(unit_index: int, lines: list[DialogueLine]) -> Path:
119
+ """MD5 of all dialogue texts + 'segments_v3' → .tutor_cache/<hash>.segments.json"""
120
+ content = "segments_v3" + "".join(ln.text for ln in lines)
121
+ digest = hashlib.md5(content.encode()).hexdigest()
122
+ return Path(SUMMARY_CACHE_DIR) / f"{digest}.segments.json"
123
+
124
+
125
+ def _load_unit_lines(units_json_path: Path) -> dict[int, tuple[str, list[DialogueLine]]]:
126
+ """Parse tutorial.units.json.
127
+
128
+ Returns dict: unit_number → (concept, list[DialogueLine]).
129
+ Only includes teaching units (unit_number >= 1).
130
+ Falls back to script.txt + timing.json when units.json has no lines.
131
+ """
132
+ import re
133
+
134
+ raw_units = json.loads(units_json_path.read_text(encoding="utf-8"))
135
+ result: dict[int, tuple[str, list[DialogueLine]]] = {}
136
+ for u in raw_units:
137
+ unit_num = int(u.get("unit", 0))
138
+ if unit_num < 1:
139
+ continue
140
+ concept = str(u.get("concept", ""))
141
+ raw_lines = u.get("lines", [])
142
+ lines = [DialogueLine(**ln) for ln in raw_lines]
143
+ result[unit_num] = (concept, lines)
144
+
145
+ if any(len(lines) > 0 for _, lines in result.values()):
146
+ return result
147
+
148
+ script_path = units_json_path.parent / "tutorial.script.txt"
149
+ if not script_path.exists():
150
+ return result
151
+
152
+ speaker_re = re.compile(r"^(ALEX|MAYA|SAM):\s*(.+)$")
153
+ all_pairs = [
154
+ (m.group(1), m.group(2))
155
+ for ln in script_path.read_text(encoding="utf-8").splitlines()
156
+ if (m := speaker_re.match(ln.strip()))
157
+ ]
158
+ if not all_pairs:
159
+ return result
160
+
161
+ timing_path = units_json_path.parent / "tutorial.timing.json"
162
+ lines_per_unit: dict[int, int] = {}
163
+ if timing_path.exists():
164
+ try:
165
+ timing = json.loads(timing_path.read_text(encoding="utf-8"))
166
+ if timing.get("version") == 1:
167
+ for uk, entries in timing.get("units", {}).items():
168
+ lines_per_unit[int(uk)] = len(entries)
169
+ except Exception:
170
+ pass
171
+
172
+ if lines_per_unit and all(u in lines_per_unit for u in result):
173
+ n_teaching = sum(lines_per_unit.values())
174
+ n_non = max(0, len(all_pairs) - n_teaching)
175
+ n_intro = (n_non + 1) // 2
176
+ cursor = n_intro
177
+ for unit_num in sorted(result.keys()):
178
+ concept = result[unit_num][0]
179
+ count = lines_per_unit[unit_num]
180
+ pairs = all_pairs[cursor : cursor + count]
181
+ cursor += count
182
+ result[unit_num] = (
183
+ concept,
184
+ [DialogueLine(speaker=s, text=t, unit_number=unit_num) for s, t in pairs],
185
+ )
186
+ else:
187
+ n_units = len(result)
188
+ n_lines = len(all_pairs)
189
+ per_unit = max(1, n_lines // max(n_units, 1))
190
+ for i, unit_num in enumerate(sorted(result.keys())):
191
+ concept = result[unit_num][0]
192
+ start = i * per_unit
193
+ end = n_lines if i == n_units - 1 else min(start + per_unit, n_lines)
194
+ pairs = all_pairs[start:end]
195
+ result[unit_num] = (
196
+ concept,
197
+ [DialogueLine(speaker=s, text=t, unit_number=unit_num) for s, t in pairs],
198
+ )
199
+
200
+ return result
@@ -0,0 +1,205 @@
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ from dataclasses import asdict
5
+ from pathlib import Path
6
+
7
+ from tutor.constants import SUMMARY_CACHE_DIR
8
+ from tutor.exceptions import LLMError
9
+ from tutor.infra.llm import LLMFn, load_prompt, parse_json_response
10
+ from tutor.models import TeachingUnit, VisualSpec
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ VISUAL_PROMPT_VERSION = "visual_v1"
15
+ VALID_DIAGRAM_TYPES = frozenset(
16
+ {"class_diagram", "flowchart", "code_comparison", "concept_map", "none"}
17
+ )
18
+
19
+
20
+ def plan_visuals(
21
+ units_json_path: Path,
22
+ doc_title: str,
23
+ session: str,
24
+ llm_fn: LLMFn,
25
+ difficulty: str,
26
+ video_dir: Path,
27
+ no_cache: bool = False,
28
+ ) -> list[VisualSpec]:
29
+ """
30
+ Read units from units_json_path, generate one VisualSpec per unit via LLM.
31
+ Returns [title_card, unit_1, ..., unit_N, outro].
32
+ Writes tutorial.visuals.json to video_dir.
33
+ """
34
+ raw_units = json.loads(units_json_path.read_text(encoding="utf-8"))
35
+ for u in raw_units:
36
+ u.setdefault("prerequisite_concepts", [])
37
+ units = [TeachingUnit(**u) for u in raw_units]
38
+
39
+ video_dir.mkdir(parents=True, exist_ok=True)
40
+
41
+ specs: list[VisualSpec] = [_build_title_card(doc_title, units, session)]
42
+ for unit in units:
43
+ cache_file = _cache_path(unit, difficulty)
44
+ if no_cache and cache_file.exists():
45
+ cache_file.unlink()
46
+ specs.append(_plan_unit(unit, llm_fn, difficulty, cache_file))
47
+ specs.append(_build_outro(units))
48
+
49
+ visuals_path = video_dir / "tutorial.visuals.json"
50
+ visuals_path.write_text(
51
+ json.dumps([asdict(s) for s in specs], indent=2, ensure_ascii=False),
52
+ encoding="utf-8",
53
+ )
54
+ log.info("Visual specs written to %s (%d entries)", visuals_path, len(specs))
55
+ return specs
56
+
57
+
58
+ def _plan_unit(
59
+ unit: TeachingUnit,
60
+ llm_fn: LLMFn,
61
+ difficulty: str,
62
+ cache_file: Path,
63
+ ) -> VisualSpec:
64
+ if cache_file.exists():
65
+ log.debug("Cache hit for visual spec unit %d (%s)", unit.unit, unit.concept)
66
+ data = json.loads(cache_file.read_text(encoding="utf-8"))
67
+ return VisualSpec(**data)
68
+
69
+ prompt = load_prompt("visual.txt")
70
+ unit_context = json.dumps(
71
+ {
72
+ "concept": unit.concept,
73
+ "key_facts": unit.key_facts,
74
+ "common_misconception": unit.common_misconception,
75
+ "good_analogy": unit.good_analogy,
76
+ "memory_hook": unit.memory_hook,
77
+ "word_budget": unit.word_budget,
78
+ "difficulty": difficulty,
79
+ },
80
+ indent=2,
81
+ )
82
+ messages = [
83
+ {"role": "system", "content": prompt},
84
+ {"role": "user", "content": unit_context},
85
+ ]
86
+
87
+ log.info("Generating visual spec for unit %d: %s", unit.unit, unit.concept)
88
+ try:
89
+ raw = llm_fn(messages, call_type="visual")
90
+ spec = _parse_visual_response(raw, unit)
91
+ except Exception as exc:
92
+ log.warning(
93
+ "Visual spec failed for unit %d (%s): %s — using fallback",
94
+ unit.unit,
95
+ unit.concept,
96
+ exc,
97
+ )
98
+ spec = _fallback_spec(unit)
99
+
100
+ cache_file.parent.mkdir(parents=True, exist_ok=True)
101
+ cache_file.write_text(
102
+ json.dumps(asdict(spec), ensure_ascii=False),
103
+ encoding="utf-8",
104
+ )
105
+ return spec
106
+
107
+
108
+ def _parse_visual_response(raw: str, unit: TeachingUnit) -> VisualSpec:
109
+ try:
110
+ data = parse_json_response(raw)
111
+ except LLMError:
112
+ log.warning("Cannot parse visual spec JSON for unit %d", unit.unit)
113
+ return _fallback_spec(unit)
114
+
115
+ if not isinstance(data, dict):
116
+ log.warning("Visual spec is not a JSON object for unit %d", unit.unit)
117
+ return _fallback_spec(unit)
118
+
119
+ diagram_type = data.get("diagram_type", "none")
120
+ if diagram_type not in VALID_DIAGRAM_TYPES:
121
+ log.warning(
122
+ "Unknown diagram_type %r for unit %d — falling back to 'none'", diagram_type, unit.unit
123
+ )
124
+ diagram_type = "none"
125
+
126
+ diagram_spec = data.get("diagram_spec")
127
+ diagram_type, diagram_spec = _validate_diagram(diagram_type, diagram_spec, unit.unit)
128
+
129
+ return VisualSpec(
130
+ unit_index=unit.unit,
131
+ slide_type="unit",
132
+ concept=unit.concept,
133
+ hook_question=str(data.get("hook_question", "")),
134
+ key_points=list(data.get("key_points", unit.key_facts[:5])),
135
+ code_snippet=data.get("code_snippet") or None,
136
+ diagram_type=diagram_type,
137
+ diagram_spec=diagram_spec, # type: ignore[arg-type]
138
+ memory_hook=str(data.get("memory_hook", unit.memory_hook)),
139
+ analogy=unit.good_analogy,
140
+ )
141
+
142
+
143
+ def _validate_diagram(diagram_type: str, diagram_spec: object, unit_idx: int) -> tuple[str, object]:
144
+ if diagram_type in ("class_diagram", "flowchart", "concept_map"):
145
+ if not isinstance(diagram_spec, str) or not _looks_like_dot(diagram_spec):
146
+ log.warning("diagram_spec for unit %d is not valid DOT — setting to 'none'", unit_idx)
147
+ return "none", None
148
+ elif diagram_type == "code_comparison":
149
+ if not isinstance(diagram_spec, dict) or not all(
150
+ k in diagram_spec for k in ("wrong", "right")
151
+ ):
152
+ log.warning(
153
+ "code_comparison spec for unit %d is not a valid dict — setting to 'none'", unit_idx
154
+ )
155
+ return "none", None
156
+ else:
157
+ diagram_spec = None
158
+ return diagram_type, diagram_spec
159
+
160
+
161
+ def _fallback_spec(unit: TeachingUnit) -> VisualSpec:
162
+ return VisualSpec(
163
+ unit_index=unit.unit,
164
+ slide_type="unit",
165
+ concept=unit.concept,
166
+ hook_question=f"What do you know about {unit.concept}?",
167
+ key_points=unit.key_facts[:5],
168
+ code_snippet=None,
169
+ diagram_type="none",
170
+ diagram_spec=None,
171
+ memory_hook=unit.memory_hook,
172
+ analogy=unit.good_analogy,
173
+ )
174
+
175
+
176
+ def _build_title_card(doc_title: str, units: list[TeachingUnit], doc_source: str) -> VisualSpec:
177
+ n = len(units)
178
+ return VisualSpec(
179
+ unit_index=0,
180
+ slide_type="title_card",
181
+ title=doc_title,
182
+ subtitle=f"{n} unit{'s' if n != 1 else ''} - beginner",
183
+ doc_source=doc_source,
184
+ )
185
+
186
+
187
+ def _build_outro(units: list[TeachingUnit]) -> VisualSpec:
188
+ return VisualSpec(
189
+ unit_index=len(units) + 1,
190
+ slide_type="outro",
191
+ memory_hooks=[u.memory_hook for u in units if u.memory_hook],
192
+ session_stats=f"{len(units)} unit{'s' if len(units) != 1 else ''}",
193
+ )
194
+
195
+
196
+ def _cache_path(unit: TeachingUnit, difficulty: str) -> Path:
197
+ key = hashlib.md5(
198
+ (unit.concept + str(unit.key_facts) + difficulty + VISUAL_PROMPT_VERSION).encode()
199
+ ).hexdigest()
200
+ return Path(SUMMARY_CACHE_DIR) / f"{key}.visual.json"
201
+
202
+
203
+ def _looks_like_dot(text: str) -> bool:
204
+ stripped = text.strip()
205
+ return stripped.startswith(("digraph", "graph", "strict"))
File without changes
tutor/infra/llm.py ADDED
@@ -0,0 +1,152 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ import time
5
+ import tomllib
6
+ from collections.abc import Callable
7
+ from pathlib import Path
8
+ from typing import Any, TypeAlias
9
+
10
+ from openai import OpenAI
11
+
12
+ from tutor.config import Config
13
+ from tutor.exceptions import ConfigError, LLMError
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+ LLMFn: TypeAlias = Callable[..., str]
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Config loading — reads tutor/llm_config.toml at import time
21
+ # ---------------------------------------------------------------------------
22
+
23
+ _CONFIG_PATH = Path(__file__).parent.parent / "llm_config.toml"
24
+
25
+
26
+ def _load() -> dict[str, Any]:
27
+ with open(_CONFIG_PATH, "rb") as fh:
28
+ return tomllib.load(fh)
29
+
30
+
31
+ _cfg = _load()
32
+
33
+ # Public dicts built from the TOML — used by chat() below and exported for
34
+ # other modules that need to read limits (dialogue.py, summarizer.py).
35
+ MODEL_MAP: dict[tuple[str, str], str] = {
36
+ (provider, call_type): model
37
+ for provider, calls in _cfg["providers"].items()
38
+ for call_type, model in calls.items()
39
+ }
40
+
41
+ MAX_TOKENS_MAP: dict[str, int] = _cfg["max_tokens"]
42
+ LIMITS: dict[str, int] = _cfg["limits"]
43
+
44
+ _temperature: float = _cfg["llm"]["temperature"]
45
+ _retry_count: int = _cfg["llm"]["retry_count"]
46
+ _retry_delay_s: float = _cfg["llm"]["retry_delay_s"]
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Client factory
51
+ # ---------------------------------------------------------------------------
52
+
53
+
54
+ def _build_client(provider: str, config: Config) -> OpenAI:
55
+ if provider == "groq":
56
+ if not config.groq_api_key:
57
+ raise ConfigError("GROQ_API_KEY not set. Add it to tutor/.env")
58
+ return OpenAI(
59
+ api_key=config.groq_api_key,
60
+ base_url="https://api.groq.com/openai/v1",
61
+ )
62
+ if provider == "openrouter":
63
+ if not config.openrouter_api_key:
64
+ raise ConfigError(
65
+ "OPENROUTER_API_KEY not set.\n"
66
+ " Get a free key at openrouter.ai and add OPENROUTER_API_KEY to tutor/.env"
67
+ )
68
+ return OpenAI(
69
+ api_key=config.openrouter_api_key,
70
+ base_url="https://openrouter.ai/api/v1",
71
+ default_headers={"HTTP-Referer": "http://localhost"},
72
+ timeout=120.0,
73
+ )
74
+ raise ConfigError(f"Unknown provider: {provider!r}. Use 'groq' or 'openrouter'.")
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Chat
79
+ # ---------------------------------------------------------------------------
80
+
81
+
82
+ def chat(
83
+ messages: list[dict[str, str]],
84
+ config: Config,
85
+ provider: str = "groq",
86
+ call_type: str = "dialogue",
87
+ ) -> str:
88
+ model = MODEL_MAP.get((provider, call_type))
89
+ if model is None:
90
+ raise LLMError(f"No model configured for ({provider!r}, {call_type!r}) in llm_config.toml")
91
+
92
+ client = _build_client(provider, config)
93
+ log.debug("LLM call provider=%s call_type=%s model=%s", provider, call_type, model)
94
+
95
+ max_tokens = MAX_TOKENS_MAP.get(call_type, 1_000)
96
+
97
+ for attempt in range(_retry_count):
98
+ try:
99
+ response = client.chat.completions.create(
100
+ model=model,
101
+ messages=messages, # type: ignore[arg-type]
102
+ temperature=_temperature,
103
+ max_tokens=max_tokens,
104
+ )
105
+ content = response.choices[0].message.content
106
+ assert content is not None, "LLM returned empty content"
107
+ log.debug("LLM response (first 200 chars): %s", content[:200])
108
+ return content
109
+ except Exception as e:
110
+ status = getattr(e, "status_code", None)
111
+ if status in (400, 401, 403):
112
+ raise LLMError(f"Auth/request error ({status}): {e}") from e
113
+ if status == 413:
114
+ raise LLMError(
115
+ f"Request too large for {model}.\n"
116
+ f" Lower max_source_tokens or max_tokens.{call_type} in llm_config.toml."
117
+ ) from e
118
+ if attempt < _retry_count - 1:
119
+ log.warning("LLM call failed (%s), retrying in %.1fs...", e, _retry_delay_s)
120
+ time.sleep(_retry_delay_s)
121
+ continue
122
+ raise LLMError(f"LLM call failed after {_retry_count} attempts: {e}") from e
123
+
124
+ raise LLMError("Unreachable")
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # Helpers
129
+ # ---------------------------------------------------------------------------
130
+
131
+
132
+ def parse_json_response(raw: str) -> object:
133
+ text = re.sub(r"```(?:json)?\s*", "", raw).strip().rstrip("`").strip()
134
+
135
+ try:
136
+ return json.loads(text)
137
+ except json.JSONDecodeError:
138
+ pass
139
+
140
+ match = re.search(r"(\[.*\]|\{.*\})", text, re.DOTALL)
141
+ if match:
142
+ try:
143
+ return json.loads(match.group(1))
144
+ except json.JSONDecodeError:
145
+ pass
146
+
147
+ raise LLMError(f"Could not parse JSON from response: {raw[:200]}")
148
+
149
+
150
+ def load_prompt(name: str) -> str:
151
+ prompts_dir = Path(__file__).parent.parent / "prompts"
152
+ return (prompts_dir / name).read_text(encoding="utf-8")
File without changes