-
Notifications
You must be signed in to change notification settings - Fork 220
[Example] Add STAFNet Model for Air Quality Prediction #1070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体项目请使用pre-commit格式化一边
examples/demo/conf/stafnet.yaml
Outdated
@@ -0,0 +1,136 @@ | |||
hydra: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
配置文件开头请加上以下字段:
PaddleScience/examples/ldc/conf/ldc_2d_Re3200_piratenet.yaml
Lines 1 to 9 in fad6927
defaults: | |
- ppsci_default | |
- TRAIN: train_default | |
- TRAIN/ema: ema_default | |
- TRAIN/swa: swa_default | |
- EVAL: eval_default | |
- INFER: infer_default | |
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | |
- _self_ |
examples/demo/conf/stafnet.yaml
Outdated
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以删了
examples/demo/conf/stafnet.yaml
Outdated
STAFNet_DATA_PATH: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl" # | ||
DATASET: | ||
label_keys: ["label"] | ||
data_dir: "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl" | ||
STAFNet_DATA_args: { | ||
"data_dir": "/data6/home/yinhang2021/workspace/SATFNet/data/2020-2023_new/train_data.pkl", | ||
"batch_size": 1, | ||
"shuffle": True, | ||
"num_workers": 0, | ||
"training": True | ||
} | ||
|
||
|
||
|
||
# "data_dir": "data/2020-2023_new/train_data.pkl", | ||
# "batch_size": 32, | ||
# "shuffle": True, | ||
# "num_workers": 0, | ||
# "training": True | ||
# model settings | ||
# MODEL: # | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议改为相对路径,以./data/...
开头即可
examples/demo/conf/stafnet.yaml
Outdated
# "data_dir": "data/2020-2023_new/train_data.pkl", | ||
# "batch_size": 32, | ||
# "shuffle": True, | ||
# "num_workers": 0, | ||
# "training": True | ||
# model settings | ||
# MODEL: # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个注释如果没用的可以删除
examples/demo/conf/stafnet.yaml
Outdated
# configs: { | ||
# "task_name": "forecast", | ||
# "output_attention": False, | ||
# "seq_len": 72, | ||
# "label_len": 24, | ||
# "pred_len": 48, | ||
|
||
# "aq_gat_node_features" : 7, | ||
# "aq_gat_node_num": 35, | ||
|
||
# "mete_gat_node_features" : 7, | ||
# "mete_gat_node_num": 18, | ||
|
||
# "gat_hidden_dim": 32, | ||
# "gat_edge_dim": 3, | ||
# "gat_embed_dim": 32, | ||
|
||
# "e_layers": 1, | ||
# "enc_in": 7, | ||
# "dec_in": 7, | ||
# "c_out": 7, | ||
# "d_model": 16 , | ||
# "embed": "fixed", | ||
# "freq": "t", | ||
# "dropout": 0.05, | ||
# "factor": 3, | ||
# "n_heads": 4, | ||
|
||
# "d_ff": 32 , | ||
# "num_kernels": 6, | ||
# "top_k": 4 | ||
# } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,如果没用可以删除
examples/demo/demo.py
Outdated
|
||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
避免连续空行
examples/demo/demo.py
Outdated
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(42) | ||
# set output directory | ||
OUTPUT_DIR = "./output_example" | ||
# initialize logger | ||
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以删除,output_dir会由ppsci.utils.callbacks.InitCallback
自动创建:
PaddleScience/ppsci/utils/callbacks.py
Lines 90 to 96 in fad6927
logger.init_logger( | |
"ppsci", | |
osp.join(full_cfg.output_dir, f"{full_cfg.mode}.log") | |
if full_cfg.output_dir and full_cfg.mode not in ["export", "infer"] | |
else None, | |
full_cfg.log_level, | |
) |
from typing import Tuple | ||
|
||
class Inception_Block_V1(paddle.nn.Layer): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
冗余的空行请删除,下同
examples/demo/conf/stafnet.yaml
Outdated
output_dir: ${hydra:run.dir} | ||
log_freq: 20 | ||
# dataset setting | ||
STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里的路径是否能改成相对路径?比如
./dataset/train_data.pkl
,其余的路径字段也是,建议改为相对路径,并去掉用户名 - STAFNet_DATA_PATH是否应该放到DATASET字段下?
examples/demo/conf/stafnet.yaml
Outdated
|
||
|
||
MODEL: | ||
input_keys: ["aq_train_data","mete_train_data",] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_keys: ["aq_train_data","mete_train_data",] | |
input_keys: [aq_train_data, mete_train_data] |
examples/demo/conf/stafnet.yaml
Outdated
|
||
MODEL: | ||
input_keys: ["aq_train_data","mete_train_data",] | ||
output_keys: ["label"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_keys: ["label"] | |
output_keys: [label] |
examples/demo/conf/stafnet.yaml
Outdated
checkpoint_path: null | ||
|
||
EVAL: | ||
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eval_data_path: "/data6/home/yinhang2021/dataset/chongqing_1921/val_data.pkl" | |
eval_data_path: ./dataset/val_data.pkl |
examples/demo/conf/stafnet.yaml
Outdated
STAFNet_DATA_PATH: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" # | ||
DATASET: | ||
label_keys: ["label"] | ||
data_dir: "/data6/home/yinhang2021/dataset/chongqing_1921/train_data.pkl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- data_dir为什么是具体文件路径而不是某个文件夹路径?
- 此处的路径是否跟STAFNet_DATA_PATH重复了?
examples/demo/demo.py
Outdated
cfg.TRAIN.epochs, | ||
ITERS_PER_EPOCH, | ||
eval_during_train=cfg.TRAIN.eval_during_train, | ||
seed=cfg.seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed=cfg.seed, |
examples/demo/demo.py
Outdated
""" | ||
Validate after training an epoch | ||
|
||
:param epoch: Integer, current training epoch. | ||
:return: A log that contains information about validation | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
Validate after training an epoch | |
:param epoch: Integer, current training epoch. | |
:return: A log that contains information about validation | |
""" |
examples/demo/demo.py
Outdated
"sampler": { | ||
"name": "BatchSampler", | ||
"drop_last": False, | ||
"shuffle": True, | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"sampler": { | |
"name": "BatchSampler", | |
"drop_last": False, | |
"shuffle": True, | |
}, |
examples/demo/demo.py
Outdated
# set random seed for reproducibility | ||
ppsci.utils.misc.set_random_seed(42) | ||
# set output directory | ||
OUTPUT_DIR = "./output_example" | ||
# initialize logger | ||
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# set random seed for reproducibility | |
ppsci.utils.misc.set_random_seed(42) | |
# set output directory | |
OUTPUT_DIR = "./output_example" | |
# initialize logger | |
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") |
examples/demo/demo.py
Outdated
OUTPUT_DIR = "./output_example" | ||
# initialize logger | ||
logger.init_logger("ppsci", f"{OUTPUT_DIR}/train.log", "info") | ||
multiprocessing.set_start_method("spawn") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这句代码是什么作用?paddle的多卡训练不需要这样吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我这边如果不加 multiprocessing.set_start_method("spawn"),会出现cuda error(3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs/zh/examples/stafnet.md
Outdated
|
||
```` | ||
``` sh | ||
python stafnet.py TRAIN_DIR="Your train dataset path" eval_data_path="Your evaluate dataset path" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python stafnet.py DATASET.data_dir="Your train dataset path" EVAL.eval_data_path="Your evaluate dataset path"
docs/zh/examples/stafnet.md
Outdated
|
||
```` | ||
``` sh | ||
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl -P ./dataset/
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams"
docs/zh/examples/stafnet.md
Outdated
|
||
针对空气质量预测提出了许多研究。早期的方法侧重于学习单个观测站观测数据的时间模式,而放弃了观测站之间的空间关系。最近,由于图神经网络(GNN)在处理非欧几里得图结构方面的有效性,越来越多的方法采用 GNN 来模拟空间依赖关系。这些方法将车站位置作为上下文特征,隐含地建立空间依赖关系模型,没有充分利用车站位置和车站之间关系所包含的宝贵空间信息。此外,现有的时空 GNN 缺乏在错位图中融合多个特征的能力。因此,大多数方法都需要额外的插值算法,以便在早期阶段将气象特征与 AQ 特征进行对齐和连接。这种方法消除了空气质量站和气象站之间的空间和结构信息,还可能引入噪声导致误差累积。此外,在空气质量预测中利用多周期性的问题仍未得到探索。 | ||
|
||
该案例研究时空图网络网络在空气质量预测方向上的应用。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉一个‘网络’
docs/zh/examples/stafnet.md
Outdated
``` | ||
--8<-- | ||
examples/stafnet/stafnet.py:11 | ||
--8<-- | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码不显示,参考其他案例文档,在第一行添加py linenums="11" title="examples/stafnet/stafnet.py",下面的代码块均修改一下
docs/zh/examples/stafnet.md
Outdated
|
||
### 3.7 评估器构建 | ||
|
||
在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](https://github.com/PaddlePaddle/PaddleScience/blob/develop/docs/zh/examples/cfdgcn.md#34) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
构建过程与 [3.6 约束构建](#36) 类似
docs/zh/examples/stafnet.md
Outdated
|
||
## 5. 参考资料 | ||
|
||
- [STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction]([STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction | SpringerLink](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[STAFNet: Spatiotemporal-Aware Fusion Network for Air Quality Prediction](https://link.springer.com/chapter/10.1007/978-3-031-78186-5_22)
examples/stafnet/conf/stafnet.yaml
Outdated
output_attention: True | ||
seq_len: 72 | ||
pred_len: 48 | ||
aq_gat_node_features: 7 | ||
aq_gat_node_num: 35 | ||
mete_gat_node_features: 7 | ||
mete_gat_node_num: 18 | ||
gat_hidden_dim: 32 | ||
gat_edge_dim: 3 | ||
e_layers: 2 | ||
enc_in: 7 | ||
dec_in: 7 | ||
c_out: 7 | ||
d_model: 32 | ||
embed: "fixed" | ||
freq: "t" | ||
dropout: 0.05 | ||
factor: 3 | ||
n_heads: 4 | ||
d_ff: 64 | ||
num_kernels: 6 | ||
top_k: 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_attention: true
seq_len: 72
pred_len: 48
aq_gat_node_features: 7
aq_gat_node_num: 35
mete_gat_node_features: 7
mete_gat_node_num: 18
gat_hidden_dim: 32
gat_edge_dim: 3
e_layers: 1
enc_in: 7
dec_in: 7
c_out: 7
d_model: 16
embed: fixed
freq: t
dropout: 0.05
factor: 3
n_heads: 4
d_ff: 32
num_kernels: 6
top_k: 4
examples/stafnet/conf/stafnet.yaml
Outdated
checkpoint_path: null | ||
|
||
EVAL: | ||
eval_data_path: ./dataset/new_val_data.pkl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
./dataset/val_data.pkl
examples/stafnet/stafnet.py
Outdated
from omegaconf import DictConfig | ||
import hydra | ||
import paddle | ||
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from ppsci.data.dataset.stafnet_dataset import gat_lstmcollate_fn去掉吧,并没有用到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DATASET: | ||
label_keys: [label] | ||
data_dir: ./dataset/train_data.pkl | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/stafnet/conf/stafnet.yaml
Outdated
pretrained_model_path: null | ||
compute_metric_by_batch: false | ||
eval_with_no_grad: true | ||
batch_size: 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件末尾换行
docs/zh/examples/stafnet.md
Outdated
``` sh | ||
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl -P ./dataset/ | ||
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" | ||
python stafnet.py mode=eval EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/models/stafnet/stafnet.pdparams" EVAL.pretrained_model_path="https://paddle-org.bj.bcebos.com/paddlescience/datasets/stafnet/val_data.pkl" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉第21行
docs/zh/examples/stafnet.md
Outdated
--8<-- | ||
examples/stafnet/stafnet.py:11 | ||
--8<-- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
插入的代码还需再检查一下,插入的代码和文字说明不太对应
docs/zh/examples/stafnet.md
Outdated
|
||
### 3.7 评估器构建 | ||
|
||
在训练过程中通常会按一定轮数间隔,用验证集(测试集)评估当前模型的训练情况,因此使用 `ppsci.validate.SupervisedValidator` 构建评估器,构建过程与 [约束构建](https://github.com/PaddlePaddle/PaddleScience/blob/develop/docs/zh/examples/stafnet.md#36) 类似,只需把数据目录改为测试集的目录,并在配置文件中设置 `EVAL.batch_size=1` 即可。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
构建过程与 [3.6 约束构建](#36) 类似
docs/zh/examples/stafnet.md
Outdated
|
||
```py linenums="10" title="examples/stafnet/stafnet.py" | ||
--8<-- | ||
examples/stafnet/stafnet.py:10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/stafnet/stafnet.py:10:10
docs/zh/examples/stafnet.md
Outdated
|
||
## 4. 完整代码 | ||
|
||
```python py linenums="1" title="examples/stafnet/stafnet.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉python
docs/zh/examples/stafnet.md
Outdated
|
||
``` py linenums="62" title="examples/stafnet/stafnet.py" | ||
--8<-- | ||
examples/stafnet/stafnet.py:62 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/stafnet/stafnet.py:62:62
docs/zh/examples/stafnet.md
Outdated
|
||
``` py linenums="55" title="examples/stafnet/stafnet.py" | ||
--8<-- | ||
examples/stafnet/stafnet.py:55 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/stafnet/stafnet.py:55:55
ppsci/arch/stafnet.py
Outdated
|
||
import numpy as np | ||
import paddle | ||
from pgl.nn.conv import GATv2Conv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还麻烦在docs/zh/api/data/dataset.md加上这个类
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还麻烦在arch.md加上这个类
PR types
PR changes
Describe