Skip to content

Commit 0d076e0

Browse files
authored
Merge pull request #14 from codertimo/alpha0.0.1a4
alpha-0.0.1a4 version released
2 parents 7efd2b5 + 427373c commit 0d076e0

File tree

9 files changed

+137
-48
lines changed

9 files changed

+137
-48
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
same "printed page" as the copyright notice for easier
187187
identification within third-party archives.
188188

189-
Copyright 2018 Junseong Kim, Scatter Labs, BERT contributors
189+
Copyright 2018 Junseong Kim, Scatter Lab, BERT contributors
190190

191191
Licensed under the Apache License, Version 2.0 (the "License");
192192
you may not use this file except in compliance with the License.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ bert-vocab -c data/corpus.small -o data/vocab.small
6060

6161
### 2. Train your own BERT model
6262
```shell
63-
bert -c data/dataset.small -v data/vocab.small -o output/bert.model
63+
bert -c data/corpus.small -v data/vocab.small -o output/bert.model
6464
```
6565

6666
## Language Model Pre-training

bert_pytorch/__main__.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,30 @@
1010
def train():
1111
parser = argparse.ArgumentParser()
1212

13-
parser.add_argument("-c", "--train_dataset", required=True, type=str)
14-
parser.add_argument("-t", "--test_dataset", type=str, default=None)
15-
parser.add_argument("-v", "--vocab_path", required=True, type=str)
16-
parser.add_argument("-o", "--output_path", required=True, type=str)
17-
18-
parser.add_argument("-hs", "--hidden", type=int, default=256)
19-
parser.add_argument("-l", "--layers", type=int, default=8)
20-
parser.add_argument("-a", "--attn_heads", type=int, default=8)
21-
parser.add_argument("-s", "--seq_len", type=int, default=20)
22-
23-
parser.add_argument("-b", "--batch_size", type=int, default=64)
24-
parser.add_argument("-e", "--epochs", type=int, default=10)
25-
parser.add_argument("-w", "--num_workers", type=int, default=5)
26-
parser.add_argument("--with_cuda", type=bool, default=True)
27-
parser.add_argument("--log_freq", type=int, default=10)
28-
parser.add_argument("--corpus_lines", type=int, default=None)
29-
30-
parser.add_argument("--lr", type=float, default=1e-3)
31-
parser.add_argument("--adam_weight_decay", type=float, default=0.01)
32-
parser.add_argument("--adam_beta1", type=float, default=0.9)
33-
parser.add_argument("--adam_beta2", type=float, default=0.999)
13+
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert")
14+
parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set")
15+
parser.add_argument("-v", "--vocab_path", required=True, type=str, help="built vocab model path with bert-vocab")
16+
parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model")
17+
18+
parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model")
19+
parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers")
20+
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
21+
parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len")
22+
23+
parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size")
24+
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
25+
parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size")
26+
27+
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
28+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
29+
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
30+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
31+
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
32+
33+
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam")
34+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
35+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
36+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")
3437

3538
args = parser.parse_args()
3639

@@ -39,11 +42,12 @@ def train():
3942
print("Vocab Size: ", len(vocab))
4043

4144
print("Loading Train Dataset", args.train_dataset)
42-
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, corpus_lines=args.corpus_lines)
45+
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
46+
corpus_lines=args.corpus_lines, on_memory=args.on_memory)
4347

4448
print("Loading Test Dataset", args.test_dataset)
45-
test_dataset = BERTDataset(args.test_dataset, vocab,
46-
seq_len=args.seq_len) if args.test_dataset is not None else None
49+
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
50+
if args.test_dataset is not None else None
4751

4852
print("Creating Dataloader")
4953
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
@@ -56,7 +60,7 @@ def train():
5660
print("Creating BERT Trainer")
5761
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
5862
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
59-
with_cuda=args.with_cuda, log_freq=args.log_freq)
63+
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq)
6064

6165
print("Training Start")
6266
for epoch in range(args.epochs):

bert_pytorch/dataset/dataset.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,37 @@
55

66

77
class BERTDataset(Dataset):
8-
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None):
8+
def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
99
self.vocab = vocab
1010
self.seq_len = seq_len
1111

12+
self.on_memory = on_memory
13+
self.corpus_lines = corpus_lines
14+
self.corpus_path = corpus_path
15+
self.encoding = encoding
16+
1217
with open(corpus_path, "r", encoding=encoding) as f:
13-
self.datas = [line[:-1].split("\t")
14-
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
18+
if self.corpus_lines is None and not on_memory:
19+
for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines):
20+
self.corpus_lines += 1
21+
22+
if on_memory:
23+
self.lines = [line[:-1].split("\t")
24+
for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)]
25+
self.corpus_lines = len(self.lines)
26+
27+
if not on_memory:
28+
self.file = open(corpus_path, "r", encoding=encoding)
29+
self.random_file = open(corpus_path, "r", encoding=encoding)
30+
31+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
32+
self.random_file.__next__()
1533

1634
def __len__(self):
17-
return len(self.datas)
35+
return self.corpus_lines
1836

1937
def __getitem__(self, item):
20-
t1, (t2, is_next_label) = self.datas[item][0], self.random_sent(item)
38+
t1, t2, is_next_label = self.random_sent(item)
2139
t1_random, t1_label = self.random_word(t1)
2240
t2_random, t2_label = self.random_word(t2)
2341

@@ -49,16 +67,18 @@ def random_word(self, sentence):
4967
for i, token in enumerate(tokens):
5068
prob = random.random()
5169
if prob < 0.15:
52-
# 80% randomly change token to make token
53-
if prob < prob * 0.8:
70+
prob /= 0.15
71+
72+
# 80% randomly change token to mask token
73+
if prob < 0.8:
5474
tokens[i] = self.vocab.mask_index
5575

5676
# 10% randomly change token to random token
57-
elif prob * 0.8 <= prob < prob * 0.9:
77+
elif prob < 0.9:
5878
tokens[i] = random.randrange(len(self.vocab))
5979

6080
# 10% randomly change token to current token
61-
elif prob >= prob * 0.9:
81+
else:
6282
tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
6383

6484
output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))
@@ -70,8 +90,36 @@ def random_word(self, sentence):
7090
return tokens, output_label
7191

7292
def random_sent(self, index):
93+
t1, t2 = self.get_corpus_line(index)
94+
7395
# output_text, label(isNotNext:0, isNext:1)
7496
if random.random() > 0.5:
75-
return self.datas[index][1], 1
97+
return t1, t2, 1
98+
else:
99+
return t1, self.get_random_line(), 0
100+
101+
def get_corpus_line(self, item):
102+
if self.on_memory:
103+
return self.lines[item][0], self.lines[item][1]
76104
else:
77-
return self.datas[random.randrange(len(self.datas))][1], 0
105+
line = self.file.__next__()
106+
if line is None:
107+
self.file.close()
108+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
109+
line = self.file.__next__()
110+
111+
t1, t2 = line[:-1].split("\t")
112+
return t1, t2
113+
114+
def get_random_line(self):
115+
if self.on_memory:
116+
return self.lines[random.randrange(len(self.lines))][1]
117+
118+
line = self.file.__next__()
119+
if line is None:
120+
self.file.close()
121+
self.file = open(self.corpus_path, "r", encoding=self.encoding)
122+
for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
123+
self.random_file.__next__()
124+
line = self.random_file.__next__()
125+
return line[:-1].split("\t")[1]

bert_pytorch/model/embedding/position.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, d_model, max_len=512):
1313
pe.require_grad = False
1414

1515
position = torch.arange(0, max_len).float().unsqueeze(1)
16-
div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp()
16+
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
1717

1818
pe[:, 0::2] = torch.sin(position * div_term)
1919
pe[:, 1::2] = torch.cos(position * div_term)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
'''A wrapper class for optimizer '''
2+
import numpy as np
3+
4+
5+
class ScheduledOptim():
6+
'''A simple wrapper class for learning rate scheduling'''
7+
8+
def __init__(self, optimizer, d_model, n_warmup_steps):
9+
self._optimizer = optimizer
10+
self.n_warmup_steps = n_warmup_steps
11+
self.n_current_steps = 0
12+
self.init_lr = np.power(d_model, -0.5)
13+
14+
def step_and_update_lr(self):
15+
"Step with the inner optimizer"
16+
self._update_learning_rate()
17+
self._optimizer.step()
18+
19+
def zero_grad(self):
20+
"Zero out the gradients by the inner optimizer"
21+
self._optimizer.zero_grad()
22+
23+
def _get_lr_scale(self):
24+
return np.min([
25+
np.power(self.n_current_steps, -0.5),
26+
np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
27+
28+
def _update_learning_rate(self):
29+
''' Learning rate scheduling per step '''
30+
31+
self.n_current_steps += 1
32+
lr = self.init_lr * self._get_lr_scale()
33+
34+
for param_group in self._optimizer.param_groups:
35+
param_group['lr'] = lr

bert_pytorch/trainer/pretrain.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.utils.data import DataLoader
55

66
from ..model import BERTLM, BERT
7+
from .optim_schedule import ScheduledOptim
78

89
import tqdm
910

@@ -21,8 +22,8 @@ class BERTTrainer:
2122

2223
def __init__(self, bert: BERT, vocab_size: int,
2324
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
24-
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
25-
with_cuda: bool = True, log_freq: int = 10):
25+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
26+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10):
2627
"""
2728
:param bert: BERT model which you want to train
2829
:param vocab_size: total word vocab size
@@ -45,16 +46,17 @@ def __init__(self, bert: BERT, vocab_size: int,
4546
self.model = BERTLM(bert, vocab_size).to(self.device)
4647

4748
# Distributed GPU training if CUDA can detect more than 1 GPU
48-
if torch.cuda.device_count() > 1:
49+
if with_cuda and torch.cuda.device_count() > 1:
4950
print("Using %d GPUS for BERT" % torch.cuda.device_count())
50-
self.model = nn.DataParallel(self.model)
51+
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
5152

5253
# Setting the train and test data loader
5354
self.train_data = train_dataloader
5455
self.test_data = test_dataloader
5556

5657
# Setting the Adam optimizer with hyper-param
5758
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
59+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
5860

5961
# Using Negative Log Likelihood Loss function for predicting the masked_token
6062
self.criterion = nn.NLLLoss(ignore_index=0)
@@ -110,9 +112,9 @@ def iteration(self, epoch, data_loader, train=True):
110112

111113
# 3. backward and optimization only in train
112114
if train:
113-
self.optim.zero_grad()
115+
self.optim_schedule.zero_grad()
114116
loss.backward()
115-
self.optim.step()
117+
self.optim_schedule.step_and_update_lr()
116118

117119
# next sentence prediction accuracy
118120
correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
tqdm
22
numpy
3-
torch>=0.4.0
3+
torch>=0.4.0

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import os
44
import sys
55

6-
__version__ = "0.0.1a3"
6+
__version__ = "0.0.1a4"
77

88
with open("requirements.txt") as f:
9-
require_packages = [line[:-1] for line in f]
9+
require_packages = [line[:-1] if line[-1] == "\n" else line for line in f]
1010

1111
with open("README.md", "r", encoding="utf-8") as f:
1212
long_description = f.read()

0 commit comments

Comments
 (0)