@@ -38,7 +38,8 @@ PW_Basis:: ~PW_Basis()
38
38
delete[] ig2igg;
39
39
delete[] gg_uniq;
40
40
#if defined(__CUDA) || defined(__ROCM)
41
- if (this ->device == " gpu" ) {
41
+ if (this ->device == " gpu" )
42
+ {
42
43
delmem_int_op ()(this ->d_is2fftixy );
43
44
}
44
45
#endif
@@ -70,9 +71,14 @@ void PW_Basis::setuptransform()
70
71
71
72
void PW_Basis::getstartgr ()
72
73
{
73
- if (this ->gamma_only ) this ->nmaxgr = ( this ->npw > (this ->nrxx +1 )/2 ) ? this ->npw : (this ->nrxx +1 )/2 ;
74
- else this ->nmaxgr = ( this ->npw > this ->nrxx ) ? this ->npw : this ->nrxx ;
75
- this ->nmaxgr = (this ->nz * this ->nst > this ->nxy * nplane) ? this ->nz * this ->nst : this ->nxy * nplane;
74
+ if (this ->gamma_only )
75
+ {
76
+ this ->nmaxgr = ( this ->npw > (this ->nrxx +1 )/2 ) ? this ->npw : (this ->nrxx +1 )/2 ;
77
+ }
78
+ else
79
+ {
80
+ this ->nmaxgr = ( this ->npw > this ->nrxx ) ? this ->npw : this ->nrxx ;
81
+ }
76
82
77
83
// ---------------------------------------------
78
84
// sum : starting plane of FFT box.
@@ -85,23 +91,35 @@ void PW_Basis::getstartgr()
85
91
// Each processor has a set of full sticks,
86
92
// 'rank_use' processor send a piece(npps[ip]) of these sticks(nst_per[rank_use])
87
93
// to all the other processors in this pool
88
- for (int ip = 0 ;ip < poolnproc; ++ip) this ->numg [ip] = this ->nst_per [poolrank] * this ->numz [ip];
94
+ for (int ip = 0 ;ip < poolnproc; ++ip)
95
+ {
96
+ this ->numg [ip] = this ->nst_per [poolrank] * this ->numz [ip];
97
+ }
89
98
90
99
91
100
// Each processor in a pool send a piece of each stick(nst_per[ip]) to
92
101
// other processors in this pool
93
102
// rank_use processor receive datas in npps[rank_p] planes.
94
- for (int ip = 0 ;ip < poolnproc; ++ip) this ->numr [ip] = this ->nst_per [ip] * this ->numz [poolrank];
103
+ for (int ip = 0 ;ip < poolnproc; ++ip)
104
+ {
105
+ this ->numr [ip] = this ->nst_per [ip] * this ->numz [poolrank];
106
+ }
95
107
96
108
97
109
// startg record the starting 'numg' position in each processor.
98
110
this ->startg [0 ] = 0 ;
99
- for (int ip = 1 ;ip < poolnproc; ++ip) this ->startg [ip] = this ->startg [ip-1 ] + this ->numg [ip-1 ];
111
+ for (int ip = 1 ;ip < poolnproc; ++ip)
112
+ {
113
+ this ->startg [ip] = this ->startg [ip-1 ] + this ->numg [ip-1 ];
114
+ }
100
115
101
116
102
117
// startr record the starting 'numr' position
103
118
this ->startr [0 ] = 0 ;
104
- for (int ip = 1 ;ip < poolnproc; ++ip) this ->startr [ip] = this ->startr [ip-1 ] + this ->numr [ip-1 ];
119
+ for (int ip = 1 ;ip < poolnproc; ++ip)
120
+ {
121
+ this ->startr [ip] = this ->startr [ip-1 ] + this ->numr [ip-1 ];
122
+ }
105
123
return ;
106
124
}
107
125
@@ -112,7 +130,10 @@ void PW_Basis::getstartgr()
112
130
// /
113
131
void PW_Basis::collect_local_pw ()
114
132
{
115
- if (this ->npw <= 0 ) return ;
133
+ if (this ->npw <= 0 )
134
+ {
135
+ return ;
136
+ }
116
137
this ->ig_gge0 = -1 ;
117
138
delete[] this ->gg ; this ->gg = new double [this ->npw ];
118
139
delete[] this ->gdirect ; this ->gdirect = new ModuleBase::Vector3<double >[this ->npw ];
@@ -127,16 +148,28 @@ void PW_Basis::collect_local_pw()
127
148
int ixy = this ->is2fftixy [is];
128
149
int ix = ixy / this ->fftny ;
129
150
int iy = ixy % this ->fftny ;
130
- if (ix >= int (this ->nx /2 ) + 1 ) ix -= this ->nx ;
131
- if (iy >= int (this ->ny /2 ) + 1 ) iy -= this ->ny ;
132
- if (iz >= int (this ->nz /2 ) + 1 ) iz -= this ->nz ;
151
+ if (ix >= int (this ->nx /2 ) + 1 )
152
+ {
153
+ ix -= this ->nx ;
154
+ }
155
+ if (iy >= int (this ->ny /2 ) + 1 )
156
+ {
157
+ iy -= this ->ny ;
158
+ }
159
+ if (iz >= int (this ->nz /2 ) + 1 )
160
+ {
161
+ iz -= this ->nz ;
162
+ }
133
163
f.x = ix;
134
164
f.y = iy;
135
165
f.z = iz;
136
166
this ->gg [ig] = f * (this ->GGT * f);
137
167
this ->gdirect [ig] = f;
138
168
this ->gcar [ig] = f * this ->G ;
139
- if (this ->gg [ig] < 1e-8 ) this ->ig_gge0 = ig;
169
+ if (this ->gg [ig] < 1e-8 )
170
+ {
171
+ this ->ig_gge0 = ig;
172
+ }
140
173
}
141
174
return ;
142
175
}
@@ -148,10 +181,13 @@ void PW_Basis::collect_local_pw()
148
181
// /
149
182
void PW_Basis::collect_uniqgg ()
150
183
{
151
- if (this ->npw <= 0 ) return ;
184
+ if (this ->npw <= 0 )
185
+ {
186
+ return ;
187
+ }
152
188
this ->ig_gge0 = -1 ;
153
189
delete[] this ->ig2igg ; this ->ig2igg = new int [this ->npw ];
154
- // add by A.s 202406
190
+
155
191
int *sortindex = new int [this ->npw ];// Reconstruct the mapping of the plane wave index ig according to the energy size of the plane waves
156
192
double *tmpgg = new double [this ->npw ];// Ranking the plane waves by energy size while ensuring that the same energy is preserved for each wave to correspond
157
193
double *tmpgg2 = new double [this ->npw ];// ranking the plane waves by energy size and removing the duplicates
@@ -164,14 +200,26 @@ void PW_Basis::collect_uniqgg()
164
200
int ixy = this ->is2fftixy [is];
165
201
int ix = ixy / this ->fftny ;
166
202
int iy = ixy % this ->fftny ;
167
- if (ix >= int (this ->nx /2 ) + 1 ) ix -= this ->nx ;
168
- if (iy >= int (this ->ny /2 ) + 1 ) iy -= this ->ny ;
169
- if (iz >= int (this ->nz /2 ) + 1 ) iz -= this ->nz ;
203
+ if (ix >= int (this ->nx /2 ) + 1 )
204
+ {
205
+ ix -= this ->nx ;
206
+ }
207
+ if (iy >= int (this ->ny /2 ) + 1 )
208
+ {
209
+ iy -= this ->ny ;
210
+ }
211
+ if (iz >= int (this ->nz /2 ) + 1 )
212
+ {
213
+ iz -= this ->nz ;
214
+ }
170
215
f.x = ix;
171
216
f.y = iy;
172
217
f.z = iz;
173
218
tmpgg[ig] = f * (this ->GGT * f);
174
- if (tmpgg[ig] < 1e-8 ) this ->ig_gge0 = ig;
219
+ if (tmpgg[ig] < 1e-8 )
220
+ {
221
+ this ->ig_gge0 = ig;
222
+ }
175
223
}
176
224
177
225
ModuleBase::GlobalFunc::ZEROS (sortindex, this ->npw );
@@ -215,7 +263,10 @@ void PW_Basis::collect_uniqgg()
215
263
void PW_Basis::getfftixy2is (int * fftixy2is) const
216
264
{
217
265
// Note: please assert when is1 >= is2, fftixy2is[is1] >= fftixy2is[is2]!
218
- for (int ixy = 0 ; ixy < this ->fftnxy ; ++ixy) fftixy2is[ixy] = -1 ;
266
+ for (int ixy = 0 ; ixy < this ->fftnxy ; ++ixy)
267
+ {
268
+ fftixy2is[ixy] = -1 ;
269
+ }
219
270
int ixy = 0 ;
220
271
for (int is = 0 ; is < this ->nst ; ++is)
221
272
{
0 commit comments