Skip to content

[Example] Add STAFNet Model for Air Quality Prediction #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: develop
Choose a base branch
from

Conversation

dylan-yin
Copy link
Contributor

PR types

PR changes

Describe

@CLAassistant
Copy link

CLAassistant commented Feb 7, 2025

CLA assistant check
All committers have signed the CLA.

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体项目请使用pre-commit格式化一边

@@ -0,0 +1,136 @@
hydra:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

配置文件开头请加上以下字段:

defaults:
- ppsci_default
- TRAIN: train_default
- TRAIN/ema: ema_default
- TRAIN/swa: swa_default
- EVAL: eval_default
- INFER: infer_default
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default
- _self_

Comment on lines 8 to 16
config:
override_dirname:
exclude_keys:
- TRAIN.checkpoint_path
- TRAIN.pretrained_model_path
- EVAL.pretrained_model_path
- mode
- output_dir
- log_freq
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以删了

Comment on lines 31 to 52
STAFNet_DATA_PATH: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl" #
DATASET:
label_keys: ["label"]
data_dir: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl"
STAFNet_DATA_args: {
"data_dir": "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl",
"batch_size": 1,
"shuffle": True,
"num_workers": 0,
"training": True
}



# "data_dir": "data/2020-2023_new/train_data.pkl",
# "batch_size": 32,
# "shuffle": True,
# "num_workers": 0,
# "training": True
# model settings
# MODEL: #

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议改为相对路径,以./data/...开头即可

Comment on lines 45 to 51
# "data_dir": "data/2020-2023_new/train_data.pkl",
# "batch_size": 32,
# "shuffle": True,
# "num_workers": 0,
# "training": True
# model settings
# MODEL: #
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个注释如果没用的可以删除

Comment on lines 82 to 113
# configs: {
# "task_name": "forecast",
# "output_attention": False,
# "seq_len": 72,
# "label_len": 24,
# "pred_len": 48,

# "aq_gat_node_features" : 7,
# "aq_gat_node_num": 35,

# "mete_gat_node_features" : 7,
# "mete_gat_node_num": 18,

# "gat_hidden_dim": 32,
# "gat_edge_dim": 3,
# "gat_embed_dim": 32,

# "e_layers": 1,
# "enc_in": 7,
# "dec_in": 7,
# "c_out": 7,
# "d_model": 16 ,
# "embed": "fixed",
# "freq": "t",
# "dropout": 0.05,
# "factor": 3,
# "n_heads": 4,

# "d_ff": 32 ,
# "num_kernels": 6,
# "top_k": 4
# }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,如果没用可以删除

Comment on lines 117 to 120




Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

避免连续空行

Comment on lines 122 to 127
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_example"
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以删除,output_dir会由ppsci.utils.callbacks.InitCallback自动创建:

logger.init_logger(
"ppsci",
osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log")
if full_cfg.output_dir and full_cfg.mode not in ["export", "infer"]
else None,
full_cfg.log_level,
)

from typing import Tuple

class Inception_Block_V1(paddle.nn.Layer):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

冗余的空行请删除,下同

@HydrogenSulfate HydrogenSulfate changed the title Add STAFNet Model for Air Quality Prediction [Example] Add STAFNet Model for Air Quality Prediction Feb 12, 2025
output_dir: ${hydra:run.dir}
log_freq: 20
# dataset setting
STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" #
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里的路径是否能改成相对路径?比如 ./dataset/train_data.pkl,其余的路径字段也是,建议改为相对路径,并去掉用户名
  2. STAFNet_DATA_PATH是否应该放到DATASET字段下?



MODEL:
input_keys: ["aq_train_data","mete_train_data",]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_keys: ["aq_train_data","mete_train_data",]
input_keys: [aq_train_data, mete_train_data]


MODEL:
input_keys: ["aq_train_data","mete_train_data",]
output_keys: ["label"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_keys: ["label"]
output_keys: [label]

checkpoint_path: null

EVAL:
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl"
eval_data_path: ./dataset/val_data.pkl

STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" #
DATASET:
label_keys: ["label"]
data_dir: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. data_dir为什么是具体文件路径而不是某个文件夹路径?
  2. 此处的路径是否跟STAFNet_DATA_PATH重复了?

cfg.TRAIN.epochs,
ITERS_PER_EPOCH,
eval_during_train=cfg.TRAIN.eval_during_train,
seed=cfg.seed,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seed=cfg.seed,

Comment on lines 89 to 94
"""
Validate after training an epoch

:param epoch: Integer, current training epoch.
:return: A log that contains information about validation
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Validate after training an epoch
:param epoch: Integer, current training epoch.
:return: A log that contains information about validation
"""

Comment on lines 106 to 110
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"sampler": {
"name": "BatchSampler",
"drop_last": False,
"shuffle": True,
},

Comment on lines 145 to 150
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_example"
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(42)
# set output directory
OUTPUT_DIR = "./output_example"
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")

OUTPUT_DIR = "./output_example"
# initialize logger
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info")
multiprocessing.set_start_method("spawn")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这句代码是什么作用?paddle的多卡训练不需要这样吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3)

这是多卡训练时才会出现的吗?你的训练命令是什么呢?如果按照我们文档里给的数据并行命令,也会报错吗?
image


````
``` sh
python stafnet.py TRAIN_DIR="Your train dataset path" eval_data_path="Your evaluate dataset path"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python stafnet.py DATASET.data_dir="Your train dataset path" EVAL.eval_data_path="Your evaluate dataset path"


````
``` sh
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


针对空气质量预测提出了许多研究。早期的方法侧重于学习单个观测站观测数据的时间模式,而放弃了观测站之间的空间关系。最近,由于图神经网络(GNN)在处理非欧几里得图结构方面的有效性,越来越多的方法采用 GNN 来模拟空间依赖关系。这些方法将车站位置作为上下文特征,隐含地建立空间依赖关系模型,没有充分利用车站位置和车站之间关系所包含的宝贵空间信息。此外,现有的时空 GNN 缺乏在错位图中融合多个特征的能力。因此,大多数方法都需要额外的插值算法,以便在早期阶段将气象特征与 AQ 特征进行对齐和连接。这种方法消除了空气质量站和气象站之间的空间和结构信息,还可能引入噪声导致误差累积。此外,在空气质量预测中利用多周期性的问题仍未得到探索。

该案例研究时空图网络网络在空气质量预测方向上的应用。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉一个‘网络’

Comment on lines 77 to 81
```
--8<--
examples/stafnet/stafnet.py:11
--8<--
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码不显示,参考其他案例文档,在第一行添加py linenums="11" title="examples/stafnet/stafnet.py",下面的代码块均修改一下


### 3.7 评估器构建

在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](https://github.com/PaddlePaddle/PaddleScience/blob/develop/docs/zh/examples/cfdgcn.md#34) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。
Copy link
Contributor

@liaoxin2 liaoxin2 Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

构建过程与 [3.6 约束构建](#36) 类似


## 5. 参考资料

- [STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction]([STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction | SpringerLink](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22))
Copy link
Contributor

@liaoxin2 liaoxin2 Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22)

Comment on lines 39 to 60
output_attention: True
seq_len: 72
pred_len: 48
aq_gat_node_features: 7
aq_gat_node_num: 35
mete_gat_node_features: 7
mete_gat_node_num: 18
gat_hidden_dim: 32
gat_edge_dim: 3
e_layers: 2
enc_in: 7
dec_in: 7
c_out: 7
d_model: 32
embed: "fixed"
freq: "t"
dropout: 0.05
factor: 3
n_heads: 4
d_ff: 64
num_kernels: 6
top_k: 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_attention: true
seq_len: 72
pred_len: 48
aq_gat_node_features: 7
aq_gat_node_num: 35
mete_gat_node_features: 7
mete_gat_node_num: 18
gat_hidden_dim: 32
gat_edge_dim: 3
e_layers: 1
enc_in: 7
dec_in: 7
c_out: 7
d_model: 16
embed: fixed
freq: t
dropout: 0.05
factor: 3
n_heads: 4
d_ff: 32
num_kernels: 6
top_k: 4

checkpoint_path: null

EVAL:
eval_data_path: ./dataset/new_val_data.pkl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

./dataset/val_data.pkl

from omegaconf import DictConfig
import hydra
import paddle
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn去掉吧,并没有用到

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所有提交的代码请使用pre-commit格式化一下

image

DATASET:
label_keys: [label]
data_dir: ./dataset/train_data.pkl

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

pretrained_model_path: null
compute_metric_by_batch: false
eval_with_no_grad: true
batch_size: 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件末尾换行

``` sh
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl -P ./dataset/
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams"
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉第21行

Comment on lines 80 to 82
--8<--
examples/stafnet/stafnet.py:11
--8<--
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

插入的代码还需再检查一下,插入的代码和文字说明不太对应


### 3.7 评估器构建

在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](https://github.com/PaddlePaddle/PaddleScience/blob/develop/docs/zh/examples/stafnet.md#36) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

构建过程与 [3.6 约束构建](#36) 类似


```py linenums="10" title="examples/stafnet/stafnet.py"
--8<--
examples/stafnet/stafnet.py:10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples/stafnet/stafnet.py:10:10


## 4. 完整代码

```python py linenums="1" title="examples/stafnet/stafnet.py"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉python


``` py linenums="62" title="examples/stafnet/stafnet.py"
--8<--
examples/stafnet/stafnet.py:62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples/stafnet/stafnet.py:62:62


``` py linenums="55" title="examples/stafnet/stafnet.py"
--8<--
examples/stafnet/stafnet.py:55
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

examples/stafnet/stafnet.py:55:55


import numpy as np
import paddle
from pgl.nn.conv import GATv2Conv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为pgl不是pdsci的依赖库,所以还麻烦按照以下方式修改导入方式,避免影响ppsci其它代码
image
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还麻烦在docs/zh/api/data/dataset.md加上这个类

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还麻烦在arch.md加上这个类

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants