diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..69091448 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.DS_Store +__pycache__/ +data/ +err.txt +finetune.py +models_check_tok_15M/ +models_check_tok_260K/ +out/ +resources.txt +run +runq +stdout.txt +test/ diff --git a/TODO b/TODO new file mode 100644 index 00000000..c1ac4fcd --- /dev/null +++ b/TODO @@ -0,0 +1,2 @@ +TODO +- Lora for the embedding \ No newline at end of file diff --git a/model.py b/model.py index 9e4ce220..fd8410ae 100644 --- a/model.py +++ b/model.py @@ -1,13 +1,18 @@ import math import struct import inspect -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional, Tuple import numpy as np import torch import torch.nn.functional as F -from torch import nn +from torch import nn + +@dataclass +class LoraArgs: + lora_r: int + lora_alpha: float @dataclass class ModelArgs: @@ -22,7 +27,7 @@ class ModelArgs: norm_eps: float = 1e-5 max_seq_len: int = 2048 dropout: float = 0.0 - + lora_modules: Optional[dict[str, LoraArgs]] = field(default_factory = dict) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float): @@ -91,6 +96,37 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) +class LoraLinear(nn.Module): + def __init__(self, base_layer, args:dict[str, LoraArgs], init=True): + super().__init__() + self.base_layer = base_layer + self.lora_A = nn.Linear(self.base_layer.in_features, args.lora_r, bias=False) + self.lora_B = nn.Linear(args.lora_r, self.base_layer.out_features, bias=False) + self.scaling = args.lora_alpha/args.lora_r + self.base_layer.requires_grad_(False) + self.weight=self.merge_weights() # it does not saves the finetuning checkpoints separately! + + if init: + self._init_weights() + self.initialized=True + + def _init_weights(self): + nn.init.zeros_(self.lora_A.weight) + nn.init.normal_(self.lora_B.weight) + + def get_delta_weights(self) -> torch.Tensor: + output_tensor = (self.lora_B.weight @ self.lora_A.weight) * self.scaling + return output_tensor + + def merge_weights(self): + delta_weight = self.get_delta_weights() + self.base_layer.weight.data = self.base_layer.weight.data + delta_weight + + def forward(self, + x:torch.tensor + ) -> torch.tensor: + return self.base_layer.forward(x) + self.lora_B(self.lora_A(x)) #(tunable) dropout? + class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -165,16 +201,16 @@ def forward( class FeedForward(nn.Module): - def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): + def __init__(self, args: ModelArgs): super().__init__() - if hidden_dim is None: - hidden_dim = 4 * dim - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.dropout = nn.Dropout(dropout) + if args.hidden_dim is None: + args.hidden_dim = 4 * args.dim + args.hidden_dim = int(2 * args.hidden_dim / 3) + args.hidden_dim = args.multiple_of * ((args.hidden_dim + args.multiple_of - 1) // args.multiple_of) + self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) + self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + self.dropout = nn.Dropout(args.dropout) def forward(self, x): return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) @@ -187,12 +223,7 @@ def __init__(self, layer_id: int, args: ModelArgs): self.dim = args.dim self.head_dim = args.dim // args.n_heads self.attention = Attention(args) - self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=args.hidden_dim, - multiple_of=args.multiple_of, - dropout=args.dropout, - ) + self.feed_forward = FeedForward(args) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) @@ -245,6 +276,8 @@ def _init_weights(self, module): torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + #elif isinstance(module, LoraLinear): #needed? + # module._init_weights() def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor: _bsz, seqlen = tokens.shape @@ -341,3 +374,19 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): idx = torch.cat((idx, idx_next), dim=1) return idx + +class LoraTransformer(Transformer): + def __init__(self, params: ModelArgs): + super().__init__(params) + + def add_lora(self, args:LoraArgs): + for lk in args.keys(): # e.g. layers.0.attention.wq + for dk in self.state_dict(): # e.g. layers.0.attention.wq.weight + if dk.startswith(lk): # as per e.g. above + new_module = LoraLinear(self.get_submodule(lk), args[lk]) + parent, child = '.'.join(lk.split('.')[:-1]), lk.split('.')[-1] + setattr(self.get_submodule(parent), child, new_module) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False \ No newline at end of file diff --git a/train.py b/train.py index e8321d8f..3b98ba78 100644 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ from functools import partial import torch -from model import Transformer, ModelArgs +from model import Transformer, ModelArgs, LoraArgs, LoraTransformer from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP @@ -39,7 +39,7 @@ eval_iters = 100 eval_only = False # if True, script exits right after the first eval always_save_checkpoint = False # if True, always save a checkpoint after each eval -init_from = "scratch" # 'scratch' or 'resume' +init_from = "finetune" # 'scratch' or 'resume' or 'finetune' # wandb logging wandb_log = False # disabled by default wandb_project = "llamac" @@ -56,10 +56,12 @@ n_kv_heads = 6 multiple_of = 32 dropout = 0.0 +# lora specification +lora_modules = {f'layers.{l_id}.attention.{w}': LoraArgs(2, 1) for w in ['wq', 'wk', 'wv', 'wo'] for l_id in range(5)} # adamw optimizer gradient_accumulation_steps = 4 # used to simulate larger batch sizes learning_rate = 5e-4 # max learning rate -max_iters = 100000 # total number of training iterations +max_iters = 100000 #298100 # total number of training iterations weight_decay = 1e-1 beta1 = 0.9 beta2 = 0.95 @@ -153,14 +155,15 @@ multiple_of=multiple_of, max_seq_len=max_seq_len, dropout=dropout, + lora_modules=lora_modules, ) # start with model_args from command line if init_from == "scratch": # init a new model from scratch print("Initializing a new model from scratch") gptconf = ModelArgs(**model_args) model = Transformer(gptconf) -elif init_from == "resume": - print(f"Resuming training from {out_dir}") +elif init_from == "resume" or init_from == "finetune": + print(f"{init_from.capitalize()} from {out_dir}") # resume training from a checkpoint. ckpt_path = os.path.join(out_dir, "ckpt.pt") checkpoint = torch.load(ckpt_path, map_location=device) @@ -171,7 +174,10 @@ model_args[k] = checkpoint_model_args[k] # create the model gptconf = ModelArgs(**model_args) - model = Transformer(gptconf) + if init_from == 'finetune': + model = LoraTransformer(gptconf) + else: + model = Transformer(gptconf) state_dict = checkpoint["model"] # fix the keys of the state dictionary :( # honestly no idea how checkpoints sometimes get this prefix, have to debug more @@ -180,6 +186,9 @@ if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) model.load_state_dict(state_dict) + if init_from == 'finetune': + model.freeze() + model.add_lora(model_args['lora_modules']) iter_num = checkpoint["iter_num"] best_val_loss = checkpoint["best_val_loss"] model.to(device) @@ -189,7 +198,7 @@ # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) -if init_from == "resume" and "optimizer" in checkpoint: +if (init_from == "resume" or init_from == "finetune") and "optimizer" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer"]) checkpoint = None # free up memory @@ -226,7 +235,9 @@ def estimate_loss(): return out # learning rate decay scheduler (cosine with warmup) -def get_lr(it): +def get_lr(_it, starting_it:int = 0): + assert _it >= starting_it + it = _it - starting_it # trick to restart lr schedule for finetuning # 1) linear warmup for warmup_iters steps if it < warmup_iters: return learning_rate * it / warmup_iters @@ -251,9 +262,13 @@ def get_lr(it): local_iter_num = 0 # number of iterations in the lifetime of this process raw_model = model.module if ddp else model # unwrap DDP container if needed running_mfu = -1.0 +if init_from == 'finetune': + starting_it = iter_num # shift it for finetuning withouth touching get_lr implementation +else: + starting_it = 0 while True: # determine and set the learning rate for this iteration - lr = get_lr(iter_num) if decay_lr else learning_rate + lr = get_lr(iter_num, starting_it=starting_it) if decay_lr else learning_rate for param_group in optimizer.param_groups: param_group["lr"] = lr @@ -287,7 +302,7 @@ def get_lr(it): "config": config, } print(f"saving checkpoint to {out_dir}") - torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) + torch.save(checkpoint, os.path.join(out_dir, f"ckpt_from_{init_from}.pt")) model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0) if iter_num == 0 and eval_only: break