Skip to content

[Example] Add STAFNet Model for Air Quality Prediction #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
f79a3f9
Add files via upload
dylan-yin Feb 7, 2025
68b23d1
Add files via upload
dylan-yin Feb 7, 2025
d9d2b54
Update __init__.py
dylan-yin Feb 7, 2025
bfa3e69
Add files via upload
dylan-yin Feb 7, 2025
2d9dc85
Merge branch 'PaddlePaddle:develop' into dev_model
dylan-yin Feb 7, 2025
57dc7c2
Update __init__.py
dylan-yin Feb 7, 2025
fa1cdee
Merge branch 'PaddlePaddle:develop' into dev_model
dylan-yin Feb 11, 2025
ab1ae03
Update demo.py
dylan-yin Feb 11, 2025
d257a49
Update stafnet.yaml
dylan-yin Feb 11, 2025
b43c7f5
Update stafnet.py
dylan-yin Feb 11, 2025
757477a
Update stafnet_dataset.py
dylan-yin Feb 11, 2025
2b46497
Merge branch 'PaddlePaddle:develop' into dev_model
dylan-yin Feb 14, 2025
711cd36
Update stafnet.yaml
dylan-yin Feb 14, 2025
a79ad1d
Update demo.py
dylan-yin Feb 14, 2025
af96434
Update stafnet.yaml
dylan-yin Feb 14, 2025
86a9c0b
Update demo.py
dylan-yin Feb 14, 2025
e2d8b60
Merge branch 'PaddlePaddle:develop' into dev_model
dylan-yin May 12, 2025
66eeb9b
Merge branch 'PaddlePaddle:develop' into dev_model
dylan-yin Jun 5, 2025
792796a
Delete examples/demo directory
dylan-yin Jun 5, 2025
db2f093
Add files via upload
dylan-yin Jun 5, 2025
862cbfb
Add files via upload
dylan-yin Jun 5, 2025
726a026
Add files via upload
dylan-yin Jun 5, 2025
5ccc908
Add files via upload
dylan-yin Jun 5, 2025
0319c71
Add files via upload
dylan-yin Jun 6, 2025
ff226a7
Add files via upload
dylan-yin Jun 6, 2025
a5cda23
Add files via upload
dylan-yin Jun 6, 2025
79e5e6b
Add files via upload
dylan-yin Jun 6, 2025
d7e6c8a
Add files via upload
dylan-yin Jun 6, 2025
cc71ffc
Add files via upload
dylan-yin Jun 6, 2025
56fe67e
Add files via upload
dylan-yin Jun 6, 2025
296e58c
Add files via upload
dylan-yin Jun 6, 2025
99b1d1b
Add files via upload
dylan-yin Jun 6, 2025
b6f55a4
Add files via upload
dylan-yin Jun 6, 2025
3332954
Add files via upload
dylan-yin Jun 7, 2025
f5557b9
Add files via upload
dylan-yin Jun 7, 2025
878cc52
Add files via upload
dylan-yin Jun 7, 2025
c2e88bd
Add files via upload
dylan-yin Jun 7, 2025
bc748cd
Update README.md
dylan-yin Jun 10, 2025
c7d4426
Add files via upload
dylan-yin Jun 10, 2025
a7b2eb1
Update stafnet.md
dylan-yin Jun 11, 2025
c025ab2
Update stafnet.md
dylan-yin Jun 12, 2025
ed22102
Merge branch 'develop' into dev_model
dylan-yin Jun 12, 2025
8d2433d
Update docs/zh/examples/stafnet.md
HydrogenSulfate Jun 13, 2025
3a87833
Update examples/stafnet/stafnet.py
HydrogenSulfate Jun 13, 2025
88c2957
Update examples/stafnet/stafnet.py
HydrogenSulfate Jun 13, 2025
ec92f69
Update stafnet.py
dylan-yin Jun 16, 2025
454729f
Update dataset.md
dylan-yin Jun 16, 2025
b1b4ad3
Update arch.md
dylan-yin Jun 16, 2025
9c658dc
Update arch.md
dylan-yin Jun 16, 2025
9195278
Update arch.md
dylan-yin Jun 16, 2025
3e35a0d
Merge branch 'develop' into dev_model
dylan-yin Jun 20, 2025
803eb28
Update stafnet_dataset.py
dylan-yin Jun 23, 2025
8599724
Update stafnet.py
dylan-yin Jun 23, 2025
59cd4e0
Update stafnet_dataset.py
dylan-yin Jun 23, 2025
84f91bf
Merge branch 'PaddlePaddle:develop' into dev_model
dylan-yin Jun 23, 2025
5039b14
style: Apply automatic formatting via pre-commit
dylan-yin Jun 23, 2025
d38fbaa
Add files via upload
dylan-yin Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/demo/conf/stafnet.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件建议使用vscode的yaml插件格式化一下,或者提交前用pre-commit格式化:https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/development/#1

pre-commit run --files xxx.yaml

Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_
hydra:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

配置文件开头请加上以下字段:

defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

run:
# dynamic output directory according to running time and override name
dir: outputs_chip_heat/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname}
job:
name: ${mode} # name of logfile
chdir: false # keep current working directory unchanged
callbacks:
init_callback:
_target_: ppsci.utils.callbacks.InitCallback
sweep:
# output directory for multirun
dir: ${hydra.run.dir}
subdir: ./

# general settings
mode: train # running mode: train/eval
seed: 42
output_dir: ${hydra:run.dir}
log_freq: 20
# dataset setting
STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" #
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里的路径是否能改成相对路径?比如 ./dataset/train_data.pkl,其余的路径字段也是,建议改为相对路径,并去掉用户名
  2. STAFNet_DATA_PATH是否应该放到DATASET字段下?

DATASET:
label_keys: ["label"]
data_dir: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. data_dir为什么是具体文件路径而不是某个文件夹路径?
  2. 此处的路径是否跟STAFNet_DATA_PATH重复了?



MODEL:
input_keys: ["aq_train_data","mete_train_data",]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_keys: ["aq_train_data","mete_train_data",]
input_keys: [aq_train_data, mete_train_data]

output_keys: ["label"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_keys: ["label"]
output_keys: [label]

output_attention: True
seq_len: 72
pred_len: 48
aq_gat_node_features: 7
aq_gat_node_num: 35
mete_gat_node_features: 7
mete_gat_node_num: 18
gat_hidden_dim: 32
gat_edge_dim: 3
e_layers: 1
enc_in: 7
dec_in: 7
c_out: 7
d_model: 16
embed: "fixed"
freq: "t"
dropout: 0.05
factor: 3
n_heads: 4
d_ff: 32
num_kernels: 6
top_k: 4

# training settings
TRAIN:
epochs: 100
iters_per_epoch: 400
save_freq: 10
eval_during_train: true
eval_freq: 10
batch_size: 1
lr_scheduler:
epochs: ${TRAIN.epochs}
iters_per_epoch: ${TRAIN.iters_per_epoch}
learning_rate: 0.001
step_size: 10
gamma: 0.9
pretrained_model_path: null
checkpoint_path: null

EVAL:
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl"
eval_data_path: ./dataset/val_data.pkl

pretrained_model_path: null
compute_metric_by_batch: false
eval_with_no_grad: true
batch_size: 1
153 changes: 153 additions & 0 deletions examples/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import ppsci
from ppsci.utils import logger
from omegaconf import DictConfig
import hydra
import paddle
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn
import multiprocessing

def train(cfg: DictConfig):
# set model
model = ppsci.arch.STAFNet(**cfg.MODEL)
train_dataloader_cfg = {
"dataset": {
"name": "STAFNetDataset",
"file_path": cfg.DATASET.data_dir,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"seq_len": cfg.MODEL.seq_len,
"pred_len": cfg.MODEL.pred_len,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除多余空行

Suggested change

},
"batch_size": cfg.TRAIN.batch_size,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
"collate_fn": gat_lstmcollate_fn,
}
eval_dataloader_cfg= {
"dataset": {
"name": "STAFNetDataset",
"file_path": cfg.EVAL.eval_data_path,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"seq_len": cfg.MODEL.seq_len,
"pred_len": cfg.MODEL.pred_len,
},
"batch_size": cfg.TRAIN.batch_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是EVAL?

"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个"sampler"字段是否可以删掉?eval应该不需要shuffle

"collate_fn": gat_lstmcollate_fn,
}

sup_constraint = ppsci.constraint.SupervisedConstraint(
train_dataloader_cfg,
loss=ppsci.loss.MSELoss("mean"),
name="STAFNet_Sup",
)
constraint = {sup_constraint.name: sup_constraint}
sup_validator = ppsci.validate.SupervisedValidator(
eval_dataloader_cfg,
loss=ppsci.loss.MSELoss("mean"),
metric={"MSE": ppsci.metric.MSE()},
name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

# set optimizer
lr_scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.lr_scheduler)()
LEARNING_RATE = cfg.TRAIN.lr_scheduler.learning_rate
optimizer = ppsci.optimizer.Adam(LEARNING_RATE)(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在paddle里,如果你的学习率是lr_scheduler,那么就需要把optimizer的learning_rate设置为lr_scheduler,而不是初始学习率

output_dir = cfg.output_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_dir = cfg.output_dir

ITERS_PER_EPOCH = len(sup_constraint.data_loader)

# initialize solver
solver = ppsci.solver.Solver(
model,
constraint,
output_dir,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_dir,
cfg.output_dir,

optimizer,
lr_scheduler,
cfg.TRAIN.epochs,
ITERS_PER_EPOCH,
eval_during_train=cfg.TRAIN.eval_during_train,
seed=cfg.seed,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seed=cfg.seed,

validator=validator,
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# train model
solver.train()

def evaluate(cfg: DictConfig):
"""
Validate after training an epoch

:param epoch: Integer, current training epoch.
:return: A log that contains information about validation
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除注释

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Validate after training an epoch
:param epoch: Integer, current training epoch.
:return: A log that contains information about validation
"""

model = ppsci.arch.STAFNet(**cfg.MODEL)
eval_dataloader_cfg= {
"dataset": {
"name": "STAFNetDataset",
"file_path": cfg.EVAL.eval_data_path,
"input_keys": cfg.MODEL.input_keys,
"label_keys": cfg.MODEL.output_keys,
"seq_len": cfg.MODEL.seq_len,
"pred_len": cfg.MODEL.pred_len,
},
"batch_size": cfg.TRAIN.batch_size,
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},

"collate_fn": gat_lstmcollate_fn,
}
sup_validator = ppsci.validate.SupervisedValidator(
eval_dataloader_cfg,
loss=ppsci.loss.MSELoss("mean"),
metric={"MSE": ppsci.metric.MSE()},
name="Sup_Validator",
)
validator = {sup_validator.name: sup_validator}

# initialize solver
solver = ppsci.solver.Solver(
model,
validator=validator,
cfg=cfg,
pretrained_model_path=cfg.EVAL.pretrained_model_path,
compute_metric_by_batch=cfg.EVAL.compute_metric_by_batch,
eval_with_no_grad=cfg.EVAL.eval_with_no_grad,
)

# evaluate model
solver.eval()


@hydra.main(version_base=None, config_path="./conf", config_name="stafnet.yaml")
def main(cfg: DictConfig):
if cfg.mode == "train":
train(cfg)
elif cfg.mode == "eval":
evaluate(cfg)
else:
raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'")

if __name__ == "__main__":
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_example"
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以删除,output_dir会由ppsci.utils.callbacks.InitCallback自动创建:

logger.init_logger(
"ppsci",
osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log")
if full_cfg.output_dir and full_cfg.mode not in ["export", "infer"]
else None,
full_cfg.log_level,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_example"
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")

multiprocessing.set_start_method("spawn")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句代码是什么作用?paddle的多卡训练不需要这样吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3)

这是多卡训练时才会出现的吗?你的训练命令是什么呢?如果按照我们文档里给的数据并行命令,也会报错吗?
image


main()
2 changes: 2 additions & 0 deletions ppsci/arch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ppsci.utils import logger # isort:skip
from ppsci.arch.regdgcnn import RegDGCNN # isort:skip
from ppsci.arch.ifm_mlp import IFMMLP # isort:skip
from ppsci.arch.stafnet import STAFNet # isort:skip

__all__ = [
"MoFlowNet",
Expand Down Expand Up @@ -111,6 +112,7 @@
"VelocityGenerator",
"RegDGCNN",
"IFMMLP",
"STAFNet",
]


Expand Down
Loading