|
28 | 28 | from typing import Dict, NamedTuple, Optional, Tuple, Union
|
29 | 29 |
|
30 | 30 | import numpy as np
|
| 31 | +from paddle.optimizer.lr import LambdaDecay |
31 | 32 |
|
32 | 33 | __all__ = [
|
33 | 34 | "TrainOutput",
|
|
38 | 39 | "set_seed",
|
39 | 40 | "speed_metrics",
|
40 | 41 | "get_last_checkpoint",
|
| 42 | + "get_scheduler", |
41 | 43 | ]
|
42 | 44 |
|
43 | 45 |
|
@@ -178,12 +180,170 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None):
|
178 | 180 | class SchedulerType(ExplicitEnum):
|
179 | 181 | LINEAR = "linear"
|
180 | 182 | COSINE = "cosine"
|
181 |
| - COSINE_WITH_RESTARTS = "cosine_with_restarts" |
182 |
| - POLYNOMIAL = "polynomial" |
183 | 183 | CONSTANT = "constant"
|
184 | 184 | CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
185 | 185 |
|
186 | 186 |
|
| 187 | +def get_constant_schedule(learning_rate: float, last_epoch: int = -1): |
| 188 | + """ |
| 189 | + Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
| 190 | + Args: |
| 191 | + learning_rate (float) |
| 192 | + The initial learning rate. It is a python float number. |
| 193 | + last_epoch (`int`, *optional*, defaults to -1): |
| 194 | + The index of the last epoch when resuming training. |
| 195 | + Return: |
| 196 | + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| 197 | + """ |
| 198 | + return LambdaDecay(learning_rate, lambda _: 1, last_epoch=last_epoch) |
| 199 | + |
| 200 | + |
| 201 | +def get_constant_schedule_with_warmup(learning_rate: float, |
| 202 | + num_warmup_steps: int, |
| 203 | + last_epoch: int = -1): |
| 204 | + """ |
| 205 | + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate |
| 206 | + increases linearly between 0 and the initial lr set in the optimizer. |
| 207 | + Args: |
| 208 | + learning_rate (float) |
| 209 | + The initial learning rate. It is a python float number. |
| 210 | + num_warmup_steps (`int`): |
| 211 | + The number of steps for the warmup phase. |
| 212 | + last_epoch (`int`, *optional*, defaults to -1): |
| 213 | + The index of the last epoch when resuming training. |
| 214 | + Return: |
| 215 | + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| 216 | + """ |
| 217 | + |
| 218 | + def lr_lambda(current_step: int): |
| 219 | + if current_step < num_warmup_steps: |
| 220 | + return float(current_step) / float(max(1.0, num_warmup_steps)) |
| 221 | + return 1.0 |
| 222 | + |
| 223 | + return LambdaDecay(learning_rate, lr_lambda, last_epoch=last_epoch) |
| 224 | + |
| 225 | + |
| 226 | +def get_linear_schedule_with_warmup(learning_rate: float, |
| 227 | + num_warmup_steps, |
| 228 | + num_training_steps, |
| 229 | + last_epoch=-1): |
| 230 | + """ |
| 231 | + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
| 232 | + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
| 233 | + Args: |
| 234 | + learning_rate (float) |
| 235 | + The initial learning rate. It is a python float number. |
| 236 | + num_warmup_steps (`int`): |
| 237 | + The number of steps for the warmup phase. |
| 238 | + num_training_steps (`int`): |
| 239 | + The total number of training steps. |
| 240 | + last_epoch (`int`, *optional*, defaults to -1): |
| 241 | + The index of the last epoch when resuming training. |
| 242 | + Return: |
| 243 | + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| 244 | + """ |
| 245 | + |
| 246 | + def lr_lambda(current_step: int): |
| 247 | + if current_step < num_warmup_steps: |
| 248 | + return float(current_step) / float(max(1, num_warmup_steps)) |
| 249 | + return max( |
| 250 | + 0.0, |
| 251 | + float(num_training_steps - current_step) / |
| 252 | + float(max(1, num_training_steps - num_warmup_steps))) |
| 253 | + |
| 254 | + return LambdaDecay(learning_rate, lr_lambda, last_epoch) |
| 255 | + |
| 256 | + |
| 257 | +def get_cosine_schedule_with_warmup(learning_rate: float, |
| 258 | + num_warmup_steps: int, |
| 259 | + num_training_steps: int, |
| 260 | + num_cycles: float = 0.5, |
| 261 | + last_epoch: int = -1): |
| 262 | + """ |
| 263 | + Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| 264 | + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
| 265 | + initial lr set in the optimizer. |
| 266 | + Args: |
| 267 | + learning_rate (float) |
| 268 | + The initial learning rate. It is a python float number. |
| 269 | + num_warmup_steps (`int`): |
| 270 | + The number of steps for the warmup phase. |
| 271 | + num_training_steps (`int`): |
| 272 | + The total number of training steps. |
| 273 | + num_cycles (`float`, *optional*, defaults to 0.5): |
| 274 | + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
| 275 | + following a half-cosine). |
| 276 | + last_epoch (`int`, *optional*, defaults to -1): |
| 277 | + The index of the last epoch when resuming training. |
| 278 | + Return: |
| 279 | + `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| 280 | + """ |
| 281 | + |
| 282 | + def lr_lambda(current_step): |
| 283 | + if current_step < num_warmup_steps: |
| 284 | + return float(current_step) / float(max(1, num_warmup_steps)) |
| 285 | + progress = float(current_step - num_warmup_steps) / float( |
| 286 | + max(1, num_training_steps - num_warmup_steps)) |
| 287 | + return max( |
| 288 | + 0.0, 0.5 * |
| 289 | + (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
| 290 | + |
| 291 | + return LambdaDecay(learning_rate, lr_lambda, last_epoch) |
| 292 | + |
| 293 | + |
| 294 | +TYPE_TO_SCHEDULER_FUNCTION = { |
| 295 | + SchedulerType.LINEAR: get_linear_schedule_with_warmup, |
| 296 | + SchedulerType.COSINE: get_cosine_schedule_with_warmup, |
| 297 | + SchedulerType.CONSTANT: get_constant_schedule, |
| 298 | + SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, |
| 299 | +} |
| 300 | + |
| 301 | + |
| 302 | +def get_scheduler( |
| 303 | + name: Union[str, SchedulerType], |
| 304 | + learning_rate: float, |
| 305 | + num_warmup_steps: Optional[int] = None, |
| 306 | + num_training_steps: Optional[int] = None, |
| 307 | +): |
| 308 | + """ |
| 309 | + Unified API to get any scheduler from its name. |
| 310 | + Args: |
| 311 | + name (`str` or `SchedulerType`): |
| 312 | + The name of the scheduler to use. |
| 313 | + learning_rate (float) |
| 314 | + The initial learning rate. It is a python float number. |
| 315 | + num_warmup_steps (`int`, *optional*): |
| 316 | + The number of warmup steps to do. This is not required by all schedulers (hence the argument being |
| 317 | + optional), the function will raise an error if it's unset and the scheduler type requires it. |
| 318 | + num_training_steps (`int``, *optional*): |
| 319 | + The number of training steps to do. This is not required by all schedulers (hence the argument being |
| 320 | + optional), the function will raise an error if it's unset and the scheduler type requires it. |
| 321 | + """ |
| 322 | + name = SchedulerType(name) |
| 323 | + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] |
| 324 | + if name == SchedulerType.CONSTANT: |
| 325 | + return schedule_func(learning_rate) |
| 326 | + |
| 327 | + # All other schedulers require `num_warmup_steps` |
| 328 | + if num_warmup_steps is None: |
| 329 | + raise ValueError( |
| 330 | + f"{name} requires `num_warmup_steps`, please provide that argument." |
| 331 | + ) |
| 332 | + |
| 333 | + if name == SchedulerType.CONSTANT_WITH_WARMUP: |
| 334 | + return schedule_func(learning_rate, num_warmup_steps=num_warmup_steps) |
| 335 | + |
| 336 | + # All other schedulers require `num_training_steps` |
| 337 | + if num_training_steps is None: |
| 338 | + raise ValueError( |
| 339 | + f"{name} requires `num_training_steps`, please provide that argument." |
| 340 | + ) |
| 341 | + |
| 342 | + return schedule_func(learning_rate, |
| 343 | + num_warmup_steps=num_warmup_steps, |
| 344 | + num_training_steps=num_training_steps) |
| 345 | + |
| 346 | + |
187 | 347 | def _secs2timedelta(secs):
|
188 | 348 | """
|
189 | 349 | convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimals
|
|
0 commit comments