diff --git a/.github/ISSUE_TEMPLATE/bug-report.yaml b/.github/ISSUE_TEMPLATE/bug-report.yaml deleted file mode 100644 index d8360533..00000000 --- a/.github/ISSUE_TEMPLATE/bug-report.yaml +++ /dev/null @@ -1,56 +0,0 @@ -name: "πŸ› Bug Report" -description: Submit a bug report to help us improve EasyEdit -labels: [ "bug" ] -body: - - type: textarea - id: describe-bug - validations: - required: true - attributes: - label: Describe the bug - description: | - Please provide a clear and concise description of the bug you encountered. - placeholder: | - e.g., When editing a sentence using MEND, the output is unexpectedly unchanged. - - - type: textarea - id: reproduction - validations: - required: true - attributes: - label: To Reproduce - description: | - Before reporting a bug, please ensure you have thoroughly reviewed the following: - - - [README.md (EasyEdit1)](https://github.com/zjunlp/EasyEdit/blob/main/README.md) - - [README_2.md (EasyEdit2)](https://github.com/zjunlp/EasyEdit/blob/main/README_2.md) - - Then, provide **clear reproduction steps**, including: - - Your environment (Python version, OS, dependency versions, etc.) - - The full script or commands used to reproduce the issue - - βœ… Please submit code using correctly formatted [code blocks](https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting) - πŸ“· Long error messages can be provided as screenshots if needed. - - These help developers reproduce and resolve your issue more efficiently. - placeholder: | - Steps to reproduce the behavior: - - 1. Code and Environment - 2. Run: - ```bash - ``` - 3. Observe error: - ```text - Traceback (most recent call last): - ... - ``` - - - type: textarea - id: expected-behavior - validations: - required: true - attributes: - label: Expected behavior - description: | - What did you expect to happen instead? diff --git a/.github/ISSUE_TEMPLATE/config.yaml b/.github/ISSUE_TEMPLATE/config.yaml deleted file mode 100644 index 4b69b225..00000000 --- a/.github/ISSUE_TEMPLATE/config.yaml +++ /dev/null @@ -1,3 +0,0 @@ -# Picked from https://github.com/huggingface/transformers/blob/main/.github/ISSUE_TEMPLATE/config.yml -blank_issues_enabled: true -version: 2.1 \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature-report.yaml b/.github/ISSUE_TEMPLATE/feature-report.yaml deleted file mode 100644 index 1f5efb13..00000000 --- a/.github/ISSUE_TEMPLATE/feature-report.yaml +++ /dev/null @@ -1,66 +0,0 @@ -name: "✨ Feature Request" -description: Suggest a new feature or improvement for EasyEdit -labels: [ "enhancement" ] -body: - - type: checkboxes - id: context - attributes: - label: Use context - description: "This feature relates to:" - options: - - label: "A specific editing method (e.g., MEND, SERAC, ROME)" - - label: "Support for a new model or architecture" - - label: "Improved APIs or usability" - - label: "Evaluation metrics or logging" - - label: "Other (please describe below)" - - - type: textarea - id: problem - validations: - required: true - attributes: - label: What problem are you facing? - description: | - Describe the limitation, missing capability, or difficulty you encountered while using EasyEdit. - Try to include relevant context or examples that motivated your request. - placeholder: | - e.g., It's hard to apply MEND to LLaMA-style models due to the current implementation relying on HuggingFace-compatible architectures. - - - type: textarea - id: proposal - validations: - required: true - attributes: - label: Describe the feature you'd like to see - description: | - Describe your proposed feature or enhancement clearly. Focus on what you'd like to achieve. - If applicable, suggest interface designs or example usage. - placeholder: | - e.g., Add support for LLaMA through a wrapper module so that MEND can be applied without changing core code. - - - type: textarea - id: alternatives - attributes: - label: Have you considered any alternatives? - description: | - Describe any workarounds or alternative solutions you've considered, and why they may not be ideal. - placeholder: | - e.g., I tried manually converting model weights, but it's error-prone and incompatible with downstream evaluation scripts. - - - type: textarea - id: additional-context - attributes: - label: Additional context - description: | - Add any other information, references, or related issues/PRs that might help us evaluate your request. - placeholder: | - e.g., This feature would enable applying model editing to non-HF models like Baichuan, InternLM, etc., which are increasingly popular in the community. - - - type: checkboxes - id: code-of-conduct - attributes: - label: Code of Conduct - description: By submitting this issue, you agree to follow our Code of Conduct. - options: - - label: I agree to follow this project's Code of Conduct - required: true diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 7e4b5643..00000000 --- a/.gitignore +++ /dev/null @@ -1,19 +0,0 @@ -.idea -hugging_cache -data -**/__pycache__ -examples/WikiBio -examples/wiki_recent -.vscode -logs -results -output -scripts -vectors -**/.cache/ -*.log -delta.txt -.history -**/.hydra/ -**/.gradio/ -**/temp/ \ No newline at end of file diff --git a/README.md b/README.md index f424c143..88498072 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ![](https://img.shields.io/badge/PRs-Welcome-red) --- - +

Installation β€’ QuickStart β€’ @@ -65,6 +65,7 @@ - [Other Related Projects](#other-related-projects) ## πŸ””News +- 2025-06-07, πŸ‘‘ [UltraEdit](https://github.com/XiaojieGu/UltraEdit) has arrived β€” powered by a lifelong normalization strategy that continuously updates feature statistics across turns, it can edit 20K samples on a 7B model in just 5 minutes and scales stably to millions ! - 2025-06-05, 🌟🌟the EasyEdit has added a new model editing algorithm [CORE](https://arxiv.org/abs/2505.23026), designed to strengthen context robustness by minimizing context-sensitive variance in hidden states of the model for edited knowledge. - 2025-05-28, 🌟🌟the EasyEdit has added a new model editing algorithm [NAMET](https://arxiv.org/abs/2505.11876), which introduces noise during memory extraction via a one-line modification to MEMIT. Thanks to [@ybdai7](https://github.com/ybdai7) for contribution! - 2025-05-15, πŸš€πŸš€We released a new blog [Reflection on Knowledge Editing: Charting the Next Steps](https://fish-sorrel-a54.notion.site/Reflection-on-Knowledge-Editing-Charting-the-Next-Steps-1e6ca8e41f3a8098bd14c85ac1db8da6) discussing the next step for knowledge editing research. @@ -311,6 +312,7 @@ You can choose different editing methods according to your specific needs. GPT s | [WISE](https://github.com/zjunlp/EasyEdit/blob/main/examples/WISE.md) | |βœ… |βœ… | βœ… | | |βœ… | | | Defer | | βœ… |βœ… | | | | | βœ… | | [AlphaEdit](https://github.com/zjunlp/EasyEdit/blob/main/easyeditor/models/alphaedit/README.md) | |βœ… |βœ… | | | | | | +| [UltraEdit](https://github.com/guangxuc42/update-UltraEdit/blob/main/examples/UltraEdit.md) | |βœ… |βœ… | | | |βœ… | βœ… | > ❗️❗️ If you intend to use Mistral, please update the `transformers` library to version 4.34.0 manually. You can use the following code: `pip install transformers==4.34.0`. @@ -1169,6 +1171,7 @@ We thank all the contributors to this project, more contributors are welcome! #### Other Related Projects - [AlphaEdit](https://github.com/jianghoucheng/AlphaEdit) +- [UltraEdit](https://github.com/XiaojieGu/UltraEdit) - [ROME](https://github.com/kmeng01/rome) - [FastEdit](https://github.com/hiyouga/FastEdit) - [GRACE](https://github.com/Thartvigsen/GRACE) diff --git a/easyeditor/editors/batch_editor.py b/easyeditor/editors/batch_editor.py index 0069d0c5..8430f5b7 100644 --- a/easyeditor/editors/batch_editor.py +++ b/easyeditor/editors/batch_editor.py @@ -14,8 +14,12 @@ class BatchEditor(Enum): DPO = "DPO" EMMET = "EMMET" ALPHAEDIT = "AlphaEdit" +<<<<<<< HEAD + ULTRAEDIT = "ULTRAEDIT" +======= CORE = "CORE" +>>>>>>> upstream/main @staticmethod def is_batchable_method(alg_name: str): @@ -29,7 +33,12 @@ def is_batchable_method(alg_name: str): or alg_name == BatchEditor.QLoRA.value \ or alg_name == BatchEditor.LoRA.value \ or alg_name == BatchEditor.DPO.value \ +<<<<<<< HEAD + or alg_name == BatchEditor.ULTRAEDIT.value \ + or alg_name == BatchEditor.EMMET.value or alg_name == BatchEditor.ALPHAEDIT.value +======= or alg_name == BatchEditor.EMMET.value \ or alg_name == BatchEditor.ALPHAEDIT.value \ or alg_name == BatchEditor.CORE.value +>>>>>>> upstream/main diff --git a/easyeditor/models/__init__.py b/easyeditor/models/__init__.py index 8bef7ce7..e8ee6fbc 100644 --- a/easyeditor/models/__init__.py +++ b/easyeditor/models/__init__.py @@ -19,4 +19,5 @@ from .deco import * from .dola import * from .deepedit_api import * -from .defer import * \ No newline at end of file +from .defer import * +from .ultraedit import * \ No newline at end of file diff --git a/easyeditor/models/ultraedit/ULTRAEDIT.py b/easyeditor/models/ultraedit/ULTRAEDIT.py new file mode 100644 index 00000000..35ada13c --- /dev/null +++ b/easyeditor/models/ultraedit/ULTRAEDIT.py @@ -0,0 +1,180 @@ +import time +from typing import Dict, List +from torch.nn.utils import clip_grad_norm_ +from collections import Counter +import numpy as np +import logging +from ...trainer.algs.editable_model import EditableModel + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import math +from tqdm import tqdm +# import wandb + +from ...trainer.algs.malmen.util import ( + get_module, + get_shape, + TracerDict, + cross_entropy, + kl_div, + succ_ratios +) + +from ...trainer.algs.malmen.nets import RunningMeanStd +from ...trainer.utils import ( + EarlyStopper, + RunningStatAverager, + _logits, + formatted_timestamp, + safe_backward, + time_delta_seconds, +) + +LOG = logging.getLogger(__name__) + +def pad_tensor(tensor, target_length, dim=0, padding_value=0): + + tensor_length = tensor.size(dim) + if tensor_length >= target_length: + return tensor.narrow(dim, 0, target_length) + else: + padding = target_length - tensor_length + pad_shape = list(tensor.shape) + pad_shape[dim] = padding + pad_tensor = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device) + mask = torch.cat([torch.ones(tensor_length, dtype=torch.float32, device=tensor.device), + torch.zeros(padding, dtype=torch.float32, device=tensor.device)], dim=0) + return torch.cat([tensor, pad_tensor], dim=dim) + + +class ULTRAEDIT(EditableModel): + + def __init__( + self, model: nn.Module, config, model_constructor + ): + super().__init__(model, config, model_constructor) + + self.shift = False + if 'gpt' in config.model_name.lower(): + self.shift = True + elif 'llama' in config.model_name.lower(): + self.shift = True + elif 'internlm' in config.model_name.lower(): + self.shift = True + elif 'chatglm' in config.model_name.lower(): + self.shift = True + elif 'qwen' in config.model_name.lower(): + self.shift = True + elif 'mistral' in config.model_name.lower(): + self.shift = True + + if not str(self.config.device).startswith('cuda'): + self.config.device = f'cuda:{self.config.device}' + + if config.half: + self.model.half() + + for param in self.model.parameters(): + param.requires_grad = False + + for i in range(len(config.inner_params)): + if config.inner_params[i].endswith(".weight"): + config.inner_params[i] = config.inner_params[i].replace(".weight", "") + self.config.inner_params = config.inner_params + + for module_name in config.inner_params: + module = get_module(self.model, module_name) + module.weight.requires_grad = True + + shape_counter = Counter() + self.name2idx = {} + for module_name in config.inner_params: + shape = get_shape(get_module(model, module_name)) + self.name2idx[module_name] = shape_counter[shape] + shape_counter[shape] += 1 + + self.lifelong_normalizer = nn.ModuleDict({ + str(k): RunningMeanStd( + k[0]+k[1], + ) + for k, v in shape_counter.items() + }).to(self.config.device) + + def edit_model( + self, + param_shifts: Dict[str, torch.FloatTensor], + is_reverse: bool + ): + + for module_name, param_shift in param_shifts.items(): + module = get_module(self.model, module_name) + if isinstance(module, nn.Linear): + param_shift = param_shift.T + if is_reverse: + param_shift = - param_shift + module.weight.data += param_shift.to(module.weight.data.dtype) + + + def cache(self, batch) -> Dict[int, Dict[int, Dict[str, torch.Tensor]]]: + module_kv_map = {} + for idx, t in enumerate(batch): + with TracerDict( + self.model, + self.config, + t + ) as tr: + logits = self.model(input_ids=t['input_ids'], attention_mask=t['attention_mask'])["logits"] + cross_entropy(logits, t["labels"], self.shift).backward() + for module_idx, module_name in enumerate(self.config.inner_params): + shape = get_shape(get_module(self.model, module_name)) + keys = tr[module_name].keys.to(torch.float32).to(self.config.device) + values_grad = tr[module_name].values_grad.to(torch.float32).to(self.config.device) + self.lifelong_normalizer[str(shape)].update(torch.cat((keys, values_grad), -1)) + module_kv_map.setdefault(module_idx, {}).update({idx: {'keys': keys, 'values_grad': values_grad}}) + return module_kv_map + + def predict_param_shifts(self, module_kv_map) -> Dict[str, torch.FloatTensor]: + + param_shifts = {} + for module_idx, module_name in enumerate(self.config.inner_params): + + shape = get_shape(get_module(self.model, module_name)) + + lifelong_normalizer = self.lifelong_normalizer[str(shape)] + hidden_states = torch.cat([ + module_kv_map[module_idx][idx]["keys"] + for idx in range(len(module_kv_map[module_idx])) + ]) + values_grad = torch.cat([ + module_kv_map[module_idx][idx]["values_grad"] + for idx in range(len(module_kv_map[module_idx])) + ]) + v_feature = torch.empty((0, shape[1]), device = self.config.device) + for start_idx in range(0, hidden_states.shape[0], self.config.editor_batch_size): + end_idx = start_idx + self.config.editor_batch_size + hidden_states_once = pad_tensor(hidden_states[start_idx:end_idx], self.config.editor_batch_size, 0) + values_grad_once = pad_tensor(values_grad[start_idx:end_idx], self.config.editor_batch_size, 0) + with torch.no_grad(): + z_feature = torch.cat((hidden_states_once, values_grad_once), -1) + + z_feature = lifelong_normalizer(z_feature) + (hidden_states_hat, pesudo_values_hat) = z_feature.split([shape[0], shape[1]], -1) + + coeffs = - self.config.lr*(hidden_states_hat * hidden_states_hat).sum(-1).unsqueeze(-1) + v_feature = torch.cat((v_feature, coeffs * pesudo_values_hat)) + with torch.no_grad(): + mat = hidden_states.T @ hidden_states + torch.eye(shape[0], device=self.config.device) + v_feature = v_feature[:hidden_states.shape[0], :] + param_shift = torch.linalg.solve(mat, hidden_states.T @ v_feature) + param_shifts[module_name] = param_shift.to(next(self.model.parameters()).device) + + return param_shifts + + + def to(self, device): + super().to(device) + self.lifelong_normalizer.to(device) + self.model.to(device) + diff --git a/easyeditor/models/ultraedit/__init__.py b/easyeditor/models/ultraedit/__init__.py new file mode 100644 index 00000000..a980937b --- /dev/null +++ b/easyeditor/models/ultraedit/__init__.py @@ -0,0 +1,3 @@ +from .ultraedit_main import UltraEditRewriteExecutor +from .ultraedit_hparams import UltraEditHyperParams +from .ULTRAEDIT import ULTRAEDIT \ No newline at end of file diff --git a/easyeditor/models/ultraedit/ultraedit_hparams.py b/easyeditor/models/ultraedit/ultraedit_hparams.py new file mode 100644 index 00000000..9d350165 --- /dev/null +++ b/easyeditor/models/ultraedit/ultraedit_hparams.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from ...util.hparams import HyperParams +from typing import Optional, Any, List +import yaml + + +@dataclass +class UltraEditHyperParams(HyperParams): + alg_name: str + + # Model + model_name: str + model_class: str + tokenizer_class: str + tokenizer_name: str + inner_params: List[str] + device: int + + # Method + alg: str + dropout: float + no_grad_layers: Any + batch_size: int + lr: float + token: str + batch_size_once: int + editor_batch_size: int + silent: bool + max_length: int = 40 + + half: Optional[bool] = False + model_parallel: bool = False + + + # Output + results_dir: str + + @classmethod + def from_hparams(cls, hparams_name_or_path: str): + + if '.yaml' not in hparams_name_or_path: + hparams_name_or_path = hparams_name_or_path + '.yaml' + + with open(hparams_name_or_path, "r") as stream: + config = yaml.safe_load(stream) + config = super().construct_float_from_scientific_notation(config) + + assert (config and config['alg_name'] == 'ULTRAEDIT') or print(f'ULTRAEDITTrainingHyperParams can not load from {hparams_name_or_path}, ' + f'alg_name is {config["alg_name"]} ') + return cls(**config) + diff --git a/easyeditor/models/ultraedit/ultraedit_main.py b/easyeditor/models/ultraedit/ultraedit_main.py new file mode 100644 index 00000000..ee2edc28 --- /dev/null +++ b/easyeditor/models/ultraedit/ultraedit_main.py @@ -0,0 +1,120 @@ +import os +from copy import deepcopy +from typing import Dict, List, Any, Tuple + +import hydra +import torch +from collections import deque +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ...util.globals import * + +from .ULTRAEDIT import ULTRAEDIT +from .ultraedit_hparams import UltraEditHyperParams + +class UltraEditRewriteExecutor: + def __init__(self): + self.is_init = False + + def init_model(self, model, tok, params: UltraEditHyperParams): + + self.model = model + self.tokenizer = tok + + # Load the trained MEND model + self.alg = ULTRAEDIT(self.model, params, lambda: deepcopy(self.model)) + if params.model_parallel: + self.alg.lifelong_normalizer.to(deque(self.alg.model.parameters(), maxlen=1)[0].device) + else: + self.alg.to(torch.device(f'cuda:{params.device}')) + + + def reset_model(self): + self.is_init = False + del self.model, self.tokenizer, self.alg + + def apply_to_model( + self, + model: AutoModelForCausalLM, + tok: AutoTokenizer, + requests: List[Dict], + hparams: UltraEditHyperParams, + copy=False, + return_orig_weights=False, + keep_original_weight=False, + **kwargs + ): + """ + Given a request, for example + {'prompt': '{} has the position of', + 'subject': 'Charles Herman Helmsing', + 'relation_id': 'P39', + 'target_new': {'str': 'President', 'id': 'Q11696'}, + 'target_true': {'str': 'bishop', 'id': 'Q29182'}} + Returns a dictionary of numpy arrays that specifies + how mend will change the weights of the model. + """ + + if not self.is_init: + self.init_model(model, tok, hparams) + + weights_copy = {} + model = deepcopy(self.model) if copy else self.model + assert len(requests) >= hparams.batch_size, "The number of requests must be greater than or equal to the value of batch_size." + # Define i/o + requests = requests[:hparams.batch_size] + batchs = [] + for i in range(hparams.batch_size // hparams.batch_size_once): + batch = requests[i * hparams.batch_size_once : (i+1)*hparams.batch_size_once] + targets = [ + (" " if request["target_new"][0] != " " else "") + + request["target_new"] + for request in batch + ] + sentences = [ + request["prompt"] + targets[i] + for i, request in enumerate(batch) + ] + + # Tokenize + sent_tok = self.tokenizer(sentences, padding=True, return_tensors="pt").to( + f"cuda:{hparams.device}" + ) + target_tok = self.tokenizer(targets, padding=True, return_tensors="pt").to( + f"cuda:{hparams.device}" + ) + + # Define labels + label_tok = deepcopy(sent_tok["input_ids"]) + for i in range(label_tok.size(0)): + target_len = target_tok["attention_mask"][i].sum() + padding_len = ( + sent_tok["input_ids"].size(1) - sent_tok["attention_mask"][i].sum() + ) + label_tok[i][: -target_len - padding_len] = -100 + label_tok[i][label_tok[i] == self.tokenizer.pad_token_id] = -100 + + edit_inner = dict( + input_ids=sent_tok["input_ids"], + attention_mask=sent_tok["attention_mask"], + labels=target_tok['input_ids'], + ) + + batchs.append(edit_inner) + # Run M + batchs = sorted( + batchs, + key=lambda x: torch.sum(x['attention_mask']).item(), + reverse=True + ) + module_kv_map = self.alg.cache(batchs) + param_shifts = self.alg.predict_param_shifts(module_kv_map) + with torch.no_grad(): + for n, p in self.model.named_parameters(): + if n in hparams.inner_params: + if return_orig_weights and n not in weights_copy: + weights_copy[n] = p.detach().clone() + self.alg.edit_model(param_shifts, False) + + return self.alg.model, weights_copy + \ No newline at end of file diff --git a/easyeditor/trainer/algs/__init__.py b/easyeditor/trainer/algs/__init__.py index 3b2e402e..2b6c76a2 100644 --- a/easyeditor/trainer/algs/__init__.py +++ b/easyeditor/trainer/algs/__init__.py @@ -2,3 +2,4 @@ from .MEND import * from .SERAC import * from .MALMEN import * +from .ULTRAEDIT import * \ No newline at end of file diff --git a/easyeditor/util/alg_dict.py b/easyeditor/util/alg_dict.py index cdd4ae1b..72c2d012 100644 --- a/easyeditor/util/alg_dict.py +++ b/easyeditor/util/alg_dict.py @@ -20,6 +20,7 @@ from ..models.core import COREHyperParams, apply_core_to_model from .. models.deepedit_api import DeepEditApiHyperParams, apply_deepedit_api_to_model from ..models.dpo import DPOHyperParams, apply_dpo_to_model +from ..models.ultraedit import UltraEditHyperParams, UltraEditRewriteExecutor ALG_DICT = { 'ROME': apply_rome_to_model, @@ -41,8 +42,13 @@ 'R-ROME': apply_r_rome_to_model, "EMMET": apply_emmet_to_model, "AlphaEdit": apply_AlphaEdit_to_model, +<<<<<<< HEAD + "DeepEdit-Api": apply_deepedit_api_to_model, + "ULTRAEDIT": UltraEditRewriteExecutor().apply_to_model +======= "CORE": apply_core_to_model, "DeepEdit-Api": apply_deepedit_api_to_model +>>>>>>> upstream/main } ALG_MULTIMODAL_DICT = { diff --git a/examples/UltraEdit.md b/examples/UltraEdit.md new file mode 100644 index 00000000..2c2303a7 --- /dev/null +++ b/examples/UltraEdit.md @@ -0,0 +1,67 @@ + +

+

UltraEdit: Training-, Subject-, and Memory-Free Lifelong Editing in Large Language Models

+
+ + +We released our paper *UltraEdit: Training-, Subject-, and Memory-Free Lifelong Editing in Large Language Models* β€” πŸ“– [UltraEdit on arXiv](https://arxiv.org/abs/2505.14679) | πŸ€— [UltraEditBench on HuggingFace](https://huggingface.co/datasets/XiaojieGu/UltraEditBench). If our project helps you, please give us a star ⭐ on [UltraEdit](https://github.com/XiaojieGu/UltraEdit) to support us. πŸ˜‰πŸ˜‰ + + + + + + +## πŸ“¦ Data & Model Preparation + +1️⃣ Create a new directory `EasyEdit/data/ultraedit` and download the files from [Google Drive](https://drive.google.com/drive/folders/1wsxG5Ybf6hT9QUlccvzTuJSfL_TFNyKQ?usp=sharing) into this folder. + +2️⃣ Download the [UltraEditBench](https://huggingface.co/datasets/XiaojieGu/UltraEditBench) and save it under `EasyEdit/data/ultraedit`. + +3️⃣ Specify the path to model weights by setting the `model_name` and `tokenizer_name`field in `EasyEdit/hparams/UltraEdit`. + +If you need to use locate-then-edit methods, we provide precomputed covariance matrices on Hugging Face for several models: [GPT-J 6B](https://huggingface.co/XiaojieGu/gpt-j-6b_CovarianceMatrix), [Qwen2.5-7B-Instruct](https://huggingface.co/XiaojieGu/Qwen2.5-7B-Instruct_CovarianceMatrix), [Mistral-7B-v0.3](https://huggingface.co/XiaojieGu/Mistral-7B-v0.3_CovarianceMatrix), [LLaMA-3-8B-Instruct](https://huggingface.co/XiaojieGu/Llama-3-8B-Instruct_CovarianceMatrix), and [LLaMA-2-7B-hf](https://huggingface.co/XiaojieGu/Llama-2-7b-hf_CovarianceMatrix). + +## πŸš€ Setup + +πŸ’‘ If you want to try editing a Mistral-7B model, even a **24GB consumer GPU** is enough β€” model editing for everyone! + +Run the main experiment with: + +```bash +sh run_ultraedit_editing.sh +``` + +The `run_ultraedit_editing.sh` script includes a sample command like: + +``` +python run_ultraedit_editing.py \ + --editing_method=UltraEdit \ + --hparams_dir=../hparams/ULTRAEDIT/mistral-7b.yaml \ + --data_dir=../data/ultraedit \ + --ds_size=20000 \ # Number of samples + --batch_size=100 \ # Number of edits per turn + --data_type=zsre \ + --sequential_edit +``` +πŸ’‘ Just try editing **20K samples** on Mistral-7B in **under 5 minutes** β€” ultra-efficient! + + + + + +## πŸ“« Contact + +For any inquiries or possible collaboration, feel free to reach out at **peettherapynoys@gmail.com**, **guangxuc42@gmail.com** β€” we’re open to connecting! + + +## πŸ“‘ Citation +If you find UltraEdit useful for your research and applications, please cite using this BibTeX: +```bibtex +@article{gu2025ultraedit, + title={UltraEdit: Training-, Subject-, and Memory-Free Lifelong Editing in Large Language Models}, + author={Gu, Xiaojie and Chen, Guangxu and Li, Jungang and Gu, Jia-Chen and Hu, Xuming and Zhang, Kai}, + journal={arXiv preprint arXiv:2505.14679}, + year={2025} +} +``` + diff --git a/examples/run_ultraedit_editing.py b/examples/run_ultraedit_editing.py new file mode 100644 index 00000000..61cbf416 --- /dev/null +++ b/examples/run_ultraedit_editing.py @@ -0,0 +1,142 @@ +import os.path +import sys +import json +import argparse +import torch +from pathlib import Path + +sys.path.append('..') +from easyeditor import ( + AlphaEditHyperParams, + FTHyperParams, + WISEHyperParams, + UltraEditHyperParams, + BaseEditor, + summary_metrics, +) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--editing_method', required=True, type=str) + parser.add_argument('--hparams_dir', required=True, type=str) + parser.add_argument('--data_dir', required=True, type=str) + parser.add_argument('--data_type', required=True, type=str, + choices=['wikibigedit', 'ultraeditbench', 'zsre']) + parser.add_argument('--output_dir', default='./outputs', type=str) + parser.add_argument('--ds_size', default=100, type=int) + parser.add_argument('--batch_size', default=100, type=int) + parser.add_argument('--sequential_edit', action="store_true") + args = parser.parse_args() + + if args.editing_method == 'FT': + editing_hparams = FTHyperParams + elif args.editing_method == 'WISE': + editing_hparams = WISEHyperParams + elif args.editing_method == 'AlphaEdit': + editing_hparams = AlphaEditHyperParams + elif args.editing_method == 'UltraEdit': + editing_hparams = UltraEditHyperParams + else: + raise NotImplementedError + + data_dir = Path(args.data_dir) + + if args.data_type == 'zsre': + zsre_dir = data_dir / 'zsre_eval_20k.json' + with open(zsre_dir, "r") as f: + raw = json.load(f) + edit_data = raw[:args.ds_size] + prompts = [edit_data_['src'] for edit_data_ in edit_data] + subject = [edit_data_['subject'] for edit_data_ in edit_data] + rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data] + target_new = [edit_data_['ans'] for edit_data_ in edit_data] + locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data] + locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data] + locality_inputs = { + 'neighborhood':{ + 'prompt': locality_prompts, + 'ground_truth': locality_ans + }, + } + elif args.data_type == 'wikibigedit': + wiki_dir = data_dir / 'wikibigedit_eval_17k.json' + with open(wiki_dir, "r") as f: + raw = json.load(f) + edit_data = raw[:args.ds_size] + prompts = [edit_data_['update'] for edit_data_ in edit_data] + subject = [edit_data_['subject'] for edit_data_ in edit_data] + rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data] + target_new = [edit_data_['ans'] for edit_data_ in edit_data] + portability_personas_prompts = [[data['personas']] if isinstance(data['personas'], str) else None for data in edit_data] + portability_personas_answers = [[data['ans']] for data in edit_data] + portability_hop_prompts = [[data['mhop']] if isinstance(data['mhop'], str) else None for data in edit_data] + portability_hop_answers = [[data['mhop_ans']] if isinstance(data['mhop_ans'], str) else None for data in edit_data] + locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data] + locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data] + + locality_inputs = { + 'neighborhood':{ + 'prompt': locality_prompts, + 'ground_truth': locality_ans + }, + } + + portability_inputs = { + 'personas':{ + 'prompt': portability_personas_prompts, + 'ground_truth': portability_personas_answers + }, + 'mhop':{ + 'prompt': portability_hop_prompts, + 'ground_truth': portability_hop_answers + } + } + + elif args.data_type == 'ultraeditbench': + ultraeditbench_dir = data_dir / 'UltraEditBench_2M.json' + with open(ultraeditbench_dir,"r") as f: + raw = json.load(f) + edit_data = raw[:args.ds_size] + prompts = [edit_data_['prompt'] for edit_data_ in edit_data] + subject = [edit_data_['subject'] for edit_data_ in edit_data] + rephrase_prompts = [edit_data_['rephrase_prompt'] for edit_data_ in edit_data] + target_new = [edit_data_['ans'] for edit_data_ in edit_data] + locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data] + locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data] + locality_inputs = { + 'neighborhood':{ + 'prompt': locality_prompts, + 'ground_truth': locality_ans + }, + } + + hparams = editing_hparams.from_hparams(f'{args.hparams_dir}') + hparams.batch_size=args.batch_size + + os.makedirs(args.output_dir, exist_ok=True) + output_file = os.path.join( + args.output_dir, + f'{hparams.model_name.split("/")[-1]}_{args.editing_method}_N={args.ds_size}_Sequential={args.sequential_edit}.json' + ) + + print("See results at: ", output_file) + + editor = BaseEditor.from_hparams(hparams) + metrics, edited_model, _ = editor.batch_edit( + prompts=prompts, + rephrase_prompts=rephrase_prompts, + target_new=target_new, + subject=subject, + locality_inputs=locality_inputs, + portability_inputs=portability_inputs if args.data_type == 'wikibigedit' else None, + keep_original_weight=True, + sequential_edit=args.sequential_edit, + ) + + with open(output_file, 'w') as f: + json.dump(metrics, f, indent=4) + + if len(metrics) > 0: + summary_metrics(metrics) + diff --git a/examples/run_ultraedit_editing.sh b/examples/run_ultraedit_editing.sh new file mode 100644 index 00000000..b3fc8820 --- /dev/null +++ b/examples/run_ultraedit_editing.sh @@ -0,0 +1,8 @@ +python run_ultraedit_editing.py \ + --editing_method=UltraEdit \ + --hparams_dir=../hparams/ULTRAEDIT/mistral-7b.yaml \ + --data_dir=../data/ultraedit \ + --ds_size=20000 \ + --batch_size=100 \ + --data_type=zsre \ + --sequential_edit diff --git a/hparams/UltraEdit/gemma-3-27b.yaml b/hparams/UltraEdit/gemma-3-27b.yaml new file mode 100644 index 00000000..a45e35d5 --- /dev/null +++ b/hparams/UltraEdit/gemma-3-27b.yaml @@ -0,0 +1,46 @@ +alg_name: "ULTRAEDIT" +device: 0 + +model_name: your/path/gemma-3-27b-it +model_class: AutoModelForCausalLM +tokenizer_class: AutoTokenizer +tokenizer_name: your/path/gemma-3-27b-it +inner_params: +#wikibigedit ultraeditbench +- language_model.model.layers.52.mlp.gate_proj +- language_model.model.layers.53.mlp.gate_proj +- language_model.model.layers.54.mlp.gate_proj +- language_model.model.layers.55.mlp.gate_proj +- language_model.model.layers.56.mlp.gate_proj +- language_model.model.layers.57.mlp.gate_proj +- language_model.model.layers.58.mlp.gate_proj +- language_model.model.layers.59.mlp.gate_proj +- language_model.model.layers.60.mlp.gate_proj +- language_model.model.layers.52.mlp.up_proj +- language_model.model.layers.53.mlp.up_proj +- language_model.model.layers.54.mlp.up_proj +- language_model.model.layers.55.mlp.up_proj +- language_model.model.layers.56.mlp.up_proj +- language_model.model.layers.57.mlp.up_proj +- language_model.model.layers.58.mlp.up_proj +- language_model.model.layers.59.mlp.up_proj +- language_model.model.layers.60.mlp.up_proj + +# Method +alg: UltraEdit +dropout: 0.0 +no_grad_layers: null + +lr: 1e-6 +token: mask + +batch_size: 100 +batch_size_once: 10 +editor_batch_size: 1024 +silent: False +half: true + +model_parallel: false + +# Output +results_dir: ./results diff --git a/hparams/UltraEdit/gpt-j-6B.yaml b/hparams/UltraEdit/gpt-j-6B.yaml new file mode 100644 index 00000000..3512e84d --- /dev/null +++ b/hparams/UltraEdit/gpt-j-6B.yaml @@ -0,0 +1,49 @@ +alg_name: "ULTRAEDIT" +device: 0 +# Model +model_name: your/path/gpt-j-6B +model_class: AutoModelForCausalLM +tokenizer_class: AutoTokenizer +tokenizer_name: your/path/gpt-j-6B +inner_params: +# zsre or ultra +- transformer.h.18.mlp.fc_out +- transformer.h.19.mlp.fc_out +- transformer.h.20.mlp.fc_out +- transformer.h.21.mlp.fc_out +- transformer.h.22.mlp.fc_out +- transformer.h.23.mlp.fc_out +- transformer.h.24.mlp.fc_out +- transformer.h.25.mlp.fc_out +- transformer.h.26.mlp.fc_out +# wikibigedit +# - transformer.h.19.mlp.fc_out +# - transformer.h.20.mlp.fc_out +# - transformer.h.21.mlp.fc_out +# - transformer.h.22.mlp.fc_out +# - transformer.h.23.mlp.fc_out +# - transformer.h.24.mlp.fc_out +# - transformer.h.25.mlp.fc_out +# - transformer.h.26.mlp.fc_out +# fever +# - transformer.h.25.mlp.fc_out +# - transformer.h.26.mlp.fc_out + +# Method +alg: UltraEdit +dropout: 0.0 +no_grad_layers: null + +lr: 1e-6 +token: mask + +batch_size: 10 +batch_size_once: 10 +editor_batch_size: 1024 +silent: False +half: true + +model_parallel: false + +# Output +results_dir: ./results diff --git a/hparams/UltraEdit/llama3-8b-instruct.yaml b/hparams/UltraEdit/llama3-8b-instruct.yaml new file mode 100644 index 00000000..8861325a --- /dev/null +++ b/hparams/UltraEdit/llama3-8b-instruct.yaml @@ -0,0 +1,59 @@ +alg_name: "ULTRAEDIT" +device: 0 + +model_name: your/path/Llama3-8b-instruct +model_class: AutoModelForCausalLM +tokenizer_class: AutoTokenizer +tokenizer_name: your/path/Llama3-8b-instruct +inner_params: +# zsre wiki ultraeditbench +- model.layers.11.mlp.gate_proj +- model.layers.12.mlp.gate_proj +- model.layers.13.mlp.gate_proj +- model.layers.14.mlp.gate_proj +- model.layers.15.mlp.gate_proj +- model.layers.18.mlp.up_proj +- model.layers.19.mlp.up_proj +- model.layers.20.mlp.up_proj +- model.layers.21.mlp.up_proj +- model.layers.22.mlp.up_proj +- model.layers.23.mlp.up_proj +- model.layers.24.mlp.up_proj +# fever +# - model.layers.22.mlp.gate_proj +# - model.layers.23.mlp.gate_proj +# - model.layers.24.mlp.gate_proj +# - model.layers.25.mlp.gate_proj +# - model.layers.26.mlp.gate_proj +# - model.layers.27.mlp.gate_proj +# - model.layers.28.mlp.gate_proj +# - model.layers.29.mlp.gate_proj +# - model.layers.30.mlp.gate_proj +# - model.layers.22.mlp.up_proj +# - model.layers.23.mlp.up_proj +# - model.layers.24.mlp.up_proj +# - model.layers.25.mlp.up_proj +# - model.layers.26.mlp.up_proj +# - model.layers.27.mlp.up_proj +# - model.layers.28.mlp.up_proj +# - model.layers.29.mlp.up_proj +# - model.layers.30.mlp.up_proj + +# Method +alg: UltraEdit +dropout: 0.0 +no_grad_layers: null + +lr: 1e-6 +token: mask + +batch_size: 100 +batch_size_once: 10 +editor_batch_size: 1024 +silent: False +half: true + +model_parallel: false + +# Output +results_dir: ./results diff --git a/hparams/UltraEdit/mistral-7b.yaml b/hparams/UltraEdit/mistral-7b.yaml new file mode 100644 index 00000000..5e7190b9 --- /dev/null +++ b/hparams/UltraEdit/mistral-7b.yaml @@ -0,0 +1,31 @@ +alg_name: "ULTRAEDIT" +device: 0 + +model_name: your/path/mistral-7b-v0.3 +class_name: AutoModelForCausalLM +tokenizer_class: AutoTokenizer +tokenizer_name: your/path/mistral-7b-v0.3 + +inner_params: +# zsre wikibigedit fever ultraeditbench +- model.layers.29.mlp.down_proj +- model.layers.30.mlp.down_proj + +# Method +alg: UltraEdit +dropout: 0.0 +no_grad_layers: null + +lr: 1e-6 +token: mask + +batch_size: 100 +batch_size_once: 10 +editor_batch_size: 1024 +silent: False +half: true + +model_parallel: false + +# Output +results_dir: ./results diff --git a/hparams/UltraEdit/phi-4.yaml b/hparams/UltraEdit/phi-4.yaml new file mode 100644 index 00000000..4adace98 --- /dev/null +++ b/hparams/UltraEdit/phi-4.yaml @@ -0,0 +1,37 @@ +alg_name: "ULTRAEDIT" +device: 0 + +model_name: your/path/phi-4 +class_name: AutoModelForCausalLM +tokenizer_class: AutoTokenizer +tokenizer_name: your/path/phi-4 +inner_params: +# wikibigedit ultraeditbench +- model.layers.30.mlp.down_proj +- model.layers.31.mlp.down_proj +- model.layers.32.mlp.down_proj +- model.layers.33.mlp.down_proj +- model.layers.34.mlp.down_proj +- model.layers.35.mlp.down_proj +- model.layers.36.mlp.down_proj +- model.layers.37.mlp.down_proj +- model.layers.38.mlp.down_proj + +# Method +alg: UltraEdit +dropout: 0.0 +no_grad_layers: null + +lr: 1e-6 +token: mask + +batch_size: 100 +batch_size_once: 10 +editor_batch_size: 1024 +silent: False +half: true + +model_parallel: false + +# Output +results_dir: ./results diff --git a/hparams/UltraEdit/qwen2.5-7b.yaml b/hparams/UltraEdit/qwen2.5-7b.yaml new file mode 100644 index 00000000..4e5674df --- /dev/null +++ b/hparams/UltraEdit/qwen2.5-7b.yaml @@ -0,0 +1,64 @@ +alg_name: "ULTRAEDIT" +device: 0 + +model_name: your/path/Qwen2.5-7B-Instruct +model_class: AutoModelForCausalLM +tokenizer_class: AutoTokenizer +tokenizer_name: your/path/Qwen2.5-7B-Instruct +inner_params: +# zsre fever ultraeditbench +- model.layers.18.mlp.gate_proj +- model.layers.19.mlp.gate_proj +- model.layers.20.mlp.gate_proj +- model.layers.21.mlp.gate_proj +- model.layers.22.mlp.gate_proj +- model.layers.23.mlp.gate_proj +- model.layers.24.mlp.gate_proj +- model.layers.25.mlp.gate_proj +- model.layers.26.mlp.gate_proj +- model.layers.18.mlp.up_proj +- model.layers.19.mlp.up_proj +- model.layers.20.mlp.up_proj +- model.layers.21.mlp.up_proj +- model.layers.22.mlp.up_proj +- model.layers.23.mlp.up_proj +- model.layers.24.mlp.up_proj +- model.layers.25.mlp.up_proj +- model.layers.26.mlp.up_proj +# wikibigedit +# - model.layers.19.mlp.gate_proj +# - model.layers.20.mlp.gate_proj +# - model.layers.21.mlp.gate_proj +# - model.layers.22.mlp.gate_proj +# - model.layers.23.mlp.gate_proj +# - model.layers.24.mlp.gate_proj +# - model.layers.25.mlp.gate_proj +# - model.layers.26.mlp.gate_proj +# - model.layers.18.mlp.up_proj +# - model.layers.19.mlp.up_proj +# - model.layers.20.mlp.up_proj +# - model.layers.21.mlp.up_proj +# - model.layers.22.mlp.up_proj +# - model.layers.23.mlp.up_proj +# - model.layers.24.mlp.up_proj +# - model.layers.25.mlp.up_proj +# - model.layers.26.mlp.up_proj + +# Method +alg: UltraEdit +dropout: 0.0 +no_grad_layers: null + +lr: 1e-6 +token: mask + +batch_size: 100 +batch_size_once: 10 +editor_batch_size: 1024 +silent: False +half: true + +model_parallel: false + +# Output +results_dir: ./results