Skip to content

Commit a93bdff

Browse files
authored
[Trainer] Support constant and consine lr scheduler (#2511)
* support constant and consine lr scheduler * fix doc * delete * add doc
1 parent 6ccc0be commit a93bdff

File tree

4 files changed

+170
-12
lines changed

4 files changed

+170
-12
lines changed

docs/trainer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
285285
--lr_scheduler_type
286286
要使用的学习率调度策略。 (`str`, 可选, 默认为 `"linear"`)
287287

288-
The scheduler type to use. (default: linear)
288+
The scheduler type to use. (default: linear) 支持,linear, cosine, constant, constant_with_warmup.
289289

290290
--warmup_ratio
291291
用于从 0`learning_rate` 的线性warmup的总训练步骤的比例。(`float`,可选,默认为 0.0

paddlenlp/trainer/trainer_base.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
OptimizerNames,
6767
PREFIX_CHECKPOINT_DIR,
6868
get_last_checkpoint,
69+
get_scheduler,
6970
)
7071
from .trainer_callback import (
7172
CallbackHandler,
@@ -919,14 +920,8 @@ def create_scheduler(self, num_training_steps: int):
919920
Args:
920921
num_training_steps (int): The number of training steps to do.
921922
"""
922-
923-
def get_scheduler(lr_scheduler_type, learning_rate, num_warmup_steps,
924-
num_training_steps):
925-
# TODO @ZHUI support others
926-
return LinearDecayWithWarmup(learning_rate, num_training_steps,
927-
num_warmup_steps)
928-
929-
warmup = self.args.warmup_steps if self.args.warmup_steps > 0 else self.args.warmup_ratio
923+
warmup = self.args.warmup_steps if self.args.warmup_steps > 0 else int(
924+
self.args.warmup_ratio * num_training_steps)
930925

931926
if self.lr_scheduler is None:
932927
self.lr_scheduler = get_scheduler(

paddlenlp/trainer/trainer_utils.py

Lines changed: 162 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from typing import Dict, NamedTuple, Optional, Tuple, Union
2929

3030
import numpy as np
31+
from paddle.optimizer.lr import LambdaDecay
3132

3233
__all__ = [
3334
"TrainOutput",
@@ -38,6 +39,7 @@
3839
"set_seed",
3940
"speed_metrics",
4041
"get_last_checkpoint",
42+
"get_scheduler",
4143
]
4244

4345

@@ -178,12 +180,170 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None):
178180
class SchedulerType(ExplicitEnum):
179181
LINEAR = "linear"
180182
COSINE = "cosine"
181-
COSINE_WITH_RESTARTS = "cosine_with_restarts"
182-
POLYNOMIAL = "polynomial"
183183
CONSTANT = "constant"
184184
CONSTANT_WITH_WARMUP = "constant_with_warmup"
185185

186186

187+
def get_constant_schedule(learning_rate: float, last_epoch: int = -1):
188+
"""
189+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
190+
Args:
191+
learning_rate (float)
192+
The initial learning rate. It is a python float number.
193+
last_epoch (`int`, *optional*, defaults to -1):
194+
The index of the last epoch when resuming training.
195+
Return:
196+
`paddle.optimizer.lr.LambdaDecay` with the appropriate schedule.
197+
"""
198+
return LambdaDecay(learning_rate, lambda _: 1, last_epoch=last_epoch)
199+
200+
201+
def get_constant_schedule_with_warmup(learning_rate: float,
202+
num_warmup_steps: int,
203+
last_epoch: int = -1):
204+
"""
205+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
206+
increases linearly between 0 and the initial lr set in the optimizer.
207+
Args:
208+
learning_rate (float)
209+
The initial learning rate. It is a python float number.
210+
num_warmup_steps (`int`):
211+
The number of steps for the warmup phase.
212+
last_epoch (`int`, *optional*, defaults to -1):
213+
The index of the last epoch when resuming training.
214+
Return:
215+
`paddle.optimizer.lr.LambdaDecay` with the appropriate schedule.
216+
"""
217+
218+
def lr_lambda(current_step: int):
219+
if current_step < num_warmup_steps:
220+
return float(current_step) / float(max(1.0, num_warmup_steps))
221+
return 1.0
222+
223+
return LambdaDecay(learning_rate, lr_lambda, last_epoch=last_epoch)
224+
225+
226+
def get_linear_schedule_with_warmup(learning_rate: float,
227+
num_warmup_steps,
228+
num_training_steps,
229+
last_epoch=-1):
230+
"""
231+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
232+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
233+
Args:
234+
learning_rate (float)
235+
The initial learning rate. It is a python float number.
236+
num_warmup_steps (`int`):
237+
The number of steps for the warmup phase.
238+
num_training_steps (`int`):
239+
The total number of training steps.
240+
last_epoch (`int`, *optional*, defaults to -1):
241+
The index of the last epoch when resuming training.
242+
Return:
243+
`paddle.optimizer.lr.LambdaDecay` with the appropriate schedule.
244+
"""
245+
246+
def lr_lambda(current_step: int):
247+
if current_step < num_warmup_steps:
248+
return float(current_step) / float(max(1, num_warmup_steps))
249+
return max(
250+
0.0,
251+
float(num_training_steps - current_step) /
252+
float(max(1, num_training_steps - num_warmup_steps)))
253+
254+
return LambdaDecay(learning_rate, lr_lambda, last_epoch)
255+
256+
257+
def get_cosine_schedule_with_warmup(learning_rate: float,
258+
num_warmup_steps: int,
259+
num_training_steps: int,
260+
num_cycles: float = 0.5,
261+
last_epoch: int = -1):
262+
"""
263+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
264+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
265+
initial lr set in the optimizer.
266+
Args:
267+
learning_rate (float)
268+
The initial learning rate. It is a python float number.
269+
num_warmup_steps (`int`):
270+
The number of steps for the warmup phase.
271+
num_training_steps (`int`):
272+
The total number of training steps.
273+
num_cycles (`float`, *optional*, defaults to 0.5):
274+
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
275+
following a half-cosine).
276+
last_epoch (`int`, *optional*, defaults to -1):
277+
The index of the last epoch when resuming training.
278+
Return:
279+
`paddle.optimizer.lr.LambdaDecay` with the appropriate schedule.
280+
"""
281+
282+
def lr_lambda(current_step):
283+
if current_step < num_warmup_steps:
284+
return float(current_step) / float(max(1, num_warmup_steps))
285+
progress = float(current_step - num_warmup_steps) / float(
286+
max(1, num_training_steps - num_warmup_steps))
287+
return max(
288+
0.0, 0.5 *
289+
(1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
290+
291+
return LambdaDecay(learning_rate, lr_lambda, last_epoch)
292+
293+
294+
TYPE_TO_SCHEDULER_FUNCTION = {
295+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
296+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
297+
SchedulerType.CONSTANT: get_constant_schedule,
298+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
299+
}
300+
301+
302+
def get_scheduler(
303+
name: Union[str, SchedulerType],
304+
learning_rate: float,
305+
num_warmup_steps: Optional[int] = None,
306+
num_training_steps: Optional[int] = None,
307+
):
308+
"""
309+
Unified API to get any scheduler from its name.
310+
Args:
311+
name (`str` or `SchedulerType`):
312+
The name of the scheduler to use.
313+
learning_rate (float)
314+
The initial learning rate. It is a python float number.
315+
num_warmup_steps (`int`, *optional*):
316+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
317+
optional), the function will raise an error if it's unset and the scheduler type requires it.
318+
num_training_steps (`int``, *optional*):
319+
The number of training steps to do. This is not required by all schedulers (hence the argument being
320+
optional), the function will raise an error if it's unset and the scheduler type requires it.
321+
"""
322+
name = SchedulerType(name)
323+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
324+
if name == SchedulerType.CONSTANT:
325+
return schedule_func(learning_rate)
326+
327+
# All other schedulers require `num_warmup_steps`
328+
if num_warmup_steps is None:
329+
raise ValueError(
330+
f"{name} requires `num_warmup_steps`, please provide that argument."
331+
)
332+
333+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
334+
return schedule_func(learning_rate, num_warmup_steps=num_warmup_steps)
335+
336+
# All other schedulers require `num_training_steps`
337+
if num_training_steps is None:
338+
raise ValueError(
339+
f"{name} requires `num_training_steps`, please provide that argument."
340+
)
341+
342+
return schedule_func(learning_rate,
343+
num_warmup_steps=num_warmup_steps,
344+
num_training_steps=num_training_steps)
345+
346+
187347
def _secs2timedelta(secs):
188348
"""
189349
convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,10 @@ class TrainingArguments:
322322
)
323323
lr_scheduler_type: str = field(
324324
default="linear",
325-
metadata={"help": "The scheduler type to use."},
325+
metadata={
326+
"help":
327+
"The scheduler type to use. suppor linear, cosine, constant, constant_with_warmup"
328+
},
326329
)
327330
warmup_ratio: float = field(
328331
default=0.0,

0 commit comments

Comments
 (0)