QuizGenerator 0.1.4__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- QuizGenerator/contentast.py +415 -284
- QuizGenerator/misc.py +96 -8
- QuizGenerator/mixins.py +10 -1
- QuizGenerator/premade_questions/cst334/memory_questions.py +1 -1
- QuizGenerator/premade_questions/cst334/persistence_questions.py +78 -23
- QuizGenerator/premade_questions/cst334/process.py +1 -2
- QuizGenerator/premade_questions/cst463/models/__init__.py +0 -0
- QuizGenerator/premade_questions/cst463/models/attention.py +192 -0
- QuizGenerator/premade_questions/cst463/models/cnns.py +186 -0
- QuizGenerator/premade_questions/cst463/models/matrices.py +24 -0
- QuizGenerator/premade_questions/cst463/models/rnns.py +202 -0
- QuizGenerator/premade_questions/cst463/models/text.py +201 -0
- QuizGenerator/premade_questions/cst463/models/weight_counting.py +227 -0
- QuizGenerator/premade_questions/cst463/neural-network-basics/neural_network_questions.py +138 -94
- QuizGenerator/question.py +3 -2
- QuizGenerator/quiz.py +0 -1
- {quizgenerator-0.1.4.dist-info → quizgenerator-0.3.1.dist-info}/METADATA +3 -1
- {quizgenerator-0.1.4.dist-info → quizgenerator-0.3.1.dist-info}/RECORD +21 -14
- {quizgenerator-0.1.4.dist-info → quizgenerator-0.3.1.dist-info}/WHEEL +1 -1
- {quizgenerator-0.1.4.dist-info → quizgenerator-0.3.1.dist-info}/entry_points.txt +0 -0
- {quizgenerator-0.1.4.dist-info → quizgenerator-0.3.1.dist-info}/licenses/LICENSE +0 -0
QuizGenerator/misc.py
CHANGED
|
@@ -6,16 +6,21 @@ import enum
|
|
|
6
6
|
import itertools
|
|
7
7
|
import logging
|
|
8
8
|
import math
|
|
9
|
+
import numpy as np
|
|
9
10
|
from typing import List, Dict, Tuple, Any
|
|
10
11
|
|
|
11
12
|
import fractions
|
|
12
13
|
|
|
14
|
+
from QuizGenerator.contentast import ContentAST
|
|
15
|
+
|
|
13
16
|
log = logging.getLogger(__name__)
|
|
14
17
|
|
|
15
18
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
+
def fix_negative_zero(value):
|
|
20
|
+
"""Convert -0.0 to 0.0 to avoid confusing display."""
|
|
21
|
+
if isinstance(value, (int, float)):
|
|
22
|
+
return 0.0 if value == 0 else value
|
|
23
|
+
return value
|
|
19
24
|
|
|
20
25
|
|
|
21
26
|
class Answer:
|
|
@@ -38,6 +43,7 @@ class Answer:
|
|
|
38
43
|
AUTOFLOAT = enum.auto()
|
|
39
44
|
LIST = enum.auto()
|
|
40
45
|
VECTOR = enum.auto()
|
|
46
|
+
MATRIX = enum.auto()
|
|
41
47
|
|
|
42
48
|
|
|
43
49
|
def __init__(
|
|
@@ -200,6 +206,7 @@ class Answer:
|
|
|
200
206
|
}
|
|
201
207
|
for possible_state in [self.value] #itertools.permutations(self.value)
|
|
202
208
|
]
|
|
209
|
+
|
|
203
210
|
else:
|
|
204
211
|
# For string answers, check if value is a list of acceptable alternatives
|
|
205
212
|
if isinstance(self.value, list):
|
|
@@ -256,13 +263,16 @@ class Answer:
|
|
|
256
263
|
|
|
257
264
|
elif self.variable_kind == Answer.VariableKind.AUTOFLOAT:
|
|
258
265
|
# Round to default precision for readability
|
|
259
|
-
|
|
266
|
+
rounded = round(self.value, self.DEFAULT_ROUNDING_DIGITS)
|
|
267
|
+
return f"{fix_negative_zero(rounded)}"
|
|
260
268
|
|
|
261
269
|
elif self.variable_kind == Answer.VariableKind.FLOAT:
|
|
262
270
|
# Round to default precision
|
|
263
271
|
if isinstance(self.value, (list, tuple)):
|
|
264
|
-
|
|
265
|
-
|
|
272
|
+
rounded = round(self.value[0], self.DEFAULT_ROUNDING_DIGITS)
|
|
273
|
+
return f"{fix_negative_zero(rounded)}"
|
|
274
|
+
rounded = round(self.value, self.DEFAULT_ROUNDING_DIGITS)
|
|
275
|
+
return f"{fix_negative_zero(rounded)}"
|
|
266
276
|
|
|
267
277
|
elif self.variable_kind == Answer.VariableKind.INT:
|
|
268
278
|
return str(int(self.value))
|
|
@@ -272,12 +282,17 @@ class Answer:
|
|
|
272
282
|
|
|
273
283
|
elif self.variable_kind == Answer.VariableKind.VECTOR:
|
|
274
284
|
# Format as comma-separated rounded values
|
|
275
|
-
return ", ".join(str(round(v, self.DEFAULT_ROUNDING_DIGITS)) for v in self.value)
|
|
285
|
+
return ", ".join(str(fix_negative_zero(round(v, self.DEFAULT_ROUNDING_DIGITS))) for v in self.value)
|
|
276
286
|
|
|
277
287
|
else:
|
|
278
288
|
# Default: use display or value
|
|
279
289
|
return str(self.display if hasattr(self, 'display') else self.value)
|
|
280
290
|
|
|
291
|
+
def get_ast_element(self, label=None):
|
|
292
|
+
from QuizGenerator.contentast import ContentAST
|
|
293
|
+
|
|
294
|
+
return ContentAST.Answer(answer=self, label=label) # todo fix label
|
|
295
|
+
|
|
281
296
|
# Factory methods for common answer types
|
|
282
297
|
@classmethod
|
|
283
298
|
def binary_hex(cls, key: str, value: int, length: int = None, **kwargs) -> 'Answer':
|
|
@@ -404,9 +419,18 @@ class Answer:
|
|
|
404
419
|
**kwargs
|
|
405
420
|
)
|
|
406
421
|
|
|
422
|
+
@classmethod
|
|
423
|
+
def matrix(cls, key: str, value: np.array|List, **kwargs ):
|
|
424
|
+
return MatrixAnswer(
|
|
425
|
+
key=key,
|
|
426
|
+
value=value,
|
|
427
|
+
variable_kind=cls.VariableKind.MATRIX
|
|
428
|
+
)
|
|
429
|
+
|
|
407
430
|
@staticmethod
|
|
408
431
|
def _to_fraction(x):
|
|
409
432
|
"""Convert int/float/decimal.Decimal/fractions.Fraction/str('a/b' or decimal) to fractions.Fraction exactly."""
|
|
433
|
+
log.debug(f"x: {x} {x.__class__}")
|
|
410
434
|
if isinstance(x, fractions.Fraction):
|
|
411
435
|
return x
|
|
412
436
|
if isinstance(x, int):
|
|
@@ -488,4 +512,68 @@ class Answer:
|
|
|
488
512
|
whole, rem = divmod(A, b)
|
|
489
513
|
outs.add(f"{sign}{whole} {rem}/{b}")
|
|
490
514
|
|
|
491
|
-
return sorted(outs, key=lambda s: (len(s), s))
|
|
515
|
+
return sorted(outs, key=lambda s: (len(s), s))
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
class MatrixAnswer(Answer):
|
|
519
|
+
def get_for_canvas(self, single_answer=False) -> List[Dict]:
|
|
520
|
+
canvas_answers = []
|
|
521
|
+
|
|
522
|
+
"""
|
|
523
|
+
The core idea is that we will walk through and generate each one for X_i,j
|
|
524
|
+
|
|
525
|
+
The big remaining question is how we will get all these names to the outside world.
|
|
526
|
+
It might have to be a pretty big challenge, or rather re-write.
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
# The core idea is that we will be walking through and generating a per-index set of answers.
|
|
530
|
+
# Boy will this get messy. Poor canvas.
|
|
531
|
+
for i, j in np.ndindex(self.value.shape):
|
|
532
|
+
entry_strings = self.__class__.accepted_strings(
|
|
533
|
+
self.value[i,j],
|
|
534
|
+
allow_integer=True,
|
|
535
|
+
allow_simple_fraction=True,
|
|
536
|
+
max_denominator=3 * 4 * 5,
|
|
537
|
+
allow_mixed=True,
|
|
538
|
+
include_spaces=False,
|
|
539
|
+
include_fixed_even_if_integer=True
|
|
540
|
+
)
|
|
541
|
+
canvas_answers.extend(
|
|
542
|
+
[
|
|
543
|
+
{
|
|
544
|
+
"blank_id": f"{self.key}_{i}_{j}", # Give each an index associated with it so we can track it
|
|
545
|
+
"answer_text": answer_string,
|
|
546
|
+
"answer_weight": 100 if self.correct else 0,
|
|
547
|
+
}
|
|
548
|
+
for answer_string in entry_strings
|
|
549
|
+
]
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
return canvas_answers
|
|
553
|
+
|
|
554
|
+
def get_ast_element(self, label=None):
|
|
555
|
+
from QuizGenerator.contentast import ContentAST
|
|
556
|
+
|
|
557
|
+
log.debug(f"self.value: {self.value}")
|
|
558
|
+
|
|
559
|
+
data = [
|
|
560
|
+
[
|
|
561
|
+
ContentAST.Answer(
|
|
562
|
+
Answer.float_value(
|
|
563
|
+
key=f"{self.key}_{i}_{j}",
|
|
564
|
+
value=self.value[i,j]
|
|
565
|
+
)
|
|
566
|
+
)
|
|
567
|
+
for i in range(self.value.shape[0])
|
|
568
|
+
]
|
|
569
|
+
for j in range(self.value.shape[1])
|
|
570
|
+
]
|
|
571
|
+
table = ContentAST.Table(data)
|
|
572
|
+
|
|
573
|
+
if label is not None:
|
|
574
|
+
return ContentAST.Container([
|
|
575
|
+
ContentAST.Text(f"{label} = "),
|
|
576
|
+
table
|
|
577
|
+
])
|
|
578
|
+
else:
|
|
579
|
+
return table
|
QuizGenerator/mixins.py
CHANGED
|
@@ -32,8 +32,17 @@ class TableQuestionMixin:
|
|
|
32
32
|
Returns:
|
|
33
33
|
ContentAST.Table with the information formatted
|
|
34
34
|
"""
|
|
35
|
+
# Don't convert ContentAST elements to strings - let them render properly
|
|
36
|
+
table_data = []
|
|
37
|
+
for key, value in info_dict.items():
|
|
38
|
+
# Keep ContentAST elements as-is, convert others to strings
|
|
39
|
+
if isinstance(value, ContentAST.Element):
|
|
40
|
+
table_data.append([key, value])
|
|
41
|
+
else:
|
|
42
|
+
table_data.append([key, str(value)])
|
|
43
|
+
|
|
35
44
|
return ContentAST.Table(
|
|
36
|
-
data=
|
|
45
|
+
data=table_data,
|
|
37
46
|
transpose=transpose
|
|
38
47
|
)
|
|
39
48
|
|
|
@@ -638,7 +638,7 @@ class Segmentation(MemoryAccessQuestion, TableQuestionMixin, BodyTemplatesMixin)
|
|
|
638
638
|
f"Since we are in the {self.segment} segment, "
|
|
639
639
|
f"we see from our table that our bounds are {self.bounds[self.segment]}. "
|
|
640
640
|
f"Remember that our check for our {self.segment} segment is: ",
|
|
641
|
-
f"`if (offset
|
|
641
|
+
f"`if (offset >= bounds({self.segment})) : INVALID`",
|
|
642
642
|
"which becomes"
|
|
643
643
|
f"`if ({self.offset:0b} > {self.bounds[self.segment]:0b}) : INVALID`"
|
|
644
644
|
]
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
4
|
import abc
|
|
5
|
+
import difflib
|
|
5
6
|
import logging
|
|
6
7
|
|
|
7
8
|
from QuizGenerator.question import Question, Answer, QuestionRegistry
|
|
@@ -365,31 +366,85 @@ class VSFS_states(IOQuestion):
|
|
|
365
366
|
def get_explanation(self) -> ContentAST.Section:
|
|
366
367
|
explanation = ContentAST.Section()
|
|
367
368
|
|
|
368
|
-
|
|
369
|
+
log.debug(f"self.start_state: {self.start_state}")
|
|
370
|
+
log.debug(f"self.end_state: {self.end_state}")
|
|
371
|
+
|
|
372
|
+
explanation.add_elements([
|
|
369
373
|
ContentAST.Paragraph([
|
|
370
|
-
"
|
|
371
|
-
"
|
|
372
|
-
"<a href=\"https://github.com/chyyuu/os_tutorial_lab/blob/master/ostep/ostep13-vsfs.md\">here</a>, "
|
|
373
|
-
"as well as simulator code. Please note that the code uses python 2.",
|
|
374
|
-
"",
|
|
375
|
-
"In general, I recommend looking for differences between the two outputs. Recommended steps would be:",
|
|
376
|
-
"<ol>"
|
|
377
|
-
|
|
378
|
-
"<li> Check to see if there are differences between the bitmaps "
|
|
379
|
-
"that could indicate a file/directroy were created or removed.</li>",
|
|
380
|
-
|
|
381
|
-
"<li>Check the listed inodes to see if any entries have changed. "
|
|
382
|
-
"This might be a new entry entirely or a reference count changing. "
|
|
383
|
-
"If the references increased then this was likely a link or creation, "
|
|
384
|
-
"and if it decreased then it is likely an unlink.</li>",
|
|
385
|
-
|
|
386
|
-
"<li>Look at the data blocks to see if a new entry has "
|
|
387
|
-
"been added to a directory or a new block has been mapped.</li>",
|
|
388
|
-
|
|
389
|
-
"</ol>",
|
|
390
|
-
"These steps can usually help you quickly identify "
|
|
391
|
-
"what has occured in the simulation and key you in to the right answer."
|
|
374
|
+
"The key thing to pay attention to when solving these problems is where there are differences between the start state and the end state.",
|
|
375
|
+
"In this particular problem, we can see that these lines are different:"
|
|
392
376
|
])
|
|
377
|
+
])
|
|
378
|
+
|
|
379
|
+
chunk_to_add = []
|
|
380
|
+
lines_that_changed = []
|
|
381
|
+
for start_line, end_line in zip(self.start_state.split('\n'), self.end_state.split('\n')):
|
|
382
|
+
if start_line == end_line:
|
|
383
|
+
continue
|
|
384
|
+
lines_that_changed.append((start_line, end_line))
|
|
385
|
+
chunk_to_add.append(
|
|
386
|
+
f" - `{start_line}` -> `{end_line}`"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
explanation.add_element(
|
|
390
|
+
ContentAST.Paragraph(chunk_to_add)
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
chunk_to_add = [
|
|
394
|
+
"A great place to start is to check to see if the bitmaps have changed as this can quickly tell us a lot of information"
|
|
395
|
+
]
|
|
396
|
+
|
|
397
|
+
inode_bitmap_lines = list(filter(lambda s: "inode bitmap" in s[0], lines_that_changed))
|
|
398
|
+
data_bitmap_lines = list(filter(lambda s: "data bitmap" in s[0], lines_that_changed))
|
|
399
|
+
|
|
400
|
+
def get_bitmap(line: str) -> str:
|
|
401
|
+
log.debug(f"line: {line}")
|
|
402
|
+
return line.split()[-1]
|
|
403
|
+
|
|
404
|
+
def highlight_changes(a: str, b: str) -> str:
|
|
405
|
+
matcher = difflib.SequenceMatcher(None, a, b)
|
|
406
|
+
result = []
|
|
407
|
+
|
|
408
|
+
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
|
409
|
+
if tag == "equal":
|
|
410
|
+
result.append(b[j1:j2])
|
|
411
|
+
elif tag in ("insert", "replace"):
|
|
412
|
+
result.append(f"***{b[j1:j2]}***")
|
|
413
|
+
# for "delete", do nothing since text is removed
|
|
414
|
+
|
|
415
|
+
return "".join(result)
|
|
416
|
+
|
|
417
|
+
if len(inode_bitmap_lines) > 0:
|
|
418
|
+
inode_bitmap_lines = inode_bitmap_lines[0]
|
|
419
|
+
chunk_to_add.append(f"The inode bitmap lines have changed from {get_bitmap(inode_bitmap_lines[0])} to {get_bitmap(inode_bitmap_lines[1])}.")
|
|
420
|
+
if get_bitmap(inode_bitmap_lines[0]).count('1') < get_bitmap(inode_bitmap_lines[1]).count('1'):
|
|
421
|
+
chunk_to_add.append("We can see that we have added an inode, so we have either called `creat` or `mkdir`.")
|
|
422
|
+
else:
|
|
423
|
+
chunk_to_add.append("We can see that we have removed an inode, so we have called `unlink`.")
|
|
424
|
+
|
|
425
|
+
if len(data_bitmap_lines) > 0:
|
|
426
|
+
data_bitmap_lines = data_bitmap_lines[0]
|
|
427
|
+
chunk_to_add.append(f"The inode bitmap lines have changed from {get_bitmap(data_bitmap_lines[0])} to {get_bitmap(data_bitmap_lines[1])}.")
|
|
428
|
+
if get_bitmap(data_bitmap_lines[0]).count('1') < get_bitmap(data_bitmap_lines[1]).count('1'):
|
|
429
|
+
chunk_to_add.append("We can see that we have added a data block, so we have either called `mkdir` or `write`.")
|
|
430
|
+
else:
|
|
431
|
+
chunk_to_add.append("We can see that we have removed a data block, so we have `unlink`ed a file.")
|
|
432
|
+
|
|
433
|
+
if len(data_bitmap_lines) == 0 and len(inode_bitmap_lines) == 0:
|
|
434
|
+
chunk_to_add.append("If they have not changed, then we know we must have eithered called `link` or `unlink` and must check the references.")
|
|
435
|
+
|
|
436
|
+
explanation.add_element(
|
|
437
|
+
ContentAST.Paragraph(chunk_to_add)
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
explanation.add_elements([
|
|
441
|
+
ContentAST.Paragraph(["The overall changes are highlighted with `*` symbols below"])
|
|
442
|
+
])
|
|
443
|
+
|
|
444
|
+
explanation.add_element(
|
|
445
|
+
ContentAST.Code(
|
|
446
|
+
highlight_changes(self.start_state, self.end_state)
|
|
447
|
+
)
|
|
393
448
|
)
|
|
394
449
|
|
|
395
450
|
return explanation
|
|
@@ -15,7 +15,6 @@ from typing import List
|
|
|
15
15
|
|
|
16
16
|
import matplotlib.pyplot as plt
|
|
17
17
|
|
|
18
|
-
from QuizGenerator.misc import OutputFormat
|
|
19
18
|
from QuizGenerator.contentast import ContentAST
|
|
20
19
|
from QuizGenerator.question import Question, Answer, QuestionRegistry, RegenerableChoiceMixin
|
|
21
20
|
from QuizGenerator.mixins import TableQuestionMixin, BodyTemplatesMixin
|
|
@@ -380,7 +379,7 @@ class SchedulingQuestion(ProcessQuestion, RegenerableChoiceMixin, TableQuestionM
|
|
|
380
379
|
# Return whether this workload is interesting
|
|
381
380
|
return self.is_interesting()
|
|
382
381
|
|
|
383
|
-
def get_body(self,
|
|
382
|
+
def get_body(self, *args, **kwargs) -> ContentAST.Section:
|
|
384
383
|
# Create table data for scheduling results
|
|
385
384
|
table_rows = []
|
|
386
385
|
for job_id in sorted(self.job_stats.keys()):
|
|
File without changes
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import logging
|
|
3
|
+
import math
|
|
4
|
+
import keras
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from QuizGenerator.misc import MatrixAnswer
|
|
8
|
+
from QuizGenerator.question import Question, QuestionRegistry, Answer
|
|
9
|
+
from QuizGenerator.contentast import ContentAST
|
|
10
|
+
from QuizGenerator.constants import MathRanges
|
|
11
|
+
from QuizGenerator.mixins import TableQuestionMixin
|
|
12
|
+
|
|
13
|
+
from .matrices import MatrixQuestion
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@QuestionRegistry.register("cst463.attention.forward-pass")
|
|
19
|
+
class AttentionForwardPass(MatrixQuestion, TableQuestionMixin):
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def simple_attention(Q, K, V):
|
|
23
|
+
"""
|
|
24
|
+
Q: (seq_len, d_k) - queries
|
|
25
|
+
K: (seq_len, d_k) - keys
|
|
26
|
+
V: (seq_len, d_v) - values
|
|
27
|
+
|
|
28
|
+
Returns: (seq_len, d_v) - attended output
|
|
29
|
+
"""
|
|
30
|
+
d_k = Q.shape[1]
|
|
31
|
+
|
|
32
|
+
# Compute attention scores
|
|
33
|
+
scores = Q @ K.T / np.sqrt(d_k)
|
|
34
|
+
|
|
35
|
+
# Softmax to get weights
|
|
36
|
+
attention_weights = np.exp(scores) / np.exp(scores).sum(axis=1, keepdims=True)
|
|
37
|
+
|
|
38
|
+
# Weighted sum of values
|
|
39
|
+
output = attention_weights @ V
|
|
40
|
+
|
|
41
|
+
return output, attention_weights
|
|
42
|
+
|
|
43
|
+
def refresh(self, *args, **kwargs):
|
|
44
|
+
super().refresh(*args, **kwargs)
|
|
45
|
+
|
|
46
|
+
seq_len = kwargs.get("seq_len", 3)
|
|
47
|
+
d_k = kwargs.get("key_dimension", 1) # key/query dimension
|
|
48
|
+
d_v = kwargs.get("value_dimension", 1) # value dimension
|
|
49
|
+
|
|
50
|
+
# Small integer matrices
|
|
51
|
+
self.Q = self.rng.randint(0, 3, size=(seq_len, d_k))
|
|
52
|
+
self.K = self.rng.randint(0, 3, size=(seq_len, d_k))
|
|
53
|
+
self.V = self.rng.randint(0, 3, size=(seq_len, d_v))
|
|
54
|
+
|
|
55
|
+
self.Q = self.get_rounded_matrix((seq_len, d_k), 0, 3)
|
|
56
|
+
self.K = self.get_rounded_matrix((seq_len, d_k), 0, 3)
|
|
57
|
+
self.V = self.get_rounded_matrix((seq_len, d_v), 0, 3)
|
|
58
|
+
|
|
59
|
+
self.output, self.weights = self.simple_attention(self.Q, self.K, self.V)
|
|
60
|
+
|
|
61
|
+
## Answers:
|
|
62
|
+
# Q, K, V, output, weights
|
|
63
|
+
|
|
64
|
+
self.answers["weights"] = MatrixAnswer("weights", self.output)
|
|
65
|
+
self.answers["output"] = MatrixAnswer("output", self.output)
|
|
66
|
+
|
|
67
|
+
return True
|
|
68
|
+
|
|
69
|
+
def get_body(self, **kwargs) -> ContentAST.Section:
|
|
70
|
+
body = ContentAST.Section()
|
|
71
|
+
|
|
72
|
+
body.add_element(
|
|
73
|
+
ContentAST.Text("Given the below information about a self attention layer, please calculate the output sequence.")
|
|
74
|
+
)
|
|
75
|
+
body.add_element(
|
|
76
|
+
self.create_info_table(
|
|
77
|
+
{
|
|
78
|
+
"Q": ContentAST.Matrix(self.Q),
|
|
79
|
+
"K": ContentAST.Matrix(self.K),
|
|
80
|
+
"V": ContentAST.Matrix(self.V),
|
|
81
|
+
}
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
body.add_elements([
|
|
86
|
+
ContentAST.LineBreak(),
|
|
87
|
+
self.answers["weights"].get_ast_element(label=f"Weights"),
|
|
88
|
+
ContentAST.LineBreak(),
|
|
89
|
+
self.answers["output"].get_ast_element(label=f"Output"),
|
|
90
|
+
])
|
|
91
|
+
|
|
92
|
+
return body
|
|
93
|
+
|
|
94
|
+
def get_explanation(self, **kwargs) -> ContentAST.Section:
|
|
95
|
+
explanation = ContentAST.Section()
|
|
96
|
+
digits = Answer.DEFAULT_ROUNDING_DIGITS
|
|
97
|
+
|
|
98
|
+
explanation.add_element(
|
|
99
|
+
ContentAST.Paragraph([
|
|
100
|
+
"Self-attention uses scaled dot-product attention to compute a weighted combination of values based on query-key similarity."
|
|
101
|
+
])
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Step 1: Compute attention scores
|
|
105
|
+
explanation.add_element(
|
|
106
|
+
ContentAST.Paragraph([
|
|
107
|
+
ContentAST.Text("Step 1: Compute attention scores", emphasis=True)
|
|
108
|
+
])
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
d_k = self.Q.shape[1]
|
|
112
|
+
explanation.add_element(
|
|
113
|
+
ContentAST.Equation(f"\\text{{scores}} = \\frac{{Q K^T}}{{\\sqrt{{d_k}}}} = \\frac{{Q K^T}}{{\\sqrt{{{d_k}}}}}")
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
scores = self.Q @ self.K.T / np.sqrt(d_k)
|
|
117
|
+
|
|
118
|
+
explanation.add_element(
|
|
119
|
+
ContentAST.Paragraph([
|
|
120
|
+
"Raw scores (scaling by ",
|
|
121
|
+
ContentAST.Equation(f'\\sqrt{{{d_k}}}', inline=True),
|
|
122
|
+
" prevents extremely large values):"
|
|
123
|
+
])
|
|
124
|
+
)
|
|
125
|
+
explanation.add_element(ContentAST.Matrix(np.round(scores, digits)))
|
|
126
|
+
|
|
127
|
+
# Step 2: Apply softmax
|
|
128
|
+
explanation.add_element(
|
|
129
|
+
ContentAST.Paragraph([
|
|
130
|
+
ContentAST.Text("Step 2: Apply softmax to get attention weights", emphasis=True)
|
|
131
|
+
])
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
explanation.add_element(
|
|
135
|
+
ContentAST.Equation(r"\alpha_{ij} = \frac{\exp(\text{score}_{ij})}{\sum_k \exp(\text{score}_{ik})}")
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Show ONE example row
|
|
139
|
+
explanation.add_element(
|
|
140
|
+
ContentAST.Paragraph([
|
|
141
|
+
"Example: Row 0 softmax computation"
|
|
142
|
+
])
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
row_scores = scores[0]
|
|
146
|
+
exp_scores = np.exp(row_scores)
|
|
147
|
+
sum_exp = exp_scores.sum()
|
|
148
|
+
weights_row = exp_scores / sum_exp
|
|
149
|
+
|
|
150
|
+
exp_terms = " + ".join([f"e^{{{s:.{digits}f}}}" for s in row_scores])
|
|
151
|
+
|
|
152
|
+
explanation.add_element(
|
|
153
|
+
ContentAST.Paragraph([
|
|
154
|
+
f"Denominator = {exp_terms} = {sum_exp:.{digits}f}"
|
|
155
|
+
])
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Format array with proper rounding
|
|
159
|
+
weights_str = "[" + ", ".join([f"{w:.{digits}f}" for w in weights_row]) + "]"
|
|
160
|
+
explanation.add_element(
|
|
161
|
+
ContentAST.Paragraph([
|
|
162
|
+
f"Resulting weights: {weights_str}"
|
|
163
|
+
])
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
explanation.add_element(
|
|
167
|
+
ContentAST.Paragraph([
|
|
168
|
+
"Complete attention weight matrix:"
|
|
169
|
+
])
|
|
170
|
+
)
|
|
171
|
+
explanation.add_element(ContentAST.Matrix(np.round(self.weights, digits)))
|
|
172
|
+
|
|
173
|
+
# Step 3: Weighted sum of values
|
|
174
|
+
explanation.add_element(
|
|
175
|
+
ContentAST.Paragraph([
|
|
176
|
+
ContentAST.Text("Step 3: Compute weighted sum of values", emphasis=True)
|
|
177
|
+
])
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
explanation.add_element(
|
|
181
|
+
ContentAST.Equation(r"\text{output} = \text{weights} \times V")
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
explanation.add_element(
|
|
185
|
+
ContentAST.Paragraph([
|
|
186
|
+
"Final output:"
|
|
187
|
+
])
|
|
188
|
+
)
|
|
189
|
+
explanation.add_element(ContentAST.Matrix(np.round(self.output, digits)))
|
|
190
|
+
|
|
191
|
+
return explanation
|
|
192
|
+
|