From 98b95fa47bd21d45ef4a75ec18ab9de4a3e8e3d8 Mon Sep 17 00:00:00 2001 From: violetevergarden1111 <2451239178@qq.com> Date: Wed, 9 Jul 2025 21:38:11 +0800 Subject: [PATCH 1/2] Add LoReFT method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 运行结果在EasyEdit_Example_LoReFT_translate.ipynb文件中 --- .../Steer/loreft_hparams/apply_loreft.yaml | 10 + .../Steer/loreft_hparams/generate_loreft.yaml | 15 ++ steer/datasets/loreft_data.py | 35 +++ steer/utils/alg_dict.py | 12 +- steer/vector_appliers/__init__.py | 3 +- steer/vector_appliers/loreft/__init__.py | 2 + .../loreft/apply_loreft_intervention.py | 117 +++++++++ .../apply_loreft_intervention_hparam.py | 34 +++ steer/vector_appliers/vector_applier.py | 194 +++++++++++--- steer/vector_generators/LoReFT/__init__.py | 2 + .../LoReFT/generate_LoReFT_hparams.py | 42 ++++ .../LoReFT/generate_LoReFT_vectors.py | 186 ++++++++++++++ steer/vector_generators/__init__.py | 3 +- steer/vector_generators/vector_generators.py | 3 + .../EasyEdit_Example_LoReFT_translate.ipynb | 238 ++++++++++++++++++ .../config_loreft_translate.yaml | 39 +++ 16 files changed, 900 insertions(+), 35 deletions(-) create mode 100644 hparams/Steer/loreft_hparams/apply_loreft.yaml create mode 100644 hparams/Steer/loreft_hparams/generate_loreft.yaml create mode 100644 steer/datasets/loreft_data.py create mode 100644 steer/vector_appliers/loreft/__init__.py create mode 100644 steer/vector_appliers/loreft/apply_loreft_intervention.py create mode 100644 steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py create mode 100644 steer/vector_generators/LoReFT/__init__.py create mode 100644 steer/vector_generators/LoReFT/generate_LoReFT_hparams.py create mode 100644 steer/vector_generators/LoReFT/generate_LoReFT_vectors.py create mode 100644 tutorial-notebooks/EasyEdit_Example_LoReFT_translate.ipynb create mode 100644 tutorial-notebooks/config_loreft_translate.yaml diff --git a/hparams/Steer/loreft_hparams/apply_loreft.yaml b/hparams/Steer/loreft_hparams/apply_loreft.yaml new file mode 100644 index 00000000..7dccb2df --- /dev/null +++ b/hparams/Steer/loreft_hparams/apply_loreft.yaml @@ -0,0 +1,10 @@ +# Model related +alg_name: loreft +reft_layers: + - 5 + - 10 + - 15 + - 20 +max_length: 512 +low_rank_dimension: 4 +position: "f5+l5" \ No newline at end of file diff --git a/hparams/Steer/loreft_hparams/generate_loreft.yaml b/hparams/Steer/loreft_hparams/generate_loreft.yaml new file mode 100644 index 00000000..931adbee --- /dev/null +++ b/hparams/Steer/loreft_hparams/generate_loreft.yaml @@ -0,0 +1,15 @@ +# Model related +alg_name: loreft +reft_layers: + - 5 + - 10 + - 15 + - 20 +lr: 0.0009 +weight_decay: 0.00 +max_length: 512 +low_rank_dimension: 4 +n_epochs: 24 +batch_size: 1 +gradient_accumulation_steps: 2 +position: "f5+l5" \ No newline at end of file diff --git a/steer/datasets/loreft_data.py b/steer/datasets/loreft_data.py new file mode 100644 index 00000000..dae41f6e --- /dev/null +++ b/steer/datasets/loreft_data.py @@ -0,0 +1,35 @@ +from ..utils import build_model_input + + +def load_loreft_data(train_data, subset, tokenizer, system_prompt = None, use_chat_template = False): + dataset = [] + new_dataset = [] + pos_data = [{"input":item["question"],"output": item["matching"], "label": 1} for item in train_data] + if subset is not None: + pos_data = pos_data[:subset] + dataset = pos_data # loreft just needs pos data + for _datum in dataset: + if type(_datum['input']) != str: + _datum['input'] = str(_datum['input']) + if type(_datum['output']) != str: + _datum['output'] = str(_datum['output']) + inputs = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template, _datum["output"]) + prompts = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template) + + new_dataset.append({"input": inputs, "prompt": prompts,"output": _datum["output"], "label": _datum["label"]}) + + return new_dataset + +def load_reft_eval_data(eval_data, subset, tokenizer, system_prompt = None, use_chat_template = False): + dataset = [] + new_dataset = [] + data = [{"input":item["input"]} for item in eval_data] + if subset is not None: + data = data[:subset] + dataset = data + for _datum in dataset: + if type(_datum['input']) != str: + _datum['input'] = str(_datum['input']) + inputs = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template) + new_dataset.append({"input": inputs}) + return new_dataset \ No newline at end of file diff --git a/steer/utils/alg_dict.py b/steer/utils/alg_dict.py index 82e94226..c9268bc1 100644 --- a/steer/utils/alg_dict.py +++ b/steer/utils/alg_dict.py @@ -4,7 +4,8 @@ VectorPromptHyperParams, CAAHyperParams, LmSteerHyperParams, - MergeVectorHyperParams + MergeVectorHyperParams, + LoReFTHyperParams ) from ..vector_appliers import( ApplySaeFeatureHyperParams, @@ -14,6 +15,7 @@ ApplyLmSteerHyperParams, ApplyPromptHyperParams, ApplyMergeVectorHyperParams, + ApplyLoReFTHyperParams ) from ..vector_generators import ( @@ -23,6 +25,7 @@ generate_sae_feature_vectors, generate_sta_vectors, generate_merge_vector, + generate_LoReFT_vectors ) from ..vector_appliers import ( apply_lm_steer, @@ -32,6 +35,7 @@ apply_sta, apply_prompt, apply_merge_vector, + apply_loreft ) import torch DTYPES_DICT ={ @@ -52,7 +56,8 @@ 'vector_prompt': {'train': VectorPromptHyperParams, 'apply': ApplyVectorPromptHyperParams}, 'caa': {'train': CAAHyperParams, 'apply': ApplyCAAHyperParams}, 'prompt': {'apply': ApplyPromptHyperParams}, - 'merge_vector': {'train': MergeVectorHyperParams, 'apply': ApplyMergeVectorHyperParams} + 'merge_vector': {'train': MergeVectorHyperParams, 'apply': ApplyMergeVectorHyperParams}, + 'loreft':{'train':LoReFTHyperParams,'apply':ApplyLoReFTHyperParams} } METHODS_CLASS_DICT = { @@ -62,5 +67,6 @@ 'sae_feature': {'train': generate_sae_feature_vectors, 'apply': apply_sae_feature}, 'sta': {'train': generate_sta_vectors, 'apply': apply_sta}, 'prompt': {'apply': apply_prompt}, - 'merge_vector': {'train': generate_merge_vector, 'apply': apply_merge_vector} + 'merge_vector': {'train': generate_merge_vector, 'apply': apply_merge_vector}, + 'loreft':{'train':generate_LoReFT_vectors,'apply': apply_loreft} } \ No newline at end of file diff --git a/steer/vector_appliers/__init__.py b/steer/vector_appliers/__init__.py index e60e64af..392ca590 100644 --- a/steer/vector_appliers/__init__.py +++ b/steer/vector_appliers/__init__.py @@ -6,4 +6,5 @@ from .prompt import * from .sta import * from .sae_feature import * -from .vector_applier import * \ No newline at end of file +from .vector_applier import * +from .loreft import * \ No newline at end of file diff --git a/steer/vector_appliers/loreft/__init__.py b/steer/vector_appliers/loreft/__init__.py new file mode 100644 index 00000000..a066a6f9 --- /dev/null +++ b/steer/vector_appliers/loreft/__init__.py @@ -0,0 +1,2 @@ +from .apply_loreft_intervention import * +from .apply_loreft_intervention_hparam import * \ No newline at end of file diff --git a/steer/vector_appliers/loreft/apply_loreft_intervention.py b/steer/vector_appliers/loreft/apply_loreft_intervention.py new file mode 100644 index 00000000..088b36ea --- /dev/null +++ b/steer/vector_appliers/loreft/apply_loreft_intervention.py @@ -0,0 +1,117 @@ +from dataclasses import dataclass +import pyreft +import transformers +import torch +from typing import Sequence, Dict +from .apply_loreft_intervention_hparam import ApplyLoReFTHyperParams +from torch import nn +from pyvene import SourcelessIntervention, TrainableIntervention, DistributedRepresentationIntervention +class ConceptReFTIntervention( + SourcelessIntervention, + TrainableIntervention, + DistributedRepresentationIntervention +): + """ + Phi(h) = h + R^T(Wh + b - Rh) + Ref: https://arxiv.org/pdf/2404.03592 + + Note that this intervention is used for concept-based Direft. + The main difference is that weights are assumed to be trained and saved as 3D tensors. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs, keep_last_dim=True) + self.W_proj = nn.Parameter(torch.zeros( + kwargs["n_concepts"], self.embed_dim, kwargs["low_rank_dimension"])) + self.W_source = nn.Parameter(torch.zeros( + kwargs["n_concepts"], self.embed_dim, kwargs["low_rank_dimension"])) + self.b_source = nn.Parameter(torch.zeros( + kwargs["n_concepts"], kwargs["low_rank_dimension"])) + + def encode( + self, base, source=None, subspaces=None + ): + """High-dimensional concept space.""" + proj_weight = self.W_proj[subspaces["input_subspaces"]] # batch_size, embed_dim, low_rank_dimension + rotated_base = torch.bmm(base, proj_weight) # [batch_size, seq_len, embed_dim] X [batch_size, embed_dim, low_rank_dimension] + + return rotated_base # batch_size, seq_len, low_rank_dimension + + def forward( + self, base, source=None, subspaces=None + ): + proj_weight = self.W_proj[subspaces["idx"]] # batch_size, embed_dim, low_rank_dimension + source_weight = self.W_source[subspaces["idx"]] # batch_size, embed_dim, low_rank_dimension + source_bias = self.b_source[subspaces["idx"]].unsqueeze(dim=1) # batch_size, 1, low_rank_dimension + + rotated_base = torch.bmm(base.float(), proj_weight) # batch_size, seq_len, low_rank_dimension + output = base + torch.bmm( + ((torch.bmm(base, source_weight) + source_bias) - rotated_base), # batch_size, seq_len, low_rank_dimension + proj_weight.transpose(-1, -2) + ) + return output.to(base.dtype) + +def apply_loreft(hparams: ApplyLoReFTHyperParams,model = None): + from ...models import get_model + # make_model + dump_dir = hparams.steer_vector_load_dir + model_name = "loreft" + weight = torch.load( + f"{dump_dir}/{model_name}_weight.pt",weights_only= True + ) + bias = torch.load( + f"{dump_dir}/{model_name}_bias.pt",weights_only= True + ) + device = hparams.device + weight_keys = list(weight.keys()) + n_concepts = weight[weight_keys[0]].shape[0] + low_rank_dimension = weight[weight_keys[0]].shape[-1] + model, _ = get_model(hparams) + reft_layers = hparams.reft_layers + dtype = model.torch_dtype + intervention_cls = ConceptReFTIntervention + reft_config = pyreft.ReftConfig(representations=[{ + "layer": l, "component": "block_output", + "low_rank_dimension": low_rank_dimension, + "intervention": intervention_cls(n_concepts=n_concepts,embed_dim=model.model.config.hidden_size, + low_rank_dimension=low_rank_dimension,dtype =dtype)} for l in reft_layers]) + reft_model = pyreft.get_reft_model(model.model, reft_config) + reft_model.set_device(device) + for intervention_name, intervention in reft_model.interventions.items(): + intervention.W_proj.data = weight[f"{intervention_name}.proj_weight"] + intervention.W_source.data = weight[f"{intervention_name}.source_weight"] + intervention.b_source.data = bias[f"{intervention_name}.bias"] + for k,v in reft_model.interventions.items(): + v.eval() + return reft_model,model +@dataclass +class InterventionEvalDataCollator(object): + """Collate examples for Intervention.""" + + tokenizer: transformers.AutoTokenizer + data_collator: transformers.DefaultDataCollator + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + """ + intervention_locations will be something like [1,10,0,0,0] where all 0s are padding intervention locations. + """ + max_intervention_len = max([len(inst["intervention_locations"][0]) for inst in instances]) + max_seq_len = max([len(inst["input_ids"]) for inst in instances]) + + for inst in instances: + non_pad_len = len(inst["input_ids"]) + _intervention_location_paddings = torch.tensor( + [[-1 for _ in range(max_intervention_len - len(inst["intervention_locations"][0]))] for _ in range(inst["intervention_locations"].shape[0])]) # pointing to the first padding token + inst["intervention_locations"] = torch.cat([inst["intervention_locations"], _intervention_location_paddings], dim=-1).int() + inst["intervention_locations"] = inst["intervention_locations"] + 1 # shift by 1 to point to the first non-padding token, and all paddings will be 0. + + _input_id_paddings = torch.tensor( + [self.tokenizer.pad_token_id for _ in range(max_seq_len - non_pad_len)]) + offset = max_seq_len - non_pad_len + inst["intervention_locations"] = inst["intervention_locations"] + offset + inst["input_ids"] = torch.cat((_input_id_paddings, torch.tensor([self.tokenizer.pad_token_id]), inst["input_ids"])).int() + + inst["attention_mask"] = (inst["input_ids"] != self.tokenizer.pad_token_id).int() + + batch_inputs = self.data_collator(instances) + return batch_inputs + \ No newline at end of file diff --git a/steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py b/steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py new file mode 100644 index 00000000..c10ef7b0 --- /dev/null +++ b/steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py @@ -0,0 +1,34 @@ +import yaml +from typing import List +from ...utils import HyperParams +from dataclasses import dataclass, field + + + +@dataclass +class ApplyLoReFTHyperParams(HyperParams): + # Method (with predefined values) + alg_name: str = 'loreft' + steer_vector_load_dir: str = None + reft_layers: List[int] = field(default_factory=lambda: [20,21]) + max_length: int = 512 + batch_size: int = 1 + device: str = 'cuda' + low_rank_dimension: int = 1 + position: str = "l1" + temperature: float = 1.0 + + + @classmethod + def from_hparams(cls, hparams_name_or_path: str): + + if '.yaml' not in hparams_name_or_path: + hparams_name_or_path = hparams_name_or_path + '.yaml' + + with open(hparams_name_or_path, "r") as stream: + config = yaml.safe_load(stream) + config = super().construct_float_from_scientific_notation(config) + + assert (config and config['alg_name'] == 'loReFT') or print(f'LoReFTHyperParams can not load from {hparams_name_or_path}, ' + f'alg_name is {config["alg_name"]} ') + return cls(**config) diff --git a/steer/vector_appliers/vector_applier.py b/steer/vector_appliers/vector_applier.py index 92cca5e6..a13fa41f 100644 --- a/steer/vector_appliers/vector_applier.py +++ b/steer/vector_appliers/vector_applier.py @@ -18,12 +18,17 @@ def __init__(self, top_cfg: DictConfig): self.model = None self.tokenizer = None self.device = None + # for loreft generation + self.reft_model = None def _load_model(self): if self.model is None: from ..models import get_model self.model, self.tokenizer = get_model(self.config) self.device = self.model.device + def reset_loreft(self): + if self.hparams_dict.get('loreft') is not None: + self.reft_model = None def apply_steering(self, hparams_dict, model=None, vectors=None): from ..utils.alg_dict import METHODS_CLASS_DICT @@ -33,6 +38,9 @@ def apply_steering(self, hparams_dict, model=None, vectors=None): # print(f"Applying {alg_name} vectors to model ...") if alg_name == 'prompt': model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model) + elif alg_name == "loreft": + reft_model, model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name]) + self.reft_model = reft_model elif vectors is None or vectors.get(alg_name) is None: assert hparams_dict[alg_name].steer_vector_load_dir is not None, f"Steer vector load path {hparams_dict[alg_name].steer_vector_load_dir} does not exist !" model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model) @@ -84,38 +92,40 @@ def generate(self, datasets, save_results=True, **kwargs): if generation_data_size is None: generation_data_size = -1 dataset = datasets[generation_data_name][:generation_data_size] if generation_data_size > 0 else datasets[generation_data_name] - - for item in tqdm(dataset, desc=f"Evaluating dataset {generation_data_name}"): - if not item.get('input'): - continue - current_preds = [] - current_output = [] - input_text = self._process_input_text(item['input']) - inputs = self.tokenizer(input_text, return_tensors="pt", add_special_tokens = not self.config.use_chat_template).to(self.device) - - num_responses = self.config.get('num_responses', 1) - for j in range(num_responses): - if num_responses > 1: - set_seed(j) - with torch.no_grad(): - if self.config.get('steer_from_end_position', False): - instr_pos = self.find_instruction_end_postion(inputs['input_ids'][0]) - print("Steering from end position:", instr_pos) - self.model.set_from_positions(instr_pos) - output = self.model.model.generate(**inputs, **generation_params) - current_output.append(self.tokenizer.decode(output[0], skip_special_tokens=False)) - output=output[0][inputs['input_ids'].shape[1]:] - text = self.tokenizer.decode(output, skip_special_tokens=True) - current_preds.append(text) - preds.append(current_preds) - complete_output.append(current_output) + num_responses = self.config.get('num_responses', 1) + # judge method type + if self.hparams_dict.get('loreft') is not None: + preds,orig_preds,complete_output = self.loreft_generate(dataset, num_responses) + else: + for item in tqdm(dataset, desc=f"Evaluating dataset {generation_data_name}"): + if not item.get('input'): + continue + current_preds = [] + current_output = [] + input_text = self._process_input_text(item['input']) + inputs = self.tokenizer(input_text, return_tensors="pt", add_special_tokens = not self.config.use_chat_template).to(self.device) - if self.config.get('generate_orig_output', False): - output = self.model.ori_generate(**inputs, **generation_params) - output=output[0][inputs['input_ids'].shape[1]:] - text = self.tokenizer.decode(output, skip_special_tokens=True) - orig_preds.append([text]) + for j in range(num_responses): + if num_responses > 1: + set_seed(j) + with torch.no_grad(): + if self.config.get('steer_from_end_position', False): + instr_pos = self.find_instruction_end_postion(inputs['input_ids'][0]) + print("Steering from end position:", instr_pos) + self.model.set_from_positions(instr_pos) + output = self.model.model.generate(**inputs, **generation_params) + current_output.append(self.tokenizer.decode(output[0], skip_special_tokens=False)) + output=output[0][inputs['input_ids'].shape[1]:] + text = self.tokenizer.decode(output, skip_special_tokens=True) + current_preds.append(text) + preds.append(current_preds) + complete_output.append(current_output) + if self.config.get('generate_orig_output', False): + output = self.model.ori_generate(**inputs, **generation_params) + output=output[0][inputs['input_ids'].shape[1]:] + text = self.tokenizer.decode(output, skip_special_tokens=True) + orig_preds.append([text]) formatted_results = self._format_result(dataset, orig_preds=orig_preds,preds=preds, complete_output=complete_output) if save_results: self.save_results(formatted_results, generation_data_name) @@ -163,3 +173,127 @@ def find_instruction_end_postion(self, tokens): start_pos = tokens.size(0) - 1 return start_pos + + def loreft_generate(self,dataset,num_responses): + import pyreft + import transformers + from torch.utils.data import DataLoader + import datasets + from .loreft.apply_loreft_intervention import InterventionEvalDataCollator + from ..datasets.loreft_data import load_reft_eval_data + def make_eval_data_module( + tokenizer:transformers.PreTrainedTokenizer, + model,df,positions="all", + num_interventions = 1, + nonstop = True, + share_weights = True, + max_length = 512 + ): + all_base_input_ids, all_intervention_locations = [], [] + for row in df: + base_prompt = row["input"] + base_prompt_ids = tokenizer( + base_prompt, max_length=max_length, truncation=True, return_tensors="pt")["input_ids"][0] + base_prompt_length = len(base_prompt_ids) + if positions == "all_prompt": + intervention_locations = torch.tensor([[i for i in range(base_prompt_length)]]) + else: + first_n, last_n = pyreft.parse_positions(positions) + intervention_locations = pyreft.get_intervention_locations( + last_position=base_prompt_length, + first_n=first_n, + last_n=last_n, + pad_mode="first", + num_interventions=num_interventions, + share_weights=share_weights, + ) + all_base_input_ids.append(base_prompt_ids) + all_intervention_locations.append(intervention_locations) + eval_dataset = datasets.Dataset.from_dict({ + "input_ids": all_base_input_ids, + "intervention_locations": all_intervention_locations, + }) + eval_dataset.set_format( + type='torch', columns=[ + 'input_ids', 'intervention_locations',]) + data_collator_fn = transformers.DefaultDataCollator( + return_tensors="pt" + ) + data_collator = InterventionEvalDataCollator(tokenizer=tokenizer, data_collator=data_collator_fn) + return dict(train_dataset=None, eval_dataset=eval_dataset, data_collator=data_collator) + hparams = self.hparams_dict["loreft"] + eval_dataset = load_reft_eval_data(dataset,None, self.tokenizer,"You are a helpful assistant.", use_chat_template=True) + batch_size = 1 # it will be something wrong if it is not 1 + eval_output_length = hparams.max_length + temperature = hparams.temperature if hasattr(hparams, "temperature") else 1.0 + reft_layers = hparams.reft_layers + number_of_interventions = len(reft_layers) + position = "l1" if not hasattr(hparams,"position") else hparams.position + data_module = make_eval_data_module( + tokenizer=self.tokenizer, + model=self.model.model, + df=eval_dataset, + positions=position, + num_interventions=number_of_interventions, + nonstop=True, + share_weights=True, + max_length=hparams.max_length + ) + eval_dataloader = DataLoader( + data_module["eval_dataset"],shuffle=False, + batch_size=batch_size, + collate_fn = data_module["data_collator"], + ) + self.reft_model.set_device("cuda") + result_generations = [] + result_origins = [] + result_complete = [] + for j in range(num_responses): + all_generations = [] + all_origins = [] + all_complete = [] + for i, bactch in enumerate(eval_dataloader): + inputs = {k: v.to(self.device) for k, v in bactch.items()} + if "intervention_locations" in inputs: + if inputs["intervention_locations"].dim() == 3: + unit_locations={"sources->base": ( + None, + inputs["intervention_locations"].permute(1, 0, 2).tolist() + )} + else: + # this is dummy for lora only baseline + unit_locations={"sources->base": (None, 0)} + origin_outputs, intervention_outputs = self.reft_model.generate( + {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]}, + unit_locations=unit_locations, intervene_on_prompt=True, + subspaces=[{"idx":[0]}] * number_of_interventions, + max_new_tokens=eval_output_length, do_sample=True, + temperature=temperature,output_original_output = True + ) + # Decode and print only the generated text without prompt tokens + input_lengths = [len(input_ids) for input_ids in inputs["input_ids"]] + generated_texts = [ + self.tokenizer.decode(generation[input_length:], skip_special_tokens=True) + for generation, input_length in zip(intervention_outputs, input_lengths) + ] + origin_texts = [ + self.tokenizer.decode(generation[input_length:], skip_special_tokens=True) + for generation, input_length in zip(origin_outputs, input_lengths) + ] + complete_output = [ + self.tokenizer.decode(generation, skip_special_tokens=False) + for generation in intervention_outputs + ] + all_generations += generated_texts + all_origins += origin_texts + all_complete += complete_output + for idx in range(len(all_generations)): + if j == 0: + result_generations.append([all_generations[idx]]) + result_origins.append([all_origins[idx]]) + result_complete.append([all_complete[idx]]) + else: + result_generations[idx].append(all_generations[idx]) + result_origins[idx].append(all_origins[idx]) + result_complete[idx].append(all_complete[idx]) + return result_generations, result_origins, result_complete \ No newline at end of file diff --git a/steer/vector_generators/LoReFT/__init__.py b/steer/vector_generators/LoReFT/__init__.py new file mode 100644 index 00000000..631b4556 --- /dev/null +++ b/steer/vector_generators/LoReFT/__init__.py @@ -0,0 +1,2 @@ +from .generate_LoReFT_hparams import * +from .generate_LoReFT_vectors import * \ No newline at end of file diff --git a/steer/vector_generators/LoReFT/generate_LoReFT_hparams.py b/steer/vector_generators/LoReFT/generate_LoReFT_hparams.py new file mode 100644 index 00000000..95ed4544 --- /dev/null +++ b/steer/vector_generators/LoReFT/generate_LoReFT_hparams.py @@ -0,0 +1,42 @@ +import yaml +from typing import List +from ...utils import HyperParams +from dataclasses import dataclass, field +import torch + + +@dataclass +class LoReFTHyperParams(HyperParams): + # Method (with predefined values) + alg_name: str = 'loreft' + steer_vector_output_dir: str = None + steer_train_dataset: str=None + reft_layers: List[int] = field(default_factory=lambda: [20,21]) + lr: float = 0.0001 + n_epochs: int = 3 + max_length: int = 256 + batch_size: int = 3 + gradient_accumulation_steps: int = 1 + device: str = 'cuda' if torch.cuda.is_available() else 'cpu' + seed: int = 42 + subset: int = None + low_rank_dimension: int = 1 + position: str = "l1" + weight_decay: float = 0.01 + save_vectors: bool = True + use_cache : bool = True + + + @classmethod + def from_hparams(cls, hparams_name_or_path: str): + + if '.yaml' not in hparams_name_or_path: + hparams_name_or_path = hparams_name_or_path + '.yaml' + + with open(hparams_name_or_path, "r") as stream: + config = yaml.safe_load(stream) + config = super().construct_float_from_scientific_notation(config) + + assert (config and config['alg_name'] == 'loReFT') or print(f'LoReFTHyperParams can not load from {hparams_name_or_path}, ' + f'alg_name is {config["alg_name"]} ') + return cls(**config) diff --git a/steer/vector_generators/LoReFT/generate_LoReFT_vectors.py b/steer/vector_generators/LoReFT/generate_LoReFT_vectors.py new file mode 100644 index 00000000..1389c654 --- /dev/null +++ b/steer/vector_generators/LoReFT/generate_LoReFT_vectors.py @@ -0,0 +1,186 @@ +import sys + +import torch.utils.data.dataset + + + + + + +sys.path.append("../../") +sys.path.append("../") +sys.path.append("./") +import os +import torch +import argparse +from tqdm import tqdm +from dotenv import load_dotenv +import random +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import DataCollatorForSeq2Seq, TrainingArguments, set_seed +import pyreft +import matplotlib.pyplot as plt +from copy import deepcopy + +from .generate_LoReFT_hparams import LoReFTHyperParams +IGNORE_INDEX = -100 + +load_dotenv() + +HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group["lr"] +def save_interventions(reft_model, dump_dir,model_name): + proj_weights = [] + source_weights = [] + source_biases = [] + intervention_names = [] + for intervention_name, intervention in reft_model.interventions.items(): + intervention_names.append(intervention_name) + intervention_state_dict = intervention.state_dict() + proj_weight = intervention_state_dict["rotate_layer"] # [embed_dim, low_rank_dimension] + source_weight = intervention_state_dict["weight"].T # [embed_dim, low_rank_dimension] + source_bias = intervention_state_dict["bias"] # [low_rank_dimension] + proj_weights.append(proj_weight) + source_weights.append(source_weight) + source_biases.append(source_bias) + weight_file = dump_dir / f"{model_name}_weight.pt" + if weight_file.exists(): + existing_weight = torch.load(weight_file) + for i, intervention_name in enumerate(intervention_names): + existing_weight[f"{intervention_name}.proj_weight"] = torch.cat( + [existing_weight[f"{intervention_name}.proj_weight"], proj_weights[i].cpu().unsqueeze(dim=0)], dim=0) + existing_weight[f"{intervention_name}.source_weight"] = torch.cat( + [existing_weight[f"{intervention_name}.source_weight"], source_weights[i].cpu().unsqueeze(dim=0)], dim=0) + else: + existing_weight = {} + for i, intervention_name in enumerate(intervention_names): + existing_weight[f"{intervention_name}.proj_weight"] = proj_weights[i].cpu().unsqueeze(dim=0) + existing_weight[f"{intervention_name}.source_weight"] = source_weights[i].cpu().unsqueeze(dim=0) + torch.save(existing_weight, weight_file) + bias_file = dump_dir / f"{model_name}_bias.pt" + if bias_file.exists(): + existing_bias = torch.load(bias_file) + for i, intervention_name in enumerate(intervention_names): + existing_bias[f"{intervention_name}.bias"] = torch.cat( + [existing_bias[f"{intervention_name}.bias"], source_biases[i].cpu().unsqueeze(dim=0)], dim=0) + else: + existing_bias = {} + for i, intervention_name in enumerate(intervention_names): + existing_bias[f"{intervention_name}.bias"] = source_biases[i].cpu().unsqueeze(dim=0) + torch.save(existing_bias, bias_file) + +def generate_LoReFT_vectors(hparams:LoReFTHyperParams, dataset, model = None): + from ...models.get_model import get_model + from ...datasets.loreft_data import load_loreft_data + from transformers import get_scheduler + import transformers + from pathlib import Path + del_model = True + if model is None: + model, tokenizer = get_model(hparams) + else: + del_model = False + model, tokenizer = model, model.tokenizer + model.hparams = hparams + system_prompt = "" if hparams.system_prompt is None else hparams.system_prompt + use_chat_template = True if hparams.use_chat_template is None else hparams.use_chat_template + subset = None if hparams.subset is None else hparams.subset + tokenizer.model_max_length = hparams.max_length + model.model.eval() + device = hparams.device + model.model.to(device) + train_data = load_loreft_data(dataset,subset,tokenizer,system_prompt,use_chat_template) + torch_dtype = model.torch_dtype + batch_size = hparams.batch_size + low_rank_dimension = hparams.low_rank_dimension + if hparams.steer_vector_output_dir is None: + assert "Need Steer Vector Output Dir" + position = hparams.position + reft_layers = hparams.reft_layers + num_interventions = len(reft_layers) + data_module = pyreft.make_multiple_position_supervised_data_module( + tokenizer=tokenizer,model=model.model, + inputs = [item["prompt"] for item in train_data], + outputs= [item["output"] for item in train_data], + positions = position, + num_interventions= num_interventions, + nonstop=True, + share_weights=True + ) + + train_dataloader = DataLoader( + data_module["train_dataset"],shuffle=True, + batch_size=batch_size,collate_fn=data_module["data_collator"] + ) + # make_model + intervention_cls = pyreft.LoreftIntervention + + reft_config = pyreft.ReftConfig(representations=[{ + "layer": l, "component": "block_output", + "low_rank_dimension": low_rank_dimension, + "intervention": intervention_cls(embed_dim=model.model.config.hidden_size, + low_rank_dimension=low_rank_dimension,dtype = torch_dtype)} for l in reft_layers]) + reft_model = pyreft.get_reft_model(model.model, reft_config) + lr = hparams.lr + weight_decay = hparams.weight_decay + optimizer = torch.optim.AdamW( + reft_model.parameters(), lr=lr,weight_decay=weight_decay, + betas=(0.9,0.999),eps = 1e-8 + ) + n_epochs= hparams.n_epochs + gradient_accumulation_steps = 1 if hparams.gradient_accumulation_steps is None else hparams.gradient_accumulation_steps + num_training_steps = n_epochs * max(1, len(train_dataloader))//gradient_accumulation_steps + lr_scheduler = get_scheduler( + "linear", optimizer=optimizer, + num_warmup_steps=0,num_training_steps=num_training_steps + ) + progress_bar, curr_step = tqdm(range(num_training_steps), leave=True), 0 + + reft_model.print_trainable_parameters() + losses = [] + for epoch in range(n_epochs): + for step, batch in enumerate(train_dataloader): + inputs = {k : v.to(device) for k,v in batch.items()} + unit_locations={"sources->base": ( + None, + inputs["intervention_locations"].permute(1, 0, 2).tolist() + )} + _, cf_outputs = reft_model.forward( + base = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + },unit_locations=unit_locations, + labels=inputs["labels"], + use_cache=False + ) + loss = cf_outputs.loss.mean() + loss.backward() + loss = loss.mean() + if gradient_accumulation_steps > 1: + loss = loss / gradient_accumulation_steps + if(step + 1) % gradient_accumulation_steps == 0 or ( step + 1) == len(train_dataloader): + torch.nn.utils.clip_grad_norm_(reft_model.parameters(), 1.0) + curr_step += 1 + curr_lr = get_lr(optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + losses.append(loss.item()) + progress_bar.set_description( + "lr %.6f || loss %.6f " % (curr_lr, loss)) + progress_bar.close() + save_directory = os.path.join(hparams.steer_vector_output_dir, hparams.alg_name + '_vector') + if not os.path.exists(save_directory): + os.makedirs(save_directory) + save_interventions(reft_model,Path(save_directory),"loreft") + if del_model: + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + diff --git a/steer/vector_generators/__init__.py b/steer/vector_generators/__init__.py index 4f3b8d79..10706191 100644 --- a/steer/vector_generators/__init__.py +++ b/steer/vector_generators/__init__.py @@ -4,4 +4,5 @@ from .sta import * from .sae_feature import * from .merge import * -from .vector_generators import * \ No newline at end of file +from .vector_generators import * +from .LoReFT import * \ No newline at end of file diff --git a/steer/vector_generators/vector_generators.py b/steer/vector_generators/vector_generators.py index 28ffb7ac..a1e5bca3 100644 --- a/steer/vector_generators/vector_generators.py +++ b/steer/vector_generators/vector_generators.py @@ -39,6 +39,9 @@ def generate_vectors(self, datasets = None,): print(f"Generating {key} vectors ...") if alg_name in ['lm_steer', 'caa', 'vector_prompt', 'sta']: vectors = METHODS_CLASS_DICT[alg_name]['train'](hparams, datasets[dataset_name]) + elif alg_name == "loreft": + vectors = None + METHODS_CLASS_DICT[alg_name]['train'](hparams, datasets[dataset_name]) else: vectors = METHODS_CLASS_DICT[alg_name]['train'](hparams) generated_vectors[dataset_name][key] = vectors diff --git a/tutorial-notebooks/EasyEdit_Example_LoReFT_translate.ipynb b/tutorial-notebooks/EasyEdit_Example_LoReFT_translate.ipynb new file mode 100644 index 00000000..1b5e467e --- /dev/null +++ b/tutorial-notebooks/EasyEdit_Example_LoReFT_translate.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.\n" + ] + }, + { + "data": { + "text/plain": [ + "{'model_name_or_path': 'C:/git/axbench-main/google/gemma-2-2b-it/', 'torch_dtype': 'bfloat16', 'device': 'cuda', 'use_chat_template': True, 'system_prompt': '', 'steer_train_hparam_paths': ['../hparams/Steer/loreft_hparams/generate_loreft.yaml'], 'steer_train_dataset': ['translate'], 'steer_vector_output_dirs': 'vectors/gemma-2-2b-it/', 'apply_steer_hparam_paths': ['../hparams/Steer/loreft_hparams/apply_loreft.yaml'], 'steer_vector_load_dir': ['vectors/gemma-2-2b-it/translate/loreft_vector'], 'generation_data': ['nontoxic'], 'generation_data_size': 5, 'generation_output_dir': 'vectors/gemma-2-2b-it/translate_loreft_results/', 'num_responses': 1, 'steer_from_end_position': False, 'generate_orig_output': True, 'generation_params': {'max_new_tokens': 100, 'temperature': 0.9}}" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "from omegaconf import OmegaConf, DictConfig\n", + "from steer.vector_generators.vector_generators import BaseVectorGenerator\n", + "from steer.datasets import prepare_train_dataset\n", + "from steer.vector_appliers.vector_applier import BaseVectorApplier\n", + "from steer.datasets import prepare_generation_datasets\n", + "\n", + "top_cfg = OmegaConf.load(\"./config_loreft_translate.yaml\")\n", + "# top_cfg.model_name_or_path = model_path\n", + "# top_cfg.device = \"cuda:0\"\n", + "top_cfg" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate Steering Vector" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading loreft hparams from ../hparams/Steer/loreft_hparams/generate_loreft.yaml ...\n", + "LOREFT_0 Generator Hyperparameters:\n", + "LoReFTHyperParams(use_chat_template=True, system_prompt='', torch_dtype='bfloat16', seed=42, model_name_or_path='C:/git/axbench-main/google/gemma-2-2b-it/', device='cuda', use_cache=True, generate_orig_output=True, alg_name='loreft', steer_vector_output_dir='vectors/gemma-2-2b-it/', steer_train_dataset=['translate'], reft_layers=[5, 10, 15, 20], lr=0.0009, n_epochs=24, max_length=512, batch_size=1, gradient_accumulation_steps=2, subset=None, low_rank_dimension=4, position='f5+l5', weight_decay=0.0, save_vectors=True)\n", + "Generating loreft_0 vectors ...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2123b85af7da4fdebe47b3e7cfb21f1a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 Date: Thu, 10 Jul 2025 22:01:54 +0800 Subject: [PATCH 2/2] Remove pyreft package from apply --- steer/models/model_wrapper.py | 24 ++- .../loreft/apply_loreft_intervention.py | 138 +++++-------- steer/vector_appliers/vector_applier.py | 183 +++--------------- 3 files changed, 86 insertions(+), 259 deletions(-) diff --git a/steer/models/model_wrapper.py b/steer/models/model_wrapper.py index 74bc7ca6..bd02d8be 100644 --- a/steer/models/model_wrapper.py +++ b/steer/models/model_wrapper.py @@ -120,15 +120,21 @@ def forward(self, *args, **kwargs): self.dot_products.append((top_token, dot_product.cpu().item())) if self.add_activations_dict: augmented_output = output[0] - for activations in self.add_activations_dict.values(): - if activations is not None: - position_ids = kwargs.get("position_ids", None) - augmented_output = add_vector_from_position( - matrix=augmented_output, - vector=activations, - position_ids=position_ids, - from_pos=self.from_position, - ) + method_names = self.add_activations_dict.keys() + if "loreft" in method_names: + intervention_cls = self.add_activations_dict["loreft"].get("intervention_cls", None) + if intervention_cls is not None: + augmented_output = intervention_cls.forward(augmented_output) + else: + for activations in self.add_activations_dict.values(): + if activations is not None: + position_ids = kwargs.get("position_ids", None) + augmented_output = add_vector_from_position( + matrix=augmented_output, + vector=activations, + position_ids=position_ids, + from_pos=self.from_position, + ) output = (augmented_output,) + output[1:] if not self.save_internal_decodings: diff --git a/steer/vector_appliers/loreft/apply_loreft_intervention.py b/steer/vector_appliers/loreft/apply_loreft_intervention.py index 088b36ea..d208b22c 100644 --- a/steer/vector_appliers/loreft/apply_loreft_intervention.py +++ b/steer/vector_appliers/loreft/apply_loreft_intervention.py @@ -1,54 +1,24 @@ from dataclasses import dataclass -import pyreft -import transformers import torch -from typing import Sequence, Dict from .apply_loreft_intervention_hparam import ApplyLoReFTHyperParams from torch import nn -from pyvene import SourcelessIntervention, TrainableIntervention, DistributedRepresentationIntervention -class ConceptReFTIntervention( - SourcelessIntervention, - TrainableIntervention, - DistributedRepresentationIntervention -): - """ - Phi(h) = h + R^T(Wh + b - Rh) - Ref: https://arxiv.org/pdf/2404.03592 +class LoReFTIntervention(): + """Phi(h) = h + R^T(Wh + b - Rh)""" + def __init__(self, layer, device,weight_proj,weight_source,bias_source): + self.layer = layer + self.device = device + self.W_proj = weight_proj.to(device) + self.W_source = weight_source.to(device) + self.b_source = bias_source.to(device) - Note that this intervention is used for concept-based Direft. - The main difference is that weights are assumed to be trained and saved as 3D tensors. - """ - def __init__(self, **kwargs): - super().__init__(**kwargs, keep_last_dim=True) - self.W_proj = nn.Parameter(torch.zeros( - kwargs["n_concepts"], self.embed_dim, kwargs["low_rank_dimension"])) - self.W_source = nn.Parameter(torch.zeros( - kwargs["n_concepts"], self.embed_dim, kwargs["low_rank_dimension"])) - self.b_source = nn.Parameter(torch.zeros( - kwargs["n_concepts"], kwargs["low_rank_dimension"])) - - def encode( - self, base, source=None, subspaces=None - ): - """High-dimensional concept space.""" - proj_weight = self.W_proj[subspaces["input_subspaces"]] # batch_size, embed_dim, low_rank_dimension - rotated_base = torch.bmm(base, proj_weight) # [batch_size, seq_len, embed_dim] X [batch_size, embed_dim, low_rank_dimension] - - return rotated_base # batch_size, seq_len, low_rank_dimension - - def forward( - self, base, source=None, subspaces=None - ): - proj_weight = self.W_proj[subspaces["idx"]] # batch_size, embed_dim, low_rank_dimension - source_weight = self.W_source[subspaces["idx"]] # batch_size, embed_dim, low_rank_dimension - source_bias = self.b_source[subspaces["idx"]].unsqueeze(dim=1) # batch_size, 1, low_rank_dimension - - rotated_base = torch.bmm(base.float(), proj_weight) # batch_size, seq_len, low_rank_dimension - output = base + torch.bmm( - ((torch.bmm(base, source_weight) + source_bias) - rotated_base), # batch_size, seq_len, low_rank_dimension - proj_weight.transpose(-1, -2) + def forward(self, h): + rotated_base = torch.bmm(h.float(), self.W_proj) + output = h + torch.bmm( + ((torch.bmm(h, self.W_source) + self.b_source) - rotated_base), # batch_size, seq_len, low_rank_dimension + self.W_proj.transpose(-1, -2) ) - return output.to(base.dtype) + return output.to(h.dtype) + def apply_loreft(hparams: ApplyLoReFTHyperParams,model = None): from ...models import get_model @@ -63,55 +33,35 @@ def apply_loreft(hparams: ApplyLoReFTHyperParams,model = None): ) device = hparams.device weight_keys = list(weight.keys()) - n_concepts = weight[weight_keys[0]].shape[0] - low_rank_dimension = weight[weight_keys[0]].shape[-1] model, _ = get_model(hparams) reft_layers = hparams.reft_layers - dtype = model.torch_dtype - intervention_cls = ConceptReFTIntervention - reft_config = pyreft.ReftConfig(representations=[{ - "layer": l, "component": "block_output", - "low_rank_dimension": low_rank_dimension, - "intervention": intervention_cls(n_concepts=n_concepts,embed_dim=model.model.config.hidden_size, - low_rank_dimension=low_rank_dimension,dtype =dtype)} for l in reft_layers]) - reft_model = pyreft.get_reft_model(model.model, reft_config) - reft_model.set_device(device) - for intervention_name, intervention in reft_model.interventions.items(): - intervention.W_proj.data = weight[f"{intervention_name}.proj_weight"] - intervention.W_source.data = weight[f"{intervention_name}.source_weight"] - intervention.b_source.data = bias[f"{intervention_name}.bias"] - for k,v in reft_model.interventions.items(): - v.eval() - return reft_model,model -@dataclass -class InterventionEvalDataCollator(object): - """Collate examples for Intervention.""" - - tokenizer: transformers.AutoTokenizer - data_collator: transformers.DefaultDataCollator - - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - """ - intervention_locations will be something like [1,10,0,0,0] where all 0s are padding intervention locations. - """ - max_intervention_len = max([len(inst["intervention_locations"][0]) for inst in instances]) - max_seq_len = max([len(inst["input_ids"]) for inst in instances]) - - for inst in instances: - non_pad_len = len(inst["input_ids"]) - _intervention_location_paddings = torch.tensor( - [[-1 for _ in range(max_intervention_len - len(inst["intervention_locations"][0]))] for _ in range(inst["intervention_locations"].shape[0])]) # pointing to the first padding token - inst["intervention_locations"] = torch.cat([inst["intervention_locations"], _intervention_location_paddings], dim=-1).int() - inst["intervention_locations"] = inst["intervention_locations"] + 1 # shift by 1 to point to the first non-padding token, and all paddings will be 0. - - _input_id_paddings = torch.tensor( - [self.tokenizer.pad_token_id for _ in range(max_seq_len - non_pad_len)]) - offset = max_seq_len - non_pad_len - inst["intervention_locations"] = inst["intervention_locations"] + offset - inst["input_ids"] = torch.cat((_input_id_paddings, torch.tensor([self.tokenizer.pad_token_id]), inst["input_ids"])).int() - - inst["attention_mask"] = (inst["input_ids"] != self.tokenizer.pad_token_id).int() - - batch_inputs = self.data_collator(instances) - return batch_inputs + method = "loreft" + W_projs = {} + W_sources = {} + b_sources = {} + weight_keys = list(weight.keys()) + bias_keys = list(bias.keys()) + for layer in reft_layers: + for weight_key in weight_keys: + if weight_key.startswith(f"layer_{layer}"): + if(weight_key.endswith("proj_weight")): + W_projs[layer] = nn.Parameter(weight[weight_key]) + elif(weight_key.endswith("source_weight")): + W_sources[layer] = nn.Parameter(weight[weight_key]) + for bias_key in bias_keys: + if bias_key.startswith(f"layer_{layer}"): + b_sources[layer] = nn.Parameter(bias[bias_key]) + for layer in reft_layers: + intervention_cls = LoReFTIntervention( + layer=layer, + device=device, + weight_proj=W_projs[layer], + weight_source=W_sources[layer], + bias_source=b_sources[layer] + ) + activations = { + "intervention_cls": intervention_cls, + } + model.set_add_activations(layer=layer, activations=activations,method_name=method) + return model \ No newline at end of file diff --git a/steer/vector_appliers/vector_applier.py b/steer/vector_appliers/vector_applier.py index a13fa41f..144bd0a7 100644 --- a/steer/vector_appliers/vector_applier.py +++ b/steer/vector_appliers/vector_applier.py @@ -39,8 +39,7 @@ def apply_steering(self, hparams_dict, model=None, vectors=None): if alg_name == 'prompt': model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model) elif alg_name == "loreft": - reft_model, model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name]) - self.reft_model = reft_model + model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name]) elif vectors is None or vectors.get(alg_name) is None: assert hparams_dict[alg_name].steer_vector_load_dir is not None, f"Steer vector load path {hparams_dict[alg_name].steer_vector_load_dir} does not exist !" model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model) @@ -93,39 +92,35 @@ def generate(self, datasets, save_results=True, **kwargs): generation_data_size = -1 dataset = datasets[generation_data_name][:generation_data_size] if generation_data_size > 0 else datasets[generation_data_name] num_responses = self.config.get('num_responses', 1) - # judge method type - if self.hparams_dict.get('loreft') is not None: - preds,orig_preds,complete_output = self.loreft_generate(dataset, num_responses) - else: - for item in tqdm(dataset, desc=f"Evaluating dataset {generation_data_name}"): - if not item.get('input'): - continue - current_preds = [] - current_output = [] - input_text = self._process_input_text(item['input']) - inputs = self.tokenizer(input_text, return_tensors="pt", add_special_tokens = not self.config.use_chat_template).to(self.device) - - for j in range(num_responses): - if num_responses > 1: - set_seed(j) - with torch.no_grad(): - if self.config.get('steer_from_end_position', False): - instr_pos = self.find_instruction_end_postion(inputs['input_ids'][0]) - print("Steering from end position:", instr_pos) - self.model.set_from_positions(instr_pos) - output = self.model.model.generate(**inputs, **generation_params) - current_output.append(self.tokenizer.decode(output[0], skip_special_tokens=False)) - output=output[0][inputs['input_ids'].shape[1]:] - text = self.tokenizer.decode(output, skip_special_tokens=True) - current_preds.append(text) - preds.append(current_preds) - complete_output.append(current_output) + for item in tqdm(dataset, desc=f"Evaluating dataset {generation_data_name}"): + if not item.get('input'): + continue + current_preds = [] + current_output = [] + input_text = self._process_input_text(item['input']) + inputs = self.tokenizer(input_text, return_tensors="pt", add_special_tokens = not self.config.use_chat_template).to(self.device) - if self.config.get('generate_orig_output', False): - output = self.model.ori_generate(**inputs, **generation_params) + for j in range(num_responses): + if num_responses > 1: + set_seed(j) + with torch.no_grad(): + if self.config.get('steer_from_end_position', False): + instr_pos = self.find_instruction_end_postion(inputs['input_ids'][0]) + print("Steering from end position:", instr_pos) + self.model.set_from_positions(instr_pos) + output = self.model.model.generate(**inputs, **generation_params) + current_output.append(self.tokenizer.decode(output[0], skip_special_tokens=False)) output=output[0][inputs['input_ids'].shape[1]:] text = self.tokenizer.decode(output, skip_special_tokens=True) - orig_preds.append([text]) + current_preds.append(text) + preds.append(current_preds) + complete_output.append(current_output) + + if self.config.get('generate_orig_output', False): + output = self.model.ori_generate(**inputs, **generation_params) + output=output[0][inputs['input_ids'].shape[1]:] + text = self.tokenizer.decode(output, skip_special_tokens=True) + orig_preds.append([text]) formatted_results = self._format_result(dataset, orig_preds=orig_preds,preds=preds, complete_output=complete_output) if save_results: self.save_results(formatted_results, generation_data_name) @@ -173,127 +168,3 @@ def find_instruction_end_postion(self, tokens): start_pos = tokens.size(0) - 1 return start_pos - - def loreft_generate(self,dataset,num_responses): - import pyreft - import transformers - from torch.utils.data import DataLoader - import datasets - from .loreft.apply_loreft_intervention import InterventionEvalDataCollator - from ..datasets.loreft_data import load_reft_eval_data - def make_eval_data_module( - tokenizer:transformers.PreTrainedTokenizer, - model,df,positions="all", - num_interventions = 1, - nonstop = True, - share_weights = True, - max_length = 512 - ): - all_base_input_ids, all_intervention_locations = [], [] - for row in df: - base_prompt = row["input"] - base_prompt_ids = tokenizer( - base_prompt, max_length=max_length, truncation=True, return_tensors="pt")["input_ids"][0] - base_prompt_length = len(base_prompt_ids) - if positions == "all_prompt": - intervention_locations = torch.tensor([[i for i in range(base_prompt_length)]]) - else: - first_n, last_n = pyreft.parse_positions(positions) - intervention_locations = pyreft.get_intervention_locations( - last_position=base_prompt_length, - first_n=first_n, - last_n=last_n, - pad_mode="first", - num_interventions=num_interventions, - share_weights=share_weights, - ) - all_base_input_ids.append(base_prompt_ids) - all_intervention_locations.append(intervention_locations) - eval_dataset = datasets.Dataset.from_dict({ - "input_ids": all_base_input_ids, - "intervention_locations": all_intervention_locations, - }) - eval_dataset.set_format( - type='torch', columns=[ - 'input_ids', 'intervention_locations',]) - data_collator_fn = transformers.DefaultDataCollator( - return_tensors="pt" - ) - data_collator = InterventionEvalDataCollator(tokenizer=tokenizer, data_collator=data_collator_fn) - return dict(train_dataset=None, eval_dataset=eval_dataset, data_collator=data_collator) - hparams = self.hparams_dict["loreft"] - eval_dataset = load_reft_eval_data(dataset,None, self.tokenizer,"You are a helpful assistant.", use_chat_template=True) - batch_size = 1 # it will be something wrong if it is not 1 - eval_output_length = hparams.max_length - temperature = hparams.temperature if hasattr(hparams, "temperature") else 1.0 - reft_layers = hparams.reft_layers - number_of_interventions = len(reft_layers) - position = "l1" if not hasattr(hparams,"position") else hparams.position - data_module = make_eval_data_module( - tokenizer=self.tokenizer, - model=self.model.model, - df=eval_dataset, - positions=position, - num_interventions=number_of_interventions, - nonstop=True, - share_weights=True, - max_length=hparams.max_length - ) - eval_dataloader = DataLoader( - data_module["eval_dataset"],shuffle=False, - batch_size=batch_size, - collate_fn = data_module["data_collator"], - ) - self.reft_model.set_device("cuda") - result_generations = [] - result_origins = [] - result_complete = [] - for j in range(num_responses): - all_generations = [] - all_origins = [] - all_complete = [] - for i, bactch in enumerate(eval_dataloader): - inputs = {k: v.to(self.device) for k, v in bactch.items()} - if "intervention_locations" in inputs: - if inputs["intervention_locations"].dim() == 3: - unit_locations={"sources->base": ( - None, - inputs["intervention_locations"].permute(1, 0, 2).tolist() - )} - else: - # this is dummy for lora only baseline - unit_locations={"sources->base": (None, 0)} - origin_outputs, intervention_outputs = self.reft_model.generate( - {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]}, - unit_locations=unit_locations, intervene_on_prompt=True, - subspaces=[{"idx":[0]}] * number_of_interventions, - max_new_tokens=eval_output_length, do_sample=True, - temperature=temperature,output_original_output = True - ) - # Decode and print only the generated text without prompt tokens - input_lengths = [len(input_ids) for input_ids in inputs["input_ids"]] - generated_texts = [ - self.tokenizer.decode(generation[input_length:], skip_special_tokens=True) - for generation, input_length in zip(intervention_outputs, input_lengths) - ] - origin_texts = [ - self.tokenizer.decode(generation[input_length:], skip_special_tokens=True) - for generation, input_length in zip(origin_outputs, input_lengths) - ] - complete_output = [ - self.tokenizer.decode(generation, skip_special_tokens=False) - for generation in intervention_outputs - ] - all_generations += generated_texts - all_origins += origin_texts - all_complete += complete_output - for idx in range(len(all_generations)): - if j == 0: - result_generations.append([all_generations[idx]]) - result_origins.append([all_origins[idx]]) - result_complete.append([all_complete[idx]]) - else: - result_generations[idx].append(all_generations[idx]) - result_origins[idx].append(all_origins[idx]) - result_complete[idx].append(all_complete[idx]) - return result_generations, result_origins, result_complete \ No newline at end of file