Skip to content

Add LoReFT method #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions hparams/Steer/loreft_hparams/apply_loreft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Model related
alg_name: loreft
reft_layers:
- 5
- 10
- 15
- 20
max_length: 512
low_rank_dimension: 4
position: "f5+l5"
15 changes: 15 additions & 0 deletions hparams/Steer/loreft_hparams/generate_loreft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Model related
alg_name: loreft
reft_layers:
- 5
- 10
- 15
- 20
lr: 0.0009
weight_decay: 0.00
max_length: 512
low_rank_dimension: 4
n_epochs: 24
batch_size: 1
gradient_accumulation_steps: 2
position: "f5+l5"
35 changes: 35 additions & 0 deletions steer/datasets/loreft_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from ..utils import build_model_input


def load_loreft_data(train_data, subset, tokenizer, system_prompt = None, use_chat_template = False):
dataset = []
new_dataset = []
pos_data = [{"input":item["question"],"output": item["matching"], "label": 1} for item in train_data]
if subset is not None:
pos_data = pos_data[:subset]
dataset = pos_data # loreft just needs pos data
for _datum in dataset:
if type(_datum['input']) != str:
_datum['input'] = str(_datum['input'])
if type(_datum['output']) != str:
_datum['output'] = str(_datum['output'])
inputs = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template, _datum["output"])
prompts = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template)

new_dataset.append({"input": inputs, "prompt": prompts,"output": _datum["output"], "label": _datum["label"]})

return new_dataset

def load_reft_eval_data(eval_data, subset, tokenizer, system_prompt = None, use_chat_template = False):
dataset = []
new_dataset = []
data = [{"input":item["input"]} for item in eval_data]
if subset is not None:
data = data[:subset]
dataset = data
for _datum in dataset:
if type(_datum['input']) != str:
_datum['input'] = str(_datum['input'])
inputs = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template)
new_dataset.append({"input": inputs})
return new_dataset
24 changes: 15 additions & 9 deletions steer/models/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,21 @@ def forward(self, *args, **kwargs):
self.dot_products.append((top_token, dot_product.cpu().item()))
if self.add_activations_dict:
augmented_output = output[0]
for activations in self.add_activations_dict.values():
if activations is not None:
position_ids = kwargs.get("position_ids", None)
augmented_output = add_vector_from_position(
matrix=augmented_output,
vector=activations,
position_ids=position_ids,
from_pos=self.from_position,
)
method_names = self.add_activations_dict.keys()
if "loreft" in method_names:
intervention_cls = self.add_activations_dict["loreft"].get("intervention_cls", None)
if intervention_cls is not None:
augmented_output = intervention_cls.forward(augmented_output)
else:
for activations in self.add_activations_dict.values():
if activations is not None:
position_ids = kwargs.get("position_ids", None)
augmented_output = add_vector_from_position(
matrix=augmented_output,
vector=activations,
position_ids=position_ids,
from_pos=self.from_position,
)
output = (augmented_output,) + output[1:]

if not self.save_internal_decodings:
Expand Down
12 changes: 9 additions & 3 deletions steer/utils/alg_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
VectorPromptHyperParams,
CAAHyperParams,
LmSteerHyperParams,
MergeVectorHyperParams
MergeVectorHyperParams,
LoReFTHyperParams
)
from ..vector_appliers import(
ApplySaeFeatureHyperParams,
Expand All @@ -14,6 +15,7 @@
ApplyLmSteerHyperParams,
ApplyPromptHyperParams,
ApplyMergeVectorHyperParams,
ApplyLoReFTHyperParams
)

from ..vector_generators import (
Expand All @@ -23,6 +25,7 @@
generate_sae_feature_vectors,
generate_sta_vectors,
generate_merge_vector,
generate_LoReFT_vectors
)
from ..vector_appliers import (
apply_lm_steer,
Expand All @@ -32,6 +35,7 @@
apply_sta,
apply_prompt,
apply_merge_vector,
apply_loreft
)
import torch
DTYPES_DICT ={
Expand All @@ -52,7 +56,8 @@
'vector_prompt': {'train': VectorPromptHyperParams, 'apply': ApplyVectorPromptHyperParams},
'caa': {'train': CAAHyperParams, 'apply': ApplyCAAHyperParams},
'prompt': {'apply': ApplyPromptHyperParams},
'merge_vector': {'train': MergeVectorHyperParams, 'apply': ApplyMergeVectorHyperParams}
'merge_vector': {'train': MergeVectorHyperParams, 'apply': ApplyMergeVectorHyperParams},
'loreft':{'train':LoReFTHyperParams,'apply':ApplyLoReFTHyperParams}
}

METHODS_CLASS_DICT = {
Expand All @@ -62,5 +67,6 @@
'sae_feature': {'train': generate_sae_feature_vectors, 'apply': apply_sae_feature},
'sta': {'train': generate_sta_vectors, 'apply': apply_sta},
'prompt': {'apply': apply_prompt},
'merge_vector': {'train': generate_merge_vector, 'apply': apply_merge_vector}
'merge_vector': {'train': generate_merge_vector, 'apply': apply_merge_vector},
'loreft':{'train':generate_LoReFT_vectors,'apply': apply_loreft}
}
3 changes: 2 additions & 1 deletion steer/vector_appliers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from .prompt import *
from .sta import *
from .sae_feature import *
from .vector_applier import *
from .vector_applier import *
from .loreft import *
2 changes: 2 additions & 0 deletions steer/vector_appliers/loreft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .apply_loreft_intervention import *
from .apply_loreft_intervention_hparam import *
67 changes: 67 additions & 0 deletions steer/vector_appliers/loreft/apply_loreft_intervention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from dataclasses import dataclass
import torch
from .apply_loreft_intervention_hparam import ApplyLoReFTHyperParams
from torch import nn
class LoReFTIntervention():
"""Phi(h) = h + R^T(Wh + b - Rh)"""
def __init__(self, layer, device,weight_proj,weight_source,bias_source):
self.layer = layer
self.device = device
self.W_proj = weight_proj.to(device)
self.W_source = weight_source.to(device)
self.b_source = bias_source.to(device)

def forward(self, h):
rotated_base = torch.bmm(h.float(), self.W_proj)
output = h + torch.bmm(
((torch.bmm(h, self.W_source) + self.b_source) - rotated_base), # batch_size, seq_len, low_rank_dimension
self.W_proj.transpose(-1, -2)
)
return output.to(h.dtype)


def apply_loreft(hparams: ApplyLoReFTHyperParams,model = None):
from ...models import get_model
# make_model
dump_dir = hparams.steer_vector_load_dir
model_name = "loreft"
weight = torch.load(
f"{dump_dir}/{model_name}_weight.pt",weights_only= True
)
bias = torch.load(
f"{dump_dir}/{model_name}_bias.pt",weights_only= True
)
device = hparams.device
weight_keys = list(weight.keys())
model, _ = get_model(hparams)
reft_layers = hparams.reft_layers
method = "loreft"
W_projs = {}
W_sources = {}
b_sources = {}
weight_keys = list(weight.keys())
bias_keys = list(bias.keys())
for layer in reft_layers:
for weight_key in weight_keys:
if weight_key.startswith(f"layer_{layer}"):
if(weight_key.endswith("proj_weight")):
W_projs[layer] = nn.Parameter(weight[weight_key])
elif(weight_key.endswith("source_weight")):
W_sources[layer] = nn.Parameter(weight[weight_key])
for bias_key in bias_keys:
if bias_key.startswith(f"layer_{layer}"):
b_sources[layer] = nn.Parameter(bias[bias_key])
for layer in reft_layers:
intervention_cls = LoReFTIntervention(
layer=layer,
device=device,
weight_proj=W_projs[layer],
weight_source=W_sources[layer],
bias_source=b_sources[layer]
)
activations = {
"intervention_cls": intervention_cls,
}
model.set_add_activations(layer=layer, activations=activations,method_name=method)
return model

34 changes: 34 additions & 0 deletions steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import yaml
from typing import List
from ...utils import HyperParams
from dataclasses import dataclass, field



@dataclass
class ApplyLoReFTHyperParams(HyperParams):
# Method (with predefined values)
alg_name: str = 'loreft'
steer_vector_load_dir: str = None
reft_layers: List[int] = field(default_factory=lambda: [20,21])
max_length: int = 512
batch_size: int = 1
device: str = 'cuda'
low_rank_dimension: int = 1
position: str = "l1"
temperature: float = 1.0


@classmethod
def from_hparams(cls, hparams_name_or_path: str):

if '.yaml' not in hparams_name_or_path:
hparams_name_or_path = hparams_name_or_path + '.yaml'

with open(hparams_name_or_path, "r") as stream:
config = yaml.safe_load(stream)
config = super().construct_float_from_scientific_notation(config)

assert (config and config['alg_name'] == 'loReFT') or print(f'LoReFTHyperParams can not load from {hparams_name_or_path}, '
f'alg_name is {config["alg_name"]} ')
return cls(**config)
13 changes: 9 additions & 4 deletions steer/vector_appliers/vector_applier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ def __init__(self, top_cfg: DictConfig):
self.model = None
self.tokenizer = None
self.device = None
# for loreft generation
self.reft_model = None

def _load_model(self):
if self.model is None:
from ..models import get_model
self.model, self.tokenizer = get_model(self.config)
self.device = self.model.device
def reset_loreft(self):
if self.hparams_dict.get('loreft') is not None:
self.reft_model = None

def apply_steering(self, hparams_dict, model=None, vectors=None):
from ..utils.alg_dict import METHODS_CLASS_DICT
Expand All @@ -33,6 +38,8 @@ def apply_steering(self, hparams_dict, model=None, vectors=None):
# print(f"Applying {alg_name} vectors to model ...")
if alg_name == 'prompt':
model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model)
elif alg_name == "loreft":
model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name])
elif vectors is None or vectors.get(alg_name) is None:
assert hparams_dict[alg_name].steer_vector_load_dir is not None, f"Steer vector load path {hparams_dict[alg_name].steer_vector_load_dir} does not exist !"
model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model)
Expand Down Expand Up @@ -84,16 +91,15 @@ def generate(self, datasets, save_results=True, **kwargs):
if generation_data_size is None:
generation_data_size = -1
dataset = datasets[generation_data_name][:generation_data_size] if generation_data_size > 0 else datasets[generation_data_name]
num_responses = self.config.get('num_responses', 1)
for item in tqdm(dataset, desc=f"Evaluating dataset {generation_data_name}"):
if not item.get('input'):
continue
current_preds = []
current_output = []
input_text = self._process_input_text(item['input'])
inputs = self.tokenizer(input_text, return_tensors="pt", add_special_tokens = not self.config.use_chat_template).to(self.device)

num_responses = self.config.get('num_responses', 1)

for j in range(num_responses):
if num_responses > 1:
set_seed(j)
Expand All @@ -115,7 +121,6 @@ def generate(self, datasets, save_results=True, **kwargs):
output=output[0][inputs['input_ids'].shape[1]:]
text = self.tokenizer.decode(output, skip_special_tokens=True)
orig_preds.append([text])

formatted_results = self._format_result(dataset, orig_preds=orig_preds,preds=preds, complete_output=complete_output)
if save_results:
self.save_results(formatted_results, generation_data_name)
Expand Down
2 changes: 2 additions & 0 deletions steer/vector_generators/LoReFT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .generate_LoReFT_hparams import *
from .generate_LoReFT_vectors import *
42 changes: 42 additions & 0 deletions steer/vector_generators/LoReFT/generate_LoReFT_hparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import yaml
from typing import List
from ...utils import HyperParams
from dataclasses import dataclass, field
import torch


@dataclass
class LoReFTHyperParams(HyperParams):
# Method (with predefined values)
alg_name: str = 'loreft'
steer_vector_output_dir: str = None
steer_train_dataset: str=None
reft_layers: List[int] = field(default_factory=lambda: [20,21])
lr: float = 0.0001
n_epochs: int = 3
max_length: int = 256
batch_size: int = 3
gradient_accumulation_steps: int = 1
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
seed: int = 42
subset: int = None
low_rank_dimension: int = 1
position: str = "l1"
weight_decay: float = 0.01
save_vectors: bool = True
use_cache : bool = True


@classmethod
def from_hparams(cls, hparams_name_or_path: str):

if '.yaml' not in hparams_name_or_path:
hparams_name_or_path = hparams_name_or_path + '.yaml'

with open(hparams_name_or_path, "r") as stream:
config = yaml.safe_load(stream)
config = super().construct_float_from_scientific_notation(config)

assert (config and config['alg_name'] == 'loReFT') or print(f'LoReFTHyperParams can not load from {hparams_name_or_path}, '
f'alg_name is {config["alg_name"]} ')
return cls(**config)
Loading