From e5bb59b89ce03bb912ad95414652f7602efae21a Mon Sep 17 00:00:00 2001 From: SyedaAnshrahGillani Date: Wed, 23 Jul 2025 17:49:06 +0500 Subject: [PATCH] Add comprehensive input validation for training hyperparameters - Add extensive validation for training parameters like learning_rate, batch_size, epochs, etc. - Validate that numeric parameters are within valid ranges (e.g., positive values, betas in [0,1)) - Add reasonable bounds checking for resolution (64-4096 pixels) - Include helpful warnings for potentially problematic parameter combinations - Improve user experience by catching invalid configurations early with clear error messages This prevents runtime errors and training failures caused by invalid hyperparameters, making the training script more robust and user-friendly. --- examples/controlnet/train_controlnet_flux.py | 100 +++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index cde1c4d0be3d..4c54f149ef21 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -686,6 +686,106 @@ def parse_args(input_args=None): "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." ) + # Additional comprehensive parameter validation + if args.learning_rate <= 0: + raise ValueError("`--learning_rate` must be positive") + + if args.train_batch_size <= 0: + raise ValueError("`--train_batch_size` must be positive") + + if args.num_train_epochs <= 0: + raise ValueError("`--num_train_epochs` must be positive") + + if args.gradient_accumulation_steps <= 0: + raise ValueError("`--gradient_accumulation_steps` must be positive") + + if args.max_train_steps is not None and args.max_train_steps <= 0: + raise ValueError("`--max_train_steps` must be positive when specified") + + if args.checkpointing_steps <= 0: + raise ValueError("`--checkpointing_steps` must be positive") + + if args.validation_steps <= 0: + raise ValueError("`--validation_steps` must be positive") + + if args.num_validation_images <= 0: + raise ValueError("`--num_validation_images` must be positive") + + if args.lr_warmup_steps < 0: + raise ValueError("`--lr_warmup_steps` must be non-negative") + + if args.lr_num_cycles <= 0: + raise ValueError("`--lr_num_cycles` must be positive") + + if args.lr_power <= 0: + raise ValueError("`--lr_power` must be positive") + + if args.dataloader_num_workers < 0: + raise ValueError("`--dataloader_num_workers` must be non-negative") + + if not (0.0 <= args.adam_beta1 < 1.0): + raise ValueError("`--adam_beta1` must be in the range [0.0, 1.0)") + + if not (0.0 <= args.adam_beta2 < 1.0): + raise ValueError("`--adam_beta2` must be in the range [0.0, 1.0)") + + if args.adam_weight_decay < 0: + raise ValueError("`--adam_weight_decay` must be non-negative") + + if args.adam_epsilon <= 0: + raise ValueError("`--adam_epsilon` must be positive") + + if args.max_grad_norm <= 0: + raise ValueError("`--max_grad_norm` must be positive") + + if args.max_train_samples is not None and args.max_train_samples <= 0: + raise ValueError("`--max_train_samples` must be positive when specified") + + if args.num_double_layers <= 0: + raise ValueError("`--num_double_layers` must be positive") + + if args.num_single_layers <= 0: + raise ValueError("`--num_single_layers` must be positive") + + if args.guidance_scale < 0: + raise ValueError("`--guidance_scale` must be non-negative") + + if args.logit_std <= 0: + raise ValueError("`--logit_std` must be positive") + + if args.mode_scale <= 0: + raise ValueError("`--mode_scale` must be positive") + + if args.checkpoints_total_limit is not None and args.checkpoints_total_limit <= 0: + raise ValueError("`--checkpoints_total_limit` must be positive when specified") + + # Validate resolution is reasonable (not too small or absurdly large) + if args.resolution < 64: + raise ValueError("`--resolution` must be at least 64 pixels") + + if args.resolution > 4096: + raise ValueError("`--resolution` should not exceed 4096 pixels for memory efficiency") + + # Validate crop coordinates are non-negative + if args.crops_coords_top_left_h < 0: + raise ValueError("`--crops_coords_top_left_h` must be non-negative") + + if args.crops_coords_top_left_w < 0: + raise ValueError("`--crops_coords_top_left_w` must be non-negative") + + # Warn about potentially problematic combinations + if args.gradient_accumulation_steps > 1 and args.train_batch_size > 32: + logger.warning( + f"Large batch size ({args.train_batch_size}) with gradient accumulation ({args.gradient_accumulation_steps}) " + "may cause memory issues. Consider reducing batch size or gradient accumulation steps." + ) + + if args.learning_rate > 1e-2: + logger.warning( + f"Learning rate ({args.learning_rate}) is quite high. This may cause training instability. " + "Consider using a lower learning rate (e.g., 1e-4 to 1e-5)." + ) + return args