Skip to content

Commit 0860a51

Browse files
[Fix] Remove padding data at the end of solver.predict (#669)
* fix(test=document_fix) * remove redundant padding data of input_data in solver.predict * remove padding data in solver.predict after prediction
1 parent 84b4f8d commit 0860a51

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

docs/zh/examples/heat_pinn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
| 预训练模型 | 指标 |
1616
|:--| :--|
17-
| [heat_pinn_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams) | norm MSE loss between the FDM and PINN is 0.0013415 |
17+
| [heat_pinn_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/heat_pinn/heat_pinn_pretrained.pdparams) | norm MSE loss between the FDM and PINN is 1.30174e-03 |
1818

1919
## 1. 背景简介
2020

examples/heat_pinn/heat_pinn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ def train(cfg: DictConfig):
139139
N_EVAL, N_EVAL
140140
)
141141
fdm_output = fdm.solve(N_EVAL, 1).T
142-
mes_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
143-
logger.info(f"The norm MSE loss between the FDM and PINN is {mes_loss}")
142+
mse_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
143+
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss}")
144144

145145
x = input_data["x"].reshape(N_EVAL, N_EVAL)
146146
y = input_data["y"].reshape(N_EVAL, N_EVAL)
@@ -237,8 +237,8 @@ def evaluate(cfg: DictConfig):
237237
"u"
238238
].reshape(N_EVAL, N_EVAL)
239239
fdm_output = fdm.solve(N_EVAL, 1).T
240-
mes_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
241-
logger.info(f"The norm MSE loss between the FDM and PINN is {mes_loss}")
240+
mse_loss = np.mean(np.square(pinn_output - (fdm_output / 75.0)))
241+
logger.info(f"The norm MSE loss between the FDM and PINN is {mse_loss:.5e}")
242242

243243
x = input_data["x"].reshape(N_EVAL, N_EVAL)
244244
y = input_data["y"].reshape(N_EVAL, N_EVAL)

ppsci/solver/solver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def __init__(
234234
logger.warning(
235235
f"Detected 'world_size'({self.world_size}) > 1, it is recommended to "
236236
"scale up the learning rate and reduce the 'epochs' or "
237-
"'iters_per_epoch' according to the 'world_size' both linearly."
237+
"'iters_per_epoch' according to the 'world_size' both linearly if you "
238+
"are training model."
238239
)
239240

240241
# load pretrained model, usually used for transfer learning
@@ -511,8 +512,6 @@ def predict(
511512
# pad with last element if `num_samples` is not divisible by `world_size`
512513
# ensuring every device get same number of data.
513514
if num_pad > 0:
514-
# NOTE: This will modify input_dict inplace by appending padding data at the
515-
# end if num_pad > 0.
516515
for k, v in input_dict.items():
517516
repeat_times = (num_pad, *(1 for _ in range(v.ndim - 1)))
518517
if isinstance(v, np.ndarray):
@@ -596,6 +595,9 @@ def predict(
596595
pred_dict = {
597596
key: value[:num_samples] for key, value in pred_dict.items()
598597
}
598+
# NOTE: Discard padding data in input_dict for consistency
599+
for k in input_dict:
600+
input_dict[k] = input_dict[k][:num_samples]
599601

600602
# convert to numpy ndarray if specified
601603
if return_numpy:

0 commit comments

Comments
 (0)