diff --git a/hparams/Steer/loreft_hparams/apply_loreft.yaml b/hparams/Steer/loreft_hparams/apply_loreft.yaml new file mode 100644 index 00000000..7dccb2df --- /dev/null +++ b/hparams/Steer/loreft_hparams/apply_loreft.yaml @@ -0,0 +1,10 @@ +# Model related +alg_name: loreft +reft_layers: + - 5 + - 10 + - 15 + - 20 +max_length: 512 +low_rank_dimension: 4 +position: "f5+l5" \ No newline at end of file diff --git a/hparams/Steer/loreft_hparams/generate_loreft.yaml b/hparams/Steer/loreft_hparams/generate_loreft.yaml new file mode 100644 index 00000000..931adbee --- /dev/null +++ b/hparams/Steer/loreft_hparams/generate_loreft.yaml @@ -0,0 +1,15 @@ +# Model related +alg_name: loreft +reft_layers: + - 5 + - 10 + - 15 + - 20 +lr: 0.0009 +weight_decay: 0.00 +max_length: 512 +low_rank_dimension: 4 +n_epochs: 24 +batch_size: 1 +gradient_accumulation_steps: 2 +position: "f5+l5" \ No newline at end of file diff --git a/steer/datasets/loreft_data.py b/steer/datasets/loreft_data.py new file mode 100644 index 00000000..dae41f6e --- /dev/null +++ b/steer/datasets/loreft_data.py @@ -0,0 +1,35 @@ +from ..utils import build_model_input + + +def load_loreft_data(train_data, subset, tokenizer, system_prompt = None, use_chat_template = False): + dataset = [] + new_dataset = [] + pos_data = [{"input":item["question"],"output": item["matching"], "label": 1} for item in train_data] + if subset is not None: + pos_data = pos_data[:subset] + dataset = pos_data # loreft just needs pos data + for _datum in dataset: + if type(_datum['input']) != str: + _datum['input'] = str(_datum['input']) + if type(_datum['output']) != str: + _datum['output'] = str(_datum['output']) + inputs = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template, _datum["output"]) + prompts = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template) + + new_dataset.append({"input": inputs, "prompt": prompts,"output": _datum["output"], "label": _datum["label"]}) + + return new_dataset + +def load_reft_eval_data(eval_data, subset, tokenizer, system_prompt = None, use_chat_template = False): + dataset = [] + new_dataset = [] + data = [{"input":item["input"]} for item in eval_data] + if subset is not None: + data = data[:subset] + dataset = data + for _datum in dataset: + if type(_datum['input']) != str: + _datum['input'] = str(_datum['input']) + inputs = build_model_input(_datum["input"], tokenizer, system_prompt, use_chat_template) + new_dataset.append({"input": inputs}) + return new_dataset \ No newline at end of file diff --git a/steer/models/model_wrapper.py b/steer/models/model_wrapper.py index 74bc7ca6..bd02d8be 100644 --- a/steer/models/model_wrapper.py +++ b/steer/models/model_wrapper.py @@ -120,15 +120,21 @@ def forward(self, *args, **kwargs): self.dot_products.append((top_token, dot_product.cpu().item())) if self.add_activations_dict: augmented_output = output[0] - for activations in self.add_activations_dict.values(): - if activations is not None: - position_ids = kwargs.get("position_ids", None) - augmented_output = add_vector_from_position( - matrix=augmented_output, - vector=activations, - position_ids=position_ids, - from_pos=self.from_position, - ) + method_names = self.add_activations_dict.keys() + if "loreft" in method_names: + intervention_cls = self.add_activations_dict["loreft"].get("intervention_cls", None) + if intervention_cls is not None: + augmented_output = intervention_cls.forward(augmented_output) + else: + for activations in self.add_activations_dict.values(): + if activations is not None: + position_ids = kwargs.get("position_ids", None) + augmented_output = add_vector_from_position( + matrix=augmented_output, + vector=activations, + position_ids=position_ids, + from_pos=self.from_position, + ) output = (augmented_output,) + output[1:] if not self.save_internal_decodings: diff --git a/steer/utils/alg_dict.py b/steer/utils/alg_dict.py index 82e94226..c9268bc1 100644 --- a/steer/utils/alg_dict.py +++ b/steer/utils/alg_dict.py @@ -4,7 +4,8 @@ VectorPromptHyperParams, CAAHyperParams, LmSteerHyperParams, - MergeVectorHyperParams + MergeVectorHyperParams, + LoReFTHyperParams ) from ..vector_appliers import( ApplySaeFeatureHyperParams, @@ -14,6 +15,7 @@ ApplyLmSteerHyperParams, ApplyPromptHyperParams, ApplyMergeVectorHyperParams, + ApplyLoReFTHyperParams ) from ..vector_generators import ( @@ -23,6 +25,7 @@ generate_sae_feature_vectors, generate_sta_vectors, generate_merge_vector, + generate_LoReFT_vectors ) from ..vector_appliers import ( apply_lm_steer, @@ -32,6 +35,7 @@ apply_sta, apply_prompt, apply_merge_vector, + apply_loreft ) import torch DTYPES_DICT ={ @@ -52,7 +56,8 @@ 'vector_prompt': {'train': VectorPromptHyperParams, 'apply': ApplyVectorPromptHyperParams}, 'caa': {'train': CAAHyperParams, 'apply': ApplyCAAHyperParams}, 'prompt': {'apply': ApplyPromptHyperParams}, - 'merge_vector': {'train': MergeVectorHyperParams, 'apply': ApplyMergeVectorHyperParams} + 'merge_vector': {'train': MergeVectorHyperParams, 'apply': ApplyMergeVectorHyperParams}, + 'loreft':{'train':LoReFTHyperParams,'apply':ApplyLoReFTHyperParams} } METHODS_CLASS_DICT = { @@ -62,5 +67,6 @@ 'sae_feature': {'train': generate_sae_feature_vectors, 'apply': apply_sae_feature}, 'sta': {'train': generate_sta_vectors, 'apply': apply_sta}, 'prompt': {'apply': apply_prompt}, - 'merge_vector': {'train': generate_merge_vector, 'apply': apply_merge_vector} + 'merge_vector': {'train': generate_merge_vector, 'apply': apply_merge_vector}, + 'loreft':{'train':generate_LoReFT_vectors,'apply': apply_loreft} } \ No newline at end of file diff --git a/steer/vector_appliers/__init__.py b/steer/vector_appliers/__init__.py index e60e64af..392ca590 100644 --- a/steer/vector_appliers/__init__.py +++ b/steer/vector_appliers/__init__.py @@ -6,4 +6,5 @@ from .prompt import * from .sta import * from .sae_feature import * -from .vector_applier import * \ No newline at end of file +from .vector_applier import * +from .loreft import * \ No newline at end of file diff --git a/steer/vector_appliers/loreft/__init__.py b/steer/vector_appliers/loreft/__init__.py new file mode 100644 index 00000000..a066a6f9 --- /dev/null +++ b/steer/vector_appliers/loreft/__init__.py @@ -0,0 +1,2 @@ +from .apply_loreft_intervention import * +from .apply_loreft_intervention_hparam import * \ No newline at end of file diff --git a/steer/vector_appliers/loreft/apply_loreft_intervention.py b/steer/vector_appliers/loreft/apply_loreft_intervention.py new file mode 100644 index 00000000..d208b22c --- /dev/null +++ b/steer/vector_appliers/loreft/apply_loreft_intervention.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +import torch +from .apply_loreft_intervention_hparam import ApplyLoReFTHyperParams +from torch import nn +class LoReFTIntervention(): + """Phi(h) = h + R^T(Wh + b - Rh)""" + def __init__(self, layer, device,weight_proj,weight_source,bias_source): + self.layer = layer + self.device = device + self.W_proj = weight_proj.to(device) + self.W_source = weight_source.to(device) + self.b_source = bias_source.to(device) + + def forward(self, h): + rotated_base = torch.bmm(h.float(), self.W_proj) + output = h + torch.bmm( + ((torch.bmm(h, self.W_source) + self.b_source) - rotated_base), # batch_size, seq_len, low_rank_dimension + self.W_proj.transpose(-1, -2) + ) + return output.to(h.dtype) + + +def apply_loreft(hparams: ApplyLoReFTHyperParams,model = None): + from ...models import get_model + # make_model + dump_dir = hparams.steer_vector_load_dir + model_name = "loreft" + weight = torch.load( + f"{dump_dir}/{model_name}_weight.pt",weights_only= True + ) + bias = torch.load( + f"{dump_dir}/{model_name}_bias.pt",weights_only= True + ) + device = hparams.device + weight_keys = list(weight.keys()) + model, _ = get_model(hparams) + reft_layers = hparams.reft_layers + method = "loreft" + W_projs = {} + W_sources = {} + b_sources = {} + weight_keys = list(weight.keys()) + bias_keys = list(bias.keys()) + for layer in reft_layers: + for weight_key in weight_keys: + if weight_key.startswith(f"layer_{layer}"): + if(weight_key.endswith("proj_weight")): + W_projs[layer] = nn.Parameter(weight[weight_key]) + elif(weight_key.endswith("source_weight")): + W_sources[layer] = nn.Parameter(weight[weight_key]) + for bias_key in bias_keys: + if bias_key.startswith(f"layer_{layer}"): + b_sources[layer] = nn.Parameter(bias[bias_key]) + for layer in reft_layers: + intervention_cls = LoReFTIntervention( + layer=layer, + device=device, + weight_proj=W_projs[layer], + weight_source=W_sources[layer], + bias_source=b_sources[layer] + ) + activations = { + "intervention_cls": intervention_cls, + } + model.set_add_activations(layer=layer, activations=activations,method_name=method) + return model + \ No newline at end of file diff --git a/steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py b/steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py new file mode 100644 index 00000000..c10ef7b0 --- /dev/null +++ b/steer/vector_appliers/loreft/apply_loreft_intervention_hparam.py @@ -0,0 +1,34 @@ +import yaml +from typing import List +from ...utils import HyperParams +from dataclasses import dataclass, field + + + +@dataclass +class ApplyLoReFTHyperParams(HyperParams): + # Method (with predefined values) + alg_name: str = 'loreft' + steer_vector_load_dir: str = None + reft_layers: List[int] = field(default_factory=lambda: [20,21]) + max_length: int = 512 + batch_size: int = 1 + device: str = 'cuda' + low_rank_dimension: int = 1 + position: str = "l1" + temperature: float = 1.0 + + + @classmethod + def from_hparams(cls, hparams_name_or_path: str): + + if '.yaml' not in hparams_name_or_path: + hparams_name_or_path = hparams_name_or_path + '.yaml' + + with open(hparams_name_or_path, "r") as stream: + config = yaml.safe_load(stream) + config = super().construct_float_from_scientific_notation(config) + + assert (config and config['alg_name'] == 'loReFT') or print(f'LoReFTHyperParams can not load from {hparams_name_or_path}, ' + f'alg_name is {config["alg_name"]} ') + return cls(**config) diff --git a/steer/vector_appliers/vector_applier.py b/steer/vector_appliers/vector_applier.py index 92cca5e6..144bd0a7 100644 --- a/steer/vector_appliers/vector_applier.py +++ b/steer/vector_appliers/vector_applier.py @@ -18,12 +18,17 @@ def __init__(self, top_cfg: DictConfig): self.model = None self.tokenizer = None self.device = None + # for loreft generation + self.reft_model = None def _load_model(self): if self.model is None: from ..models import get_model self.model, self.tokenizer = get_model(self.config) self.device = self.model.device + def reset_loreft(self): + if self.hparams_dict.get('loreft') is not None: + self.reft_model = None def apply_steering(self, hparams_dict, model=None, vectors=None): from ..utils.alg_dict import METHODS_CLASS_DICT @@ -33,6 +38,8 @@ def apply_steering(self, hparams_dict, model=None, vectors=None): # print(f"Applying {alg_name} vectors to model ...") if alg_name == 'prompt': model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model) + elif alg_name == "loreft": + model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name]) elif vectors is None or vectors.get(alg_name) is None: assert hparams_dict[alg_name].steer_vector_load_dir is not None, f"Steer vector load path {hparams_dict[alg_name].steer_vector_load_dir} does not exist !" model = METHODS_CLASS_DICT[alg_name]['apply'](hparams_dict[alg_name] , model) @@ -84,7 +91,7 @@ def generate(self, datasets, save_results=True, **kwargs): if generation_data_size is None: generation_data_size = -1 dataset = datasets[generation_data_name][:generation_data_size] if generation_data_size > 0 else datasets[generation_data_name] - + num_responses = self.config.get('num_responses', 1) for item in tqdm(dataset, desc=f"Evaluating dataset {generation_data_name}"): if not item.get('input'): continue @@ -92,8 +99,7 @@ def generate(self, datasets, save_results=True, **kwargs): current_output = [] input_text = self._process_input_text(item['input']) inputs = self.tokenizer(input_text, return_tensors="pt", add_special_tokens = not self.config.use_chat_template).to(self.device) - - num_responses = self.config.get('num_responses', 1) + for j in range(num_responses): if num_responses > 1: set_seed(j) @@ -115,7 +121,6 @@ def generate(self, datasets, save_results=True, **kwargs): output=output[0][inputs['input_ids'].shape[1]:] text = self.tokenizer.decode(output, skip_special_tokens=True) orig_preds.append([text]) - formatted_results = self._format_result(dataset, orig_preds=orig_preds,preds=preds, complete_output=complete_output) if save_results: self.save_results(formatted_results, generation_data_name) diff --git a/steer/vector_generators/LoReFT/__init__.py b/steer/vector_generators/LoReFT/__init__.py new file mode 100644 index 00000000..631b4556 --- /dev/null +++ b/steer/vector_generators/LoReFT/__init__.py @@ -0,0 +1,2 @@ +from .generate_LoReFT_hparams import * +from .generate_LoReFT_vectors import * \ No newline at end of file diff --git a/steer/vector_generators/LoReFT/generate_LoReFT_hparams.py b/steer/vector_generators/LoReFT/generate_LoReFT_hparams.py new file mode 100644 index 00000000..95ed4544 --- /dev/null +++ b/steer/vector_generators/LoReFT/generate_LoReFT_hparams.py @@ -0,0 +1,42 @@ +import yaml +from typing import List +from ...utils import HyperParams +from dataclasses import dataclass, field +import torch + + +@dataclass +class LoReFTHyperParams(HyperParams): + # Method (with predefined values) + alg_name: str = 'loreft' + steer_vector_output_dir: str = None + steer_train_dataset: str=None + reft_layers: List[int] = field(default_factory=lambda: [20,21]) + lr: float = 0.0001 + n_epochs: int = 3 + max_length: int = 256 + batch_size: int = 3 + gradient_accumulation_steps: int = 1 + device: str = 'cuda' if torch.cuda.is_available() else 'cpu' + seed: int = 42 + subset: int = None + low_rank_dimension: int = 1 + position: str = "l1" + weight_decay: float = 0.01 + save_vectors: bool = True + use_cache : bool = True + + + @classmethod + def from_hparams(cls, hparams_name_or_path: str): + + if '.yaml' not in hparams_name_or_path: + hparams_name_or_path = hparams_name_or_path + '.yaml' + + with open(hparams_name_or_path, "r") as stream: + config = yaml.safe_load(stream) + config = super().construct_float_from_scientific_notation(config) + + assert (config and config['alg_name'] == 'loReFT') or print(f'LoReFTHyperParams can not load from {hparams_name_or_path}, ' + f'alg_name is {config["alg_name"]} ') + return cls(**config) diff --git a/steer/vector_generators/LoReFT/generate_LoReFT_vectors.py b/steer/vector_generators/LoReFT/generate_LoReFT_vectors.py new file mode 100644 index 00000000..1389c654 --- /dev/null +++ b/steer/vector_generators/LoReFT/generate_LoReFT_vectors.py @@ -0,0 +1,186 @@ +import sys + +import torch.utils.data.dataset + + + + + + +sys.path.append("../../") +sys.path.append("../") +sys.path.append("./") +import os +import torch +import argparse +from tqdm import tqdm +from dotenv import load_dotenv +import random +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import DataCollatorForSeq2Seq, TrainingArguments, set_seed +import pyreft +import matplotlib.pyplot as plt +from copy import deepcopy + +from .generate_LoReFT_hparams import LoReFTHyperParams +IGNORE_INDEX = -100 + +load_dotenv() + +HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN") + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group["lr"] +def save_interventions(reft_model, dump_dir,model_name): + proj_weights = [] + source_weights = [] + source_biases = [] + intervention_names = [] + for intervention_name, intervention in reft_model.interventions.items(): + intervention_names.append(intervention_name) + intervention_state_dict = intervention.state_dict() + proj_weight = intervention_state_dict["rotate_layer"] # [embed_dim, low_rank_dimension] + source_weight = intervention_state_dict["weight"].T # [embed_dim, low_rank_dimension] + source_bias = intervention_state_dict["bias"] # [low_rank_dimension] + proj_weights.append(proj_weight) + source_weights.append(source_weight) + source_biases.append(source_bias) + weight_file = dump_dir / f"{model_name}_weight.pt" + if weight_file.exists(): + existing_weight = torch.load(weight_file) + for i, intervention_name in enumerate(intervention_names): + existing_weight[f"{intervention_name}.proj_weight"] = torch.cat( + [existing_weight[f"{intervention_name}.proj_weight"], proj_weights[i].cpu().unsqueeze(dim=0)], dim=0) + existing_weight[f"{intervention_name}.source_weight"] = torch.cat( + [existing_weight[f"{intervention_name}.source_weight"], source_weights[i].cpu().unsqueeze(dim=0)], dim=0) + else: + existing_weight = {} + for i, intervention_name in enumerate(intervention_names): + existing_weight[f"{intervention_name}.proj_weight"] = proj_weights[i].cpu().unsqueeze(dim=0) + existing_weight[f"{intervention_name}.source_weight"] = source_weights[i].cpu().unsqueeze(dim=0) + torch.save(existing_weight, weight_file) + bias_file = dump_dir / f"{model_name}_bias.pt" + if bias_file.exists(): + existing_bias = torch.load(bias_file) + for i, intervention_name in enumerate(intervention_names): + existing_bias[f"{intervention_name}.bias"] = torch.cat( + [existing_bias[f"{intervention_name}.bias"], source_biases[i].cpu().unsqueeze(dim=0)], dim=0) + else: + existing_bias = {} + for i, intervention_name in enumerate(intervention_names): + existing_bias[f"{intervention_name}.bias"] = source_biases[i].cpu().unsqueeze(dim=0) + torch.save(existing_bias, bias_file) + +def generate_LoReFT_vectors(hparams:LoReFTHyperParams, dataset, model = None): + from ...models.get_model import get_model + from ...datasets.loreft_data import load_loreft_data + from transformers import get_scheduler + import transformers + from pathlib import Path + del_model = True + if model is None: + model, tokenizer = get_model(hparams) + else: + del_model = False + model, tokenizer = model, model.tokenizer + model.hparams = hparams + system_prompt = "" if hparams.system_prompt is None else hparams.system_prompt + use_chat_template = True if hparams.use_chat_template is None else hparams.use_chat_template + subset = None if hparams.subset is None else hparams.subset + tokenizer.model_max_length = hparams.max_length + model.model.eval() + device = hparams.device + model.model.to(device) + train_data = load_loreft_data(dataset,subset,tokenizer,system_prompt,use_chat_template) + torch_dtype = model.torch_dtype + batch_size = hparams.batch_size + low_rank_dimension = hparams.low_rank_dimension + if hparams.steer_vector_output_dir is None: + assert "Need Steer Vector Output Dir" + position = hparams.position + reft_layers = hparams.reft_layers + num_interventions = len(reft_layers) + data_module = pyreft.make_multiple_position_supervised_data_module( + tokenizer=tokenizer,model=model.model, + inputs = [item["prompt"] for item in train_data], + outputs= [item["output"] for item in train_data], + positions = position, + num_interventions= num_interventions, + nonstop=True, + share_weights=True + ) + + train_dataloader = DataLoader( + data_module["train_dataset"],shuffle=True, + batch_size=batch_size,collate_fn=data_module["data_collator"] + ) + # make_model + intervention_cls = pyreft.LoreftIntervention + + reft_config = pyreft.ReftConfig(representations=[{ + "layer": l, "component": "block_output", + "low_rank_dimension": low_rank_dimension, + "intervention": intervention_cls(embed_dim=model.model.config.hidden_size, + low_rank_dimension=low_rank_dimension,dtype = torch_dtype)} for l in reft_layers]) + reft_model = pyreft.get_reft_model(model.model, reft_config) + lr = hparams.lr + weight_decay = hparams.weight_decay + optimizer = torch.optim.AdamW( + reft_model.parameters(), lr=lr,weight_decay=weight_decay, + betas=(0.9,0.999),eps = 1e-8 + ) + n_epochs= hparams.n_epochs + gradient_accumulation_steps = 1 if hparams.gradient_accumulation_steps is None else hparams.gradient_accumulation_steps + num_training_steps = n_epochs * max(1, len(train_dataloader))//gradient_accumulation_steps + lr_scheduler = get_scheduler( + "linear", optimizer=optimizer, + num_warmup_steps=0,num_training_steps=num_training_steps + ) + progress_bar, curr_step = tqdm(range(num_training_steps), leave=True), 0 + + reft_model.print_trainable_parameters() + losses = [] + for epoch in range(n_epochs): + for step, batch in enumerate(train_dataloader): + inputs = {k : v.to(device) for k,v in batch.items()} + unit_locations={"sources->base": ( + None, + inputs["intervention_locations"].permute(1, 0, 2).tolist() + )} + _, cf_outputs = reft_model.forward( + base = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + },unit_locations=unit_locations, + labels=inputs["labels"], + use_cache=False + ) + loss = cf_outputs.loss.mean() + loss.backward() + loss = loss.mean() + if gradient_accumulation_steps > 1: + loss = loss / gradient_accumulation_steps + if(step + 1) % gradient_accumulation_steps == 0 or ( step + 1) == len(train_dataloader): + torch.nn.utils.clip_grad_norm_(reft_model.parameters(), 1.0) + curr_step += 1 + curr_lr = get_lr(optimizer) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + losses.append(loss.item()) + progress_bar.set_description( + "lr %.6f || loss %.6f " % (curr_lr, loss)) + progress_bar.close() + save_directory = os.path.join(hparams.steer_vector_output_dir, hparams.alg_name + '_vector') + if not os.path.exists(save_directory): + os.makedirs(save_directory) + save_interventions(reft_model,Path(save_directory),"loreft") + if del_model: + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + diff --git a/steer/vector_generators/__init__.py b/steer/vector_generators/__init__.py index 4f3b8d79..10706191 100644 --- a/steer/vector_generators/__init__.py +++ b/steer/vector_generators/__init__.py @@ -4,4 +4,5 @@ from .sta import * from .sae_feature import * from .merge import * -from .vector_generators import * \ No newline at end of file +from .vector_generators import * +from .LoReFT import * \ No newline at end of file diff --git a/steer/vector_generators/vector_generators.py b/steer/vector_generators/vector_generators.py index 28ffb7ac..a1e5bca3 100644 --- a/steer/vector_generators/vector_generators.py +++ b/steer/vector_generators/vector_generators.py @@ -39,6 +39,9 @@ def generate_vectors(self, datasets = None,): print(f"Generating {key} vectors ...") if alg_name in ['lm_steer', 'caa', 'vector_prompt', 'sta']: vectors = METHODS_CLASS_DICT[alg_name]['train'](hparams, datasets[dataset_name]) + elif alg_name == "loreft": + vectors = None + METHODS_CLASS_DICT[alg_name]['train'](hparams, datasets[dataset_name]) else: vectors = METHODS_CLASS_DICT[alg_name]['train'](hparams) generated_vectors[dataset_name][key] = vectors diff --git a/tutorial-notebooks/EasyEdit_Example_LoReFT_translate.ipynb b/tutorial-notebooks/EasyEdit_Example_LoReFT_translate.ipynb new file mode 100644 index 00000000..1b5e467e --- /dev/null +++ b/tutorial-notebooks/EasyEdit_Example_LoReFT_translate.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.\n" + ] + }, + { + "data": { + "text/plain": [ + "{'model_name_or_path': 'C:/git/axbench-main/google/gemma-2-2b-it/', 'torch_dtype': 'bfloat16', 'device': 'cuda', 'use_chat_template': True, 'system_prompt': '', 'steer_train_hparam_paths': ['../hparams/Steer/loreft_hparams/generate_loreft.yaml'], 'steer_train_dataset': ['translate'], 'steer_vector_output_dirs': 'vectors/gemma-2-2b-it/', 'apply_steer_hparam_paths': ['../hparams/Steer/loreft_hparams/apply_loreft.yaml'], 'steer_vector_load_dir': ['vectors/gemma-2-2b-it/translate/loreft_vector'], 'generation_data': ['nontoxic'], 'generation_data_size': 5, 'generation_output_dir': 'vectors/gemma-2-2b-it/translate_loreft_results/', 'num_responses': 1, 'steer_from_end_position': False, 'generate_orig_output': True, 'generation_params': {'max_new_tokens': 100, 'temperature': 0.9}}" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "from omegaconf import OmegaConf, DictConfig\n", + "from steer.vector_generators.vector_generators import BaseVectorGenerator\n", + "from steer.datasets import prepare_train_dataset\n", + "from steer.vector_appliers.vector_applier import BaseVectorApplier\n", + "from steer.datasets import prepare_generation_datasets\n", + "\n", + "top_cfg = OmegaConf.load(\"./config_loreft_translate.yaml\")\n", + "# top_cfg.model_name_or_path = model_path\n", + "# top_cfg.device = \"cuda:0\"\n", + "top_cfg" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate Steering Vector" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading loreft hparams from ../hparams/Steer/loreft_hparams/generate_loreft.yaml ...\n", + "LOREFT_0 Generator Hyperparameters:\n", + "LoReFTHyperParams(use_chat_template=True, system_prompt='', torch_dtype='bfloat16', seed=42, model_name_or_path='C:/git/axbench-main/google/gemma-2-2b-it/', device='cuda', use_cache=True, generate_orig_output=True, alg_name='loreft', steer_vector_output_dir='vectors/gemma-2-2b-it/', steer_train_dataset=['translate'], reft_layers=[5, 10, 15, 20], lr=0.0009, n_epochs=24, max_length=512, batch_size=1, gradient_accumulation_steps=2, subset=None, low_rank_dimension=4, position='f5+l5', weight_decay=0.0, save_vectors=True)\n", + "Generating loreft_0 vectors ...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2123b85af7da4fdebe47b3e7cfb21f1a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00