bead 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bead/__init__.py +11 -0
- bead/__main__.py +11 -0
- bead/active_learning/__init__.py +15 -0
- bead/active_learning/config.py +231 -0
- bead/active_learning/loop.py +566 -0
- bead/active_learning/models/__init__.py +24 -0
- bead/active_learning/models/base.py +852 -0
- bead/active_learning/models/binary.py +910 -0
- bead/active_learning/models/categorical.py +943 -0
- bead/active_learning/models/cloze.py +862 -0
- bead/active_learning/models/forced_choice.py +956 -0
- bead/active_learning/models/free_text.py +773 -0
- bead/active_learning/models/lora.py +365 -0
- bead/active_learning/models/magnitude.py +835 -0
- bead/active_learning/models/multi_select.py +795 -0
- bead/active_learning/models/ordinal_scale.py +811 -0
- bead/active_learning/models/peft_adapter.py +155 -0
- bead/active_learning/models/random_effects.py +639 -0
- bead/active_learning/selection.py +354 -0
- bead/active_learning/strategies.py +391 -0
- bead/active_learning/trainers/__init__.py +26 -0
- bead/active_learning/trainers/base.py +210 -0
- bead/active_learning/trainers/data_collator.py +172 -0
- bead/active_learning/trainers/dataset_utils.py +261 -0
- bead/active_learning/trainers/huggingface.py +304 -0
- bead/active_learning/trainers/lightning.py +324 -0
- bead/active_learning/trainers/metrics.py +424 -0
- bead/active_learning/trainers/mixed_effects.py +551 -0
- bead/active_learning/trainers/model_wrapper.py +509 -0
- bead/active_learning/trainers/registry.py +104 -0
- bead/adapters/__init__.py +11 -0
- bead/adapters/huggingface.py +61 -0
- bead/behavioral/__init__.py +116 -0
- bead/behavioral/analytics.py +646 -0
- bead/behavioral/extraction.py +343 -0
- bead/behavioral/merging.py +343 -0
- bead/cli/__init__.py +11 -0
- bead/cli/active_learning.py +513 -0
- bead/cli/active_learning_commands.py +779 -0
- bead/cli/completion.py +359 -0
- bead/cli/config.py +624 -0
- bead/cli/constraint_builders.py +286 -0
- bead/cli/deployment.py +859 -0
- bead/cli/deployment_trials.py +493 -0
- bead/cli/deployment_ui.py +332 -0
- bead/cli/display.py +378 -0
- bead/cli/items.py +960 -0
- bead/cli/items_factories.py +776 -0
- bead/cli/list_constraints.py +714 -0
- bead/cli/lists.py +490 -0
- bead/cli/main.py +430 -0
- bead/cli/models.py +877 -0
- bead/cli/resource_loaders.py +621 -0
- bead/cli/resources.py +1036 -0
- bead/cli/shell.py +356 -0
- bead/cli/simulate.py +840 -0
- bead/cli/templates.py +1158 -0
- bead/cli/training.py +1080 -0
- bead/cli/utils.py +614 -0
- bead/cli/workflow.py +1273 -0
- bead/config/__init__.py +68 -0
- bead/config/active_learning.py +1009 -0
- bead/config/config.py +192 -0
- bead/config/defaults.py +118 -0
- bead/config/deployment.py +217 -0
- bead/config/env.py +147 -0
- bead/config/item.py +45 -0
- bead/config/list.py +193 -0
- bead/config/loader.py +149 -0
- bead/config/logging.py +42 -0
- bead/config/model.py +49 -0
- bead/config/paths.py +46 -0
- bead/config/profiles.py +320 -0
- bead/config/resources.py +47 -0
- bead/config/serialization.py +210 -0
- bead/config/simulation.py +206 -0
- bead/config/template.py +238 -0
- bead/config/validation.py +267 -0
- bead/data/__init__.py +65 -0
- bead/data/base.py +87 -0
- bead/data/identifiers.py +97 -0
- bead/data/language_codes.py +61 -0
- bead/data/metadata.py +270 -0
- bead/data/range.py +123 -0
- bead/data/repository.py +358 -0
- bead/data/serialization.py +249 -0
- bead/data/timestamps.py +89 -0
- bead/data/validation.py +349 -0
- bead/data_collection/__init__.py +11 -0
- bead/data_collection/jatos.py +223 -0
- bead/data_collection/merger.py +154 -0
- bead/data_collection/prolific.py +198 -0
- bead/deployment/__init__.py +5 -0
- bead/deployment/distribution.py +402 -0
- bead/deployment/jatos/__init__.py +1 -0
- bead/deployment/jatos/api.py +200 -0
- bead/deployment/jatos/exporter.py +210 -0
- bead/deployment/jspsych/__init__.py +9 -0
- bead/deployment/jspsych/biome.json +44 -0
- bead/deployment/jspsych/config.py +411 -0
- bead/deployment/jspsych/generator.py +598 -0
- bead/deployment/jspsych/package.json +51 -0
- bead/deployment/jspsych/pnpm-lock.yaml +2141 -0
- bead/deployment/jspsych/randomizer.py +299 -0
- bead/deployment/jspsych/src/lib/list-distributor.test.ts +327 -0
- bead/deployment/jspsych/src/lib/list-distributor.ts +1282 -0
- bead/deployment/jspsych/src/lib/randomizer.test.ts +232 -0
- bead/deployment/jspsych/src/lib/randomizer.ts +367 -0
- bead/deployment/jspsych/src/plugins/cloze-dropdown.ts +252 -0
- bead/deployment/jspsych/src/plugins/forced-choice.ts +265 -0
- bead/deployment/jspsych/src/plugins/plugins.test.ts +141 -0
- bead/deployment/jspsych/src/plugins/rating.ts +248 -0
- bead/deployment/jspsych/src/slopit/index.ts +9 -0
- bead/deployment/jspsych/src/types/jatos.d.ts +256 -0
- bead/deployment/jspsych/src/types/jspsych.d.ts +228 -0
- bead/deployment/jspsych/templates/experiment.css +1 -0
- bead/deployment/jspsych/templates/experiment.js.template +289 -0
- bead/deployment/jspsych/templates/index.html +51 -0
- bead/deployment/jspsych/templates/randomizer.js +241 -0
- bead/deployment/jspsych/templates/randomizer.js.template +313 -0
- bead/deployment/jspsych/trials.py +723 -0
- bead/deployment/jspsych/tsconfig.json +23 -0
- bead/deployment/jspsych/tsup.config.ts +30 -0
- bead/deployment/jspsych/ui/__init__.py +1 -0
- bead/deployment/jspsych/ui/components.py +383 -0
- bead/deployment/jspsych/ui/styles.py +411 -0
- bead/dsl/__init__.py +80 -0
- bead/dsl/ast.py +168 -0
- bead/dsl/context.py +178 -0
- bead/dsl/errors.py +71 -0
- bead/dsl/evaluator.py +570 -0
- bead/dsl/grammar.lark +81 -0
- bead/dsl/parser.py +231 -0
- bead/dsl/stdlib.py +929 -0
- bead/evaluation/__init__.py +13 -0
- bead/evaluation/convergence.py +485 -0
- bead/evaluation/interannotator.py +398 -0
- bead/items/__init__.py +40 -0
- bead/items/adapters/__init__.py +70 -0
- bead/items/adapters/anthropic.py +224 -0
- bead/items/adapters/api_utils.py +167 -0
- bead/items/adapters/base.py +216 -0
- bead/items/adapters/google.py +259 -0
- bead/items/adapters/huggingface.py +1074 -0
- bead/items/adapters/openai.py +323 -0
- bead/items/adapters/registry.py +202 -0
- bead/items/adapters/sentence_transformers.py +224 -0
- bead/items/adapters/togetherai.py +309 -0
- bead/items/binary.py +515 -0
- bead/items/cache.py +558 -0
- bead/items/categorical.py +593 -0
- bead/items/cloze.py +757 -0
- bead/items/constructor.py +784 -0
- bead/items/forced_choice.py +413 -0
- bead/items/free_text.py +681 -0
- bead/items/generation.py +432 -0
- bead/items/item.py +396 -0
- bead/items/item_template.py +787 -0
- bead/items/magnitude.py +573 -0
- bead/items/multi_select.py +621 -0
- bead/items/ordinal_scale.py +569 -0
- bead/items/scoring.py +448 -0
- bead/items/validation.py +723 -0
- bead/lists/__init__.py +30 -0
- bead/lists/balancer.py +263 -0
- bead/lists/constraints.py +1067 -0
- bead/lists/experiment_list.py +286 -0
- bead/lists/list_collection.py +378 -0
- bead/lists/partitioner.py +1141 -0
- bead/lists/stratification.py +254 -0
- bead/participants/__init__.py +73 -0
- bead/participants/collection.py +699 -0
- bead/participants/merging.py +312 -0
- bead/participants/metadata_spec.py +491 -0
- bead/participants/models.py +276 -0
- bead/resources/__init__.py +29 -0
- bead/resources/adapters/__init__.py +19 -0
- bead/resources/adapters/base.py +104 -0
- bead/resources/adapters/cache.py +128 -0
- bead/resources/adapters/glazing.py +508 -0
- bead/resources/adapters/registry.py +117 -0
- bead/resources/adapters/unimorph.py +796 -0
- bead/resources/classification.py +856 -0
- bead/resources/constraint_builders.py +329 -0
- bead/resources/constraints.py +165 -0
- bead/resources/lexical_item.py +223 -0
- bead/resources/lexicon.py +744 -0
- bead/resources/loaders.py +209 -0
- bead/resources/template.py +441 -0
- bead/resources/template_collection.py +707 -0
- bead/resources/template_generation.py +349 -0
- bead/simulation/__init__.py +29 -0
- bead/simulation/annotators/__init__.py +15 -0
- bead/simulation/annotators/base.py +175 -0
- bead/simulation/annotators/distance_based.py +135 -0
- bead/simulation/annotators/lm_based.py +114 -0
- bead/simulation/annotators/oracle.py +182 -0
- bead/simulation/annotators/random.py +181 -0
- bead/simulation/dsl_extension/__init__.py +3 -0
- bead/simulation/noise_models/__init__.py +13 -0
- bead/simulation/noise_models/base.py +42 -0
- bead/simulation/noise_models/random_noise.py +82 -0
- bead/simulation/noise_models/systematic.py +132 -0
- bead/simulation/noise_models/temperature.py +86 -0
- bead/simulation/runner.py +144 -0
- bead/simulation/strategies/__init__.py +23 -0
- bead/simulation/strategies/base.py +123 -0
- bead/simulation/strategies/binary.py +103 -0
- bead/simulation/strategies/categorical.py +123 -0
- bead/simulation/strategies/cloze.py +224 -0
- bead/simulation/strategies/forced_choice.py +127 -0
- bead/simulation/strategies/free_text.py +105 -0
- bead/simulation/strategies/magnitude.py +116 -0
- bead/simulation/strategies/multi_select.py +129 -0
- bead/simulation/strategies/ordinal_scale.py +131 -0
- bead/templates/__init__.py +27 -0
- bead/templates/adapters/__init__.py +17 -0
- bead/templates/adapters/base.py +128 -0
- bead/templates/adapters/cache.py +178 -0
- bead/templates/adapters/huggingface.py +312 -0
- bead/templates/combinatorics.py +103 -0
- bead/templates/filler.py +605 -0
- bead/templates/renderers.py +177 -0
- bead/templates/resolver.py +178 -0
- bead/templates/strategies.py +1806 -0
- bead/templates/streaming.py +195 -0
- bead-0.1.0.dist-info/METADATA +212 -0
- bead-0.1.0.dist-info/RECORD +231 -0
- bead-0.1.0.dist-info/WHEEL +4 -0
- bead-0.1.0.dist-info/entry_points.txt +2 -0
- bead-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""LoRA (Low-Rank Adaptation) implementation for transformer personalization.
|
|
2
|
+
|
|
3
|
+
Implements participant-specific low-rank updates to attention layers for
|
|
4
|
+
efficient parameter-efficient fine-tuning (PEFT) in the GLMM framework.
|
|
5
|
+
|
|
6
|
+
References
|
|
7
|
+
----------
|
|
8
|
+
- Hu et al. (2021): "LoRA: Low-Rank Adaptation of Large Language Models"
|
|
9
|
+
https://arxiv.org/abs/2106.09685
|
|
10
|
+
- Microsoft LoRA: https://github.com/microsoft/LoRA
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
from torch import Tensor
|
|
21
|
+
|
|
22
|
+
__all__ = ["LoRALayer", "LoRALinear", "ParticipantLoRAAdapter"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LoRALayer(nn.Module):
|
|
26
|
+
"""Low-rank adaptation layer for attention projections.
|
|
27
|
+
|
|
28
|
+
Implements: ΔW = (α/r) * B @ A
|
|
29
|
+
where:
|
|
30
|
+
- B ∈ ℝ^(in_features × rank)
|
|
31
|
+
- A ∈ ℝ^(rank × out_features)
|
|
32
|
+
- r is the rank (much smaller than in_features, out_features)
|
|
33
|
+
- α is a scaling factor
|
|
34
|
+
|
|
35
|
+
This additive update is applied to frozen base weights: W' = W + ΔW
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
in_features : int
|
|
40
|
+
Input dimension.
|
|
41
|
+
out_features : int
|
|
42
|
+
Output dimension.
|
|
43
|
+
rank : int, default=8
|
|
44
|
+
LoRA rank r. Typical values: 4-16.
|
|
45
|
+
alpha : float, default=16.0
|
|
46
|
+
Scaling factor α. Typically 2*rank.
|
|
47
|
+
dropout : float, default=0.1
|
|
48
|
+
Dropout probability for LoRA path.
|
|
49
|
+
|
|
50
|
+
Attributes
|
|
51
|
+
----------
|
|
52
|
+
lora_A : nn.Parameter
|
|
53
|
+
First low-rank matrix, shape (in_features, rank).
|
|
54
|
+
Initialized with Kaiming uniform.
|
|
55
|
+
lora_B : nn.Parameter
|
|
56
|
+
Second low-rank matrix, shape (rank, out_features).
|
|
57
|
+
Initialized with zeros (so ΔW = 0 initially).
|
|
58
|
+
scaling : float
|
|
59
|
+
Computed as α/r.
|
|
60
|
+
|
|
61
|
+
Examples
|
|
62
|
+
--------
|
|
63
|
+
>>> lora = LoRALayer(768, 768, rank=8, alpha=16.0)
|
|
64
|
+
>>> x = torch.randn(2, 10, 768) # (batch, seq_len, in_features)
|
|
65
|
+
>>> delta = lora(x) # (batch, seq_len, out_features)
|
|
66
|
+
>>> delta.shape
|
|
67
|
+
torch.Size([2, 10, 768])
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
in_features: int,
|
|
73
|
+
out_features: int,
|
|
74
|
+
rank: int = 8,
|
|
75
|
+
alpha: float = 16.0,
|
|
76
|
+
dropout: float = 0.1,
|
|
77
|
+
) -> None:
|
|
78
|
+
"""Initialize LoRA layer."""
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.rank = rank
|
|
81
|
+
self.alpha = alpha
|
|
82
|
+
self.scaling = alpha / rank
|
|
83
|
+
|
|
84
|
+
# Low-rank matrices (trainable)
|
|
85
|
+
self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
|
|
86
|
+
self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
|
|
87
|
+
self.dropout = nn.Dropout(dropout)
|
|
88
|
+
|
|
89
|
+
# Initialize A with Kaiming uniform, B with zeros
|
|
90
|
+
# This ensures ΔW = 0 at initialization
|
|
91
|
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
92
|
+
nn.init.zeros_(self.lora_B)
|
|
93
|
+
|
|
94
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
95
|
+
"""Apply LoRA: x @ (A @ B) * scaling.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
x : Tensor
|
|
100
|
+
Input tensor, shape (batch, seq_len, in_features).
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
Tensor
|
|
105
|
+
LoRA output, shape (batch, seq_len, out_features).
|
|
106
|
+
"""
|
|
107
|
+
# x @ A: (batch, seq_len, in_features) @ (in_features, rank)
|
|
108
|
+
# = (batch, seq_len, rank)
|
|
109
|
+
# @ B: (batch, seq_len, rank) @ (rank, out_features)
|
|
110
|
+
# = (batch, seq_len, out_features)
|
|
111
|
+
result = self.dropout(x) @ self.lora_A @ self.lora_B
|
|
112
|
+
return result * self.scaling
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class LoRALinear(nn.Module):
|
|
116
|
+
"""Linear layer with LoRA adaptation.
|
|
117
|
+
|
|
118
|
+
Wraps a frozen linear layer and adds trainable low-rank updates.
|
|
119
|
+
Forward pass: output = base_layer(x) + lora(x)
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
base_layer : nn.Linear
|
|
124
|
+
The original linear layer to adapt (will be frozen).
|
|
125
|
+
rank : int, default=8
|
|
126
|
+
LoRA rank r.
|
|
127
|
+
alpha : float, default=16.0
|
|
128
|
+
LoRA scaling factor α.
|
|
129
|
+
dropout : float, default=0.1
|
|
130
|
+
Dropout for LoRA path.
|
|
131
|
+
|
|
132
|
+
Attributes
|
|
133
|
+
----------
|
|
134
|
+
base_layer : nn.Linear
|
|
135
|
+
Frozen base linear layer.
|
|
136
|
+
lora : LoRALayer
|
|
137
|
+
Low-rank adaptation layer.
|
|
138
|
+
|
|
139
|
+
Examples
|
|
140
|
+
--------
|
|
141
|
+
>>> base = nn.Linear(768, 768)
|
|
142
|
+
>>> lora_linear = LoRALinear(base, rank=8)
|
|
143
|
+
>>> x = torch.randn(2, 10, 768)
|
|
144
|
+
>>> out = lora_linear(x)
|
|
145
|
+
>>> out.shape
|
|
146
|
+
torch.Size([2, 10, 768])
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
base_layer: nn.Linear,
|
|
152
|
+
rank: int = 8,
|
|
153
|
+
alpha: float = 16.0,
|
|
154
|
+
dropout: float = 0.1,
|
|
155
|
+
) -> None:
|
|
156
|
+
"""Initialize LoRA linear layer."""
|
|
157
|
+
super().__init__()
|
|
158
|
+
self.base_layer = base_layer
|
|
159
|
+
|
|
160
|
+
# Freeze base layer parameters
|
|
161
|
+
for param in self.base_layer.parameters():
|
|
162
|
+
param.requires_grad = False
|
|
163
|
+
|
|
164
|
+
# Add LoRA adaptation
|
|
165
|
+
self.lora = LoRALayer(
|
|
166
|
+
base_layer.in_features,
|
|
167
|
+
base_layer.out_features,
|
|
168
|
+
rank=rank,
|
|
169
|
+
alpha=alpha,
|
|
170
|
+
dropout=dropout,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
174
|
+
"""Forward pass: base output + LoRA adaptation.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
x : Tensor
|
|
179
|
+
Input tensor, shape (batch, seq_len, in_features).
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
Tensor
|
|
184
|
+
Output with LoRA adaptation, shape (batch, seq_len, out_features).
|
|
185
|
+
"""
|
|
186
|
+
return self.base_layer(x) + self.lora(x)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class ParticipantLoRAAdapter(nn.Module):
|
|
190
|
+
"""Participant-specific LoRA adapters for seq2seq decoder.
|
|
191
|
+
|
|
192
|
+
Injects LoRA layers into specified target modules (typically query and value
|
|
193
|
+
projections in attention layers). Used for random slopes mode in GLMM.
|
|
194
|
+
|
|
195
|
+
This class wraps a decoder module and applies participant-specific low-rank
|
|
196
|
+
adaptations to attention projections.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
decoder : nn.Module
|
|
201
|
+
The decoder module to adapt (e.g., T5 decoder, BART decoder).
|
|
202
|
+
rank : int
|
|
203
|
+
LoRA rank r.
|
|
204
|
+
alpha : float
|
|
205
|
+
LoRA scaling factor α.
|
|
206
|
+
dropout : float
|
|
207
|
+
Dropout for LoRA layers.
|
|
208
|
+
target_modules : list[str]
|
|
209
|
+
Names of modules to inject LoRA into (e.g., ["q_proj", "v_proj"]).
|
|
210
|
+
|
|
211
|
+
Attributes
|
|
212
|
+
----------
|
|
213
|
+
decoder : nn.Module
|
|
214
|
+
The adapted decoder (with LoRA layers injected).
|
|
215
|
+
lora_layers : dict[str, LoRALinear]
|
|
216
|
+
Mapping from module name to LoRA linear layer.
|
|
217
|
+
|
|
218
|
+
Examples
|
|
219
|
+
--------
|
|
220
|
+
>>> from transformers import AutoModelForSeq2SeqLM
|
|
221
|
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") # doctest: +SKIP
|
|
222
|
+
>>> decoder = model.get_decoder() # doctest: +SKIP
|
|
223
|
+
>>> adapter = ParticipantLoRAAdapter( # doctest: +SKIP
|
|
224
|
+
... decoder,
|
|
225
|
+
... rank=8,
|
|
226
|
+
... alpha=16.0,
|
|
227
|
+
... target_modules=["q", "v"] # T5 uses "q" and "v"
|
|
228
|
+
... )
|
|
229
|
+
>>> # adapter.decoder now has LoRA layers injected
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
decoder: nn.Module,
|
|
235
|
+
rank: int,
|
|
236
|
+
alpha: float,
|
|
237
|
+
dropout: float,
|
|
238
|
+
target_modules: list[str],
|
|
239
|
+
) -> None:
|
|
240
|
+
"""Initialize participant LoRA adapter."""
|
|
241
|
+
super().__init__()
|
|
242
|
+
self.decoder = decoder
|
|
243
|
+
self.rank = rank
|
|
244
|
+
self.alpha = alpha
|
|
245
|
+
self.dropout = dropout
|
|
246
|
+
self.target_modules = target_modules
|
|
247
|
+
self.lora_layers: dict[str, LoRALinear] = {}
|
|
248
|
+
|
|
249
|
+
# Inject LoRA into target modules
|
|
250
|
+
self._inject_lora()
|
|
251
|
+
|
|
252
|
+
def _inject_lora(self) -> None:
|
|
253
|
+
"""Inject LoRA into decoder attention layers.
|
|
254
|
+
|
|
255
|
+
Searches for modules matching target_modules (e.g., "q_proj", "v_proj")
|
|
256
|
+
and replaces them with LoRALinear wrappers.
|
|
257
|
+
"""
|
|
258
|
+
# Build a mapping of full module paths to modules
|
|
259
|
+
module_dict = dict(self.decoder.named_modules())
|
|
260
|
+
|
|
261
|
+
for name, module in list(self.decoder.named_modules()):
|
|
262
|
+
# Check if this module name contains any target module substring
|
|
263
|
+
# e.g., "layer.0.SelfAttention.q" contains "q"
|
|
264
|
+
if any(target in name for target in self.target_modules):
|
|
265
|
+
if isinstance(module, nn.Linear):
|
|
266
|
+
# Get parent module and attribute name
|
|
267
|
+
path_parts = name.split(".")
|
|
268
|
+
if len(path_parts) == 1:
|
|
269
|
+
# Top-level module
|
|
270
|
+
parent = self.decoder
|
|
271
|
+
attr_name = name
|
|
272
|
+
else:
|
|
273
|
+
parent_path = ".".join(path_parts[:-1])
|
|
274
|
+
parent = module_dict[parent_path]
|
|
275
|
+
attr_name = path_parts[-1]
|
|
276
|
+
|
|
277
|
+
# Create LoRA linear layer
|
|
278
|
+
lora_layer = LoRALinear(
|
|
279
|
+
module,
|
|
280
|
+
rank=self.rank,
|
|
281
|
+
alpha=self.alpha,
|
|
282
|
+
dropout=self.dropout,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Replace original module with LoRA version
|
|
286
|
+
setattr(parent, attr_name, lora_layer)
|
|
287
|
+
self.lora_layers[name] = lora_layer
|
|
288
|
+
|
|
289
|
+
def forward(
|
|
290
|
+
self, input_ids: Tensor, attention_mask: Tensor | None = None
|
|
291
|
+
) -> Tensor:
|
|
292
|
+
"""Forward pass through decoder with LoRA.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
input_ids : Tensor
|
|
297
|
+
Input token IDs, shape (batch_size, seq_len).
|
|
298
|
+
attention_mask : Tensor | None
|
|
299
|
+
Attention mask, shape (batch_size, seq_len). If None, no masking.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
Tensor
|
|
304
|
+
Decoder output tensor.
|
|
305
|
+
"""
|
|
306
|
+
if attention_mask is not None:
|
|
307
|
+
return self.decoder(input_ids, attention_mask=attention_mask)
|
|
308
|
+
return self.decoder(input_ids)
|
|
309
|
+
|
|
310
|
+
def get_lora_parameters(self) -> list[nn.Parameter]:
|
|
311
|
+
"""Get all LoRA parameters for optimization.
|
|
312
|
+
|
|
313
|
+
Returns
|
|
314
|
+
-------
|
|
315
|
+
list[nn.Parameter]
|
|
316
|
+
List of all trainable LoRA parameters (A and B matrices).
|
|
317
|
+
"""
|
|
318
|
+
params: list[nn.Parameter] = []
|
|
319
|
+
for lora_linear in self.lora_layers.values():
|
|
320
|
+
params.extend(lora_linear.lora.parameters())
|
|
321
|
+
return params
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def create_participant_lora_adapter(
|
|
325
|
+
base_decoder: nn.Module,
|
|
326
|
+
rank: int,
|
|
327
|
+
alpha: float,
|
|
328
|
+
dropout: float,
|
|
329
|
+
target_modules: list[str],
|
|
330
|
+
) -> ParticipantLoRAAdapter:
|
|
331
|
+
"""Create a participant LoRA adapter.
|
|
332
|
+
|
|
333
|
+
Creates a deep copy of the base decoder and injects LoRA layers.
|
|
334
|
+
|
|
335
|
+
Parameters
|
|
336
|
+
----------
|
|
337
|
+
base_decoder : nn.Module
|
|
338
|
+
Base decoder to copy and adapt.
|
|
339
|
+
rank : int
|
|
340
|
+
LoRA rank.
|
|
341
|
+
alpha : float
|
|
342
|
+
LoRA scaling factor.
|
|
343
|
+
dropout : float
|
|
344
|
+
LoRA dropout.
|
|
345
|
+
target_modules : list[str]
|
|
346
|
+
Target modules for LoRA injection.
|
|
347
|
+
|
|
348
|
+
Returns
|
|
349
|
+
-------
|
|
350
|
+
ParticipantLoRAAdapter
|
|
351
|
+
New adapter with LoRA injected into copied decoder.
|
|
352
|
+
"""
|
|
353
|
+
# Deep copy the base decoder
|
|
354
|
+
decoder_copy = copy.deepcopy(base_decoder)
|
|
355
|
+
|
|
356
|
+
# Create adapter with LoRA
|
|
357
|
+
adapter = ParticipantLoRAAdapter(
|
|
358
|
+
decoder_copy,
|
|
359
|
+
rank=rank,
|
|
360
|
+
alpha=alpha,
|
|
361
|
+
dropout=dropout,
|
|
362
|
+
target_modules=target_modules,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
return adapter
|