Skip to content

Adding LoRA #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.DS_Store
__pycache__/
data/
err.txt
finetune.py
models_check_tok_15M/
models_check_tok_260K/
out/
resources.txt
run
runq
stdout.txt
test/
2 changes: 2 additions & 0 deletions TODO
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
TODO
- Lora for the embedding
85 changes: 67 additions & 18 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import math
import struct
import inspect
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch import nn

@dataclass
class LoraArgs:
lora_r: int
lora_alpha: float

@dataclass
class ModelArgs:
Expand All @@ -22,7 +27,7 @@ class ModelArgs:
norm_eps: float = 1e-5
max_seq_len: int = 2048
dropout: float = 0.0

lora_modules: Optional[dict[str, LoraArgs]] = field(default_factory = dict)

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float):
Expand Down Expand Up @@ -91,6 +96,37 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)

class LoraLinear(nn.Module):
def __init__(self, base_layer, args:dict[str, LoraArgs], init=True):
super().__init__()
self.base_layer = base_layer
self.lora_A = nn.Linear(self.base_layer.in_features, args.lora_r, bias=False)
self.lora_B = nn.Linear(args.lora_r, self.base_layer.out_features, bias=False)
self.scaling = args.lora_alpha/args.lora_r
self.base_layer.requires_grad_(False)
self.weight=self.merge_weights() # it does not saves the finetuning checkpoints separately!

if init:
self._init_weights()
self.initialized=True

def _init_weights(self):
nn.init.zeros_(self.lora_A.weight)
nn.init.normal_(self.lora_B.weight)

def get_delta_weights(self) -> torch.Tensor:
output_tensor = (self.lora_B.weight @ self.lora_A.weight) * self.scaling
return output_tensor

def merge_weights(self):
delta_weight = self.get_delta_weights()
self.base_layer.weight.data = self.base_layer.weight.data + delta_weight

def forward(self,
x:torch.tensor
) -> torch.tensor:
return self.base_layer.forward(x) + self.lora_B(self.lora_A(x)) #(tunable) dropout?

class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
Expand Down Expand Up @@ -165,16 +201,16 @@ def forward(


class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
def __init__(self, args: ModelArgs):
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
if args.hidden_dim is None:
args.hidden_dim = 4 * args.dim
args.hidden_dim = int(2 * args.hidden_dim / 3)
args.hidden_dim = args.multiple_of * ((args.hidden_dim + args.multiple_of - 1) // args.multiple_of)
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.dropout = nn.Dropout(args.dropout)

def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
Expand All @@ -187,12 +223,7 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=args.hidden_dim,
multiple_of=args.multiple_of,
dropout=args.dropout,
)
self.feed_forward = FeedForward(args)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
Expand Down Expand Up @@ -245,6 +276,8 @@ def _init_weights(self, module):
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
#elif isinstance(module, LoraLinear): #needed?
# module._init_weights()

def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
_bsz, seqlen = tokens.shape
Expand Down Expand Up @@ -341,3 +374,19 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
idx = torch.cat((idx, idx_next), dim=1)

return idx

class LoraTransformer(Transformer):
def __init__(self, params: ModelArgs):
super().__init__(params)

def add_lora(self, args:LoraArgs):
for lk in args.keys(): # e.g. layers.0.attention.wq
for dk in self.state_dict(): # e.g. layers.0.attention.wq.weight
if dk.startswith(lk): # as per e.g. above
new_module = LoraLinear(self.get_submodule(lk), args[lk])
parent, child = '.'.join(lk.split('.')[:-1]), lk.split('.')[-1]
setattr(self.get_submodule(parent), child, new_module)

def freeze(self):
for param in self.parameters():
param.requires_grad = False
35 changes: 25 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from functools import partial

import torch
from model import Transformer, ModelArgs
from model import Transformer, ModelArgs, LoraArgs, LoraTransformer
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP

Expand All @@ -39,7 +39,7 @@
eval_iters = 100
eval_only = False # if True, script exits right after the first eval
always_save_checkpoint = False # if True, always save a checkpoint after each eval
init_from = "scratch" # 'scratch' or 'resume'
init_from = "finetune" # 'scratch' or 'resume' or 'finetune'
# wandb logging
wandb_log = False # disabled by default
wandb_project = "llamac"
Expand All @@ -56,10 +56,12 @@
n_kv_heads = 6
multiple_of = 32
dropout = 0.0
# lora specification
lora_modules = {f'layers.{l_id}.attention.{w}': LoraArgs(2, 1) for w in ['wq', 'wk', 'wv', 'wo'] for l_id in range(5)}
# adamw optimizer
gradient_accumulation_steps = 4 # used to simulate larger batch sizes
learning_rate = 5e-4 # max learning rate
max_iters = 100000 # total number of training iterations
max_iters = 100000 #298100 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
Expand Down Expand Up @@ -153,14 +155,15 @@
multiple_of=multiple_of,
max_seq_len=max_seq_len,
dropout=dropout,
lora_modules=lora_modules,
) # start with model_args from command line
if init_from == "scratch":
# init a new model from scratch
print("Initializing a new model from scratch")
gptconf = ModelArgs(**model_args)
model = Transformer(gptconf)
elif init_from == "resume":
print(f"Resuming training from {out_dir}")
elif init_from == "resume" or init_from == "finetune":
print(f"{init_from.capitalize()} from {out_dir}")
# resume training from a checkpoint.
ckpt_path = os.path.join(out_dir, "ckpt.pt")
checkpoint = torch.load(ckpt_path, map_location=device)
Expand All @@ -171,7 +174,10 @@
model_args[k] = checkpoint_model_args[k]
# create the model
gptconf = ModelArgs(**model_args)
model = Transformer(gptconf)
if init_from == 'finetune':
model = LoraTransformer(gptconf)
else:
model = Transformer(gptconf)
state_dict = checkpoint["model"]
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
Expand All @@ -180,6 +186,9 @@
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
if init_from == 'finetune':
model.freeze()
model.add_lora(model_args['lora_modules'])
iter_num = checkpoint["iter_num"]
best_val_loss = checkpoint["best_val_loss"]
model.to(device)
Expand All @@ -189,7 +198,7 @@

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == "resume" and "optimizer" in checkpoint:
if (init_from == "resume" or init_from == "finetune") and "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
checkpoint = None # free up memory

Expand Down Expand Up @@ -226,7 +235,9 @@ def estimate_loss():
return out

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
def get_lr(_it, starting_it:int = 0):
assert _it >= starting_it
it = _it - starting_it # trick to restart lr schedule for finetuning
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * it / warmup_iters
Expand All @@ -251,9 +262,13 @@ def get_lr(it):
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
if init_from == 'finetune':
starting_it = iter_num # shift it for finetuning withouth touching get_lr implementation
else:
starting_it = 0
while True:
# determine and set the learning rate for this iteration
lr = get_lr(iter_num) if decay_lr else learning_rate
lr = get_lr(iter_num, starting_it=starting_it) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group["lr"] = lr

Expand Down Expand Up @@ -287,7 +302,7 @@ def get_lr(it):
"config": config,
}
print(f"saving checkpoint to {out_dir}")
torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
torch.save(checkpoint, os.path.join(out_dir, f"ckpt_from_{init_from}.pt"))
model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
if iter_num == 0 and eval_only:
break
Expand Down