Skip to content

Commit 053b5d8

Browse files
authored
Merge branch 'develop' into tensor-lapack-extend
2 parents 6815bd8 + e7b5c12 commit 053b5d8

File tree

5 files changed

+79
-28
lines changed

5 files changed

+79
-28
lines changed

source/module_basis/module_pw/pw_basis.cpp

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ PW_Basis:: ~PW_Basis()
3838
delete[] ig2igg;
3939
delete[] gg_uniq;
4040
#if defined(__CUDA) || defined(__ROCM)
41-
if (this->device == "gpu") {
41+
if (this->device == "gpu")
42+
{
4243
delmem_int_op()(this->d_is2fftixy);
4344
}
4445
#endif
@@ -70,9 +71,14 @@ void PW_Basis::setuptransform()
7071

7172
void PW_Basis::getstartgr()
7273
{
73-
if(this->gamma_only) this->nmaxgr = ( this->npw > (this->nrxx+1)/2 ) ? this->npw : (this->nrxx+1)/2;
74-
else this->nmaxgr = ( this->npw > this->nrxx ) ? this->npw : this->nrxx;
75-
this->nmaxgr = (this->nz * this->nst > this->nxy * nplane) ? this->nz * this->nst : this->nxy * nplane;
74+
if(this->gamma_only)
75+
{
76+
this->nmaxgr = ( this->npw > (this->nrxx+1)/2 ) ? this->npw : (this->nrxx+1)/2;
77+
}
78+
else
79+
{
80+
this->nmaxgr = ( this->npw > this->nrxx ) ? this->npw : this->nrxx;
81+
}
7682

7783
//---------------------------------------------
7884
// sum : starting plane of FFT box.
@@ -85,23 +91,35 @@ void PW_Basis::getstartgr()
8591
// Each processor has a set of full sticks,
8692
// 'rank_use' processor send a piece(npps[ip]) of these sticks(nst_per[rank_use])
8793
// to all the other processors in this pool
88-
for (int ip = 0;ip < poolnproc; ++ip) this->numg[ip] = this->nst_per[poolrank] * this->numz[ip];
94+
for (int ip = 0;ip < poolnproc; ++ip)
95+
{
96+
this->numg[ip] = this->nst_per[poolrank] * this->numz[ip];
97+
}
8998

9099

91100
// Each processor in a pool send a piece of each stick(nst_per[ip]) to
92101
// other processors in this pool
93102
// rank_use processor receive datas in npps[rank_p] planes.
94-
for (int ip = 0;ip < poolnproc; ++ip) this->numr[ip] = this->nst_per[ip] * this->numz[poolrank];
103+
for (int ip = 0;ip < poolnproc; ++ip)
104+
{
105+
this->numr[ip] = this->nst_per[ip] * this->numz[poolrank];
106+
}
95107

96108

97109
// startg record the starting 'numg' position in each processor.
98110
this->startg[0] = 0;
99-
for (int ip = 1;ip < poolnproc; ++ip) this->startg[ip] = this->startg[ip-1] + this->numg[ip-1];
111+
for (int ip = 1;ip < poolnproc; ++ip)
112+
{
113+
this->startg[ip] = this->startg[ip-1] + this->numg[ip-1];
114+
}
100115

101116

102117
// startr record the starting 'numr' position
103118
this->startr[0] = 0;
104-
for (int ip = 1;ip < poolnproc; ++ip) this->startr[ip] = this->startr[ip-1] + this->numr[ip-1];
119+
for (int ip = 1;ip < poolnproc; ++ip)
120+
{
121+
this->startr[ip] = this->startr[ip-1] + this->numr[ip-1];
122+
}
105123
return;
106124
}
107125

@@ -112,7 +130,10 @@ void PW_Basis::getstartgr()
112130
///
113131
void PW_Basis::collect_local_pw()
114132
{
115-
if(this->npw <= 0) return;
133+
if(this->npw <= 0)
134+
{
135+
return;
136+
}
116137
this->ig_gge0 = -1;
117138
delete[] this->gg; this->gg = new double[this->npw];
118139
delete[] this->gdirect; this->gdirect = new ModuleBase::Vector3<double>[this->npw];
@@ -127,16 +148,28 @@ void PW_Basis::collect_local_pw()
127148
int ixy = this->is2fftixy[is];
128149
int ix = ixy / this->fftny;
129150
int iy = ixy % this->fftny;
130-
if (ix >= int(this->nx/2) + 1) ix -= this->nx;
131-
if (iy >= int(this->ny/2) + 1) iy -= this->ny;
132-
if (iz >= int(this->nz/2) + 1) iz -= this->nz;
151+
if (ix >= int(this->nx/2) + 1)
152+
{
153+
ix -= this->nx;
154+
}
155+
if (iy >= int(this->ny/2) + 1)
156+
{
157+
iy -= this->ny;
158+
}
159+
if (iz >= int(this->nz/2) + 1)
160+
{
161+
iz -= this->nz;
162+
}
133163
f.x = ix;
134164
f.y = iy;
135165
f.z = iz;
136166
this->gg[ig] = f * (this->GGT * f);
137167
this->gdirect[ig] = f;
138168
this->gcar[ig] = f * this->G;
139-
if(this->gg[ig] < 1e-8) this->ig_gge0 = ig;
169+
if(this->gg[ig] < 1e-8)
170+
{
171+
this->ig_gge0 = ig;
172+
}
140173
}
141174
return;
142175
}
@@ -148,10 +181,13 @@ void PW_Basis::collect_local_pw()
148181
///
149182
void PW_Basis::collect_uniqgg()
150183
{
151-
if(this->npw <= 0) return;
184+
if(this->npw <= 0)
185+
{
186+
return;
187+
}
152188
this->ig_gge0 = -1;
153189
delete[] this->ig2igg; this->ig2igg = new int [this->npw];
154-
//add by A.s 202406
190+
155191
int *sortindex = new int [this->npw];//Reconstruct the mapping of the plane wave index ig according to the energy size of the plane waves
156192
double *tmpgg = new double [this->npw];//Ranking the plane waves by energy size while ensuring that the same energy is preserved for each wave to correspond
157193
double *tmpgg2 = new double [this->npw];//ranking the plane waves by energy size and removing the duplicates
@@ -164,14 +200,26 @@ void PW_Basis::collect_uniqgg()
164200
int ixy = this->is2fftixy[is];
165201
int ix = ixy / this->fftny;
166202
int iy = ixy % this->fftny;
167-
if (ix >= int(this->nx/2) + 1) ix -= this->nx;
168-
if (iy >= int(this->ny/2) + 1) iy -= this->ny;
169-
if (iz >= int(this->nz/2) + 1) iz -= this->nz;
203+
if (ix >= int(this->nx/2) + 1)
204+
{
205+
ix -= this->nx;
206+
}
207+
if (iy >= int(this->ny/2) + 1)
208+
{
209+
iy -= this->ny;
210+
}
211+
if (iz >= int(this->nz/2) + 1)
212+
{
213+
iz -= this->nz;
214+
}
170215
f.x = ix;
171216
f.y = iy;
172217
f.z = iz;
173218
tmpgg[ig] = f * (this->GGT * f);
174-
if(tmpgg[ig] < 1e-8) this->ig_gge0 = ig;
219+
if(tmpgg[ig] < 1e-8)
220+
{
221+
this->ig_gge0 = ig;
222+
}
175223
}
176224

177225
ModuleBase::GlobalFunc::ZEROS(sortindex, this->npw);
@@ -215,7 +263,10 @@ void PW_Basis::collect_uniqgg()
215263
void PW_Basis::getfftixy2is(int * fftixy2is) const
216264
{
217265
//Note: please assert when is1 >= is2, fftixy2is[is1] >= fftixy2is[is2]!
218-
for(int ixy = 0 ; ixy < this->fftnxy ; ++ixy) fftixy2is[ixy] = -1;
266+
for(int ixy = 0 ; ixy < this->fftnxy ; ++ixy)
267+
{
268+
fftixy2is[ixy] = -1;
269+
}
219270
int ixy = 0;
220271
for(int is = 0; is < this->nst; ++is)
221272
{

source/module_basis/module_pw/pw_basis.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ class PW_Basis
237237
int nx=0, ny=0, nz=0, nxyz=0, nxy=0; // Gamma_only: fftny = int(ny/2)-1 , others: fftny = ny
238238
int liy=0, riy=0;// liy: the left edge of the pw ball; riy: the right edge of the pw ball in the y direction
239239
int lix=0, rix=0;// lix: the left edge of the pw ball; rix: the right edge of the pw ball in the x direction
240-
bool xprime = true; // true: when do recip2real, x-fft will be done last and when doing real2recip, x-fft will be done first; false: y-fft
241-
// For gamma_only, true: we use half x; false: we use half y
240+
bool xprime = true; // true: when do recip2real, x-fft will be done last and when doing real2recip, x-fft will be
241+
// done first; false: y-fft For gamma_only, true: we use half x; false: we use half y
242242
int ng_xeq0 = 0; //only used when xprime = true, number of g whose gx = 0
243-
int nmaxgr=0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx
244-
// Thus complex<double>[nmaxgr] is able to contain either reciprocal or real data
243+
int nmaxgr = 0; // Gamma_only: max between npw and (nrxx+1)/2, others: max between npw and nrxx
244+
// Thus complex<double>[nmaxgr] is able to contain either reciprocal or real data
245245
// FFT ft;
246246
FFT_Bundle fft_bundle;
247247
//The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform).

source/module_basis/module_pw/test_serial/pw_basis_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ TEST_F(PWBasisTEST,GetStartGR)
278278
EXPECT_EQ(pwb.nrxx,8000);
279279
EXPECT_EQ(pwb.nxy,400);
280280
EXPECT_EQ(pwb.nplane,20);
281-
EXPECT_EQ(pwb.nmaxgr,8000);
281+
EXPECT_EQ(pwb.nmaxgr,4000);
282282
EXPECT_EQ(pwb.numg[0],3120);
283283
EXPECT_EQ(pwb.numr[0],3120);
284284
EXPECT_EQ(pwb.startg[0],0);

source/module_io/unk_overlap_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
7575
std::complex<double>* phase = new std::complex<double>[rhopw->nmaxgr];
7676

7777
// get the phase value in realspace
78-
for (int ig = 0; ig < rhopw->nmaxgr; ig++)
78+
for (int ig = 0; ig < rhopw->npw; ig++)
7979
{
8080
ModuleBase::Vector3<double> delta_G = rhopw->gdirect[ig] - G;
8181
if (delta_G.norm2() < 1e-10) // rhopw->gdirect[ig] == G
@@ -89,7 +89,7 @@ std::complex<double> unkOverlap_pw::unkdotp_G0(const ModulePW::PW_Basis* rhopw,
8989
rhopw->recip2real(phase, phase);
9090
wfcpw->recip2real(&evc[0](ik_L, iband_L, 0), psi_r, ik_L);
9191

92-
for (int ir = 0; ir < rhopw->nmaxgr; ir++)
92+
for (int ir = 0; ir < rhopw->nrxx; ir++)
9393
{
9494
psi_r[ir] = psi_r[ir] * phase[ir];
9595
}

tests/integrate/tools/catch_properties.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ if ! test -z "$out_dm" && [ $out_dm == 1 ]; then
401401
fi
402402

403403
if ! test -z "$out_mul" && [ $out_mul == 1 ]; then
404-
python3 ../tools/CompareFile.py mulliken.txt.ref OUT.autotest/mulliken.txt 6
404+
python3 ../tools/CompareFile.py mulliken.txt.ref OUT.autotest/mulliken.txt 4
405405
echo "Compare_mulliken_pass $?" >>$1
406406
fi
407407

0 commit comments

Comments
 (0)