Skip to content

Commit 0161bdb

Browse files
committed
Fixing bugs in QR_bindings and parallel tridiag. Added a new script to solve the eigenvalue problem
1 parent 97df06b commit 0161bdb

File tree

3 files changed

+425
-335
lines changed

3 files changed

+425
-335
lines changed

src/pyclassify/QR_bindings.cpp

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@
44
#include <array>
55
#include <cmath>
66
#include <stdexcept>
7-
8-
#include "pybind11/pybind11.h"
9-
#include "pybind11/stl.h"
10-
namespace py=pybind11;
117

128

9+
1310
std::pair<std::vector<double>, std::vector<std::vector<double>> >
14-
QR_algorithm(std::vector<double> diag, std::vector<double> off_diag, const double tol=1e-8, const unsigned int max_iter=5000){
11+
QR_algorithm(std::vector<double> diag, std::vector<double> off_diag, const double toll=1e-8, const unsigned int max_iter=5000){
1512

1613
if(diag.size() != (off_diag.size()+1)){
1714
std::invalid_argument("The dimension of the diagonal and off-diagonal vector are not compatible");
@@ -38,7 +35,7 @@ std::pair<std::vector<double>, std::vector<std::vector<double>> >
3835
double tmp=0;
3936
double x=0, y=0;
4037
unsigned int m=n-1;
41-
double tol_equivalence=1e-10;
38+
double toll_equivalence=1e-10;
4239
double w=0, z=0;
4340

4441

@@ -50,7 +47,7 @@ std::pair<std::vector<double>, std::vector<std::vector<double>> >
5047
b_m_1=off_diag[m-1];
5148
d=(diag[m-1]-a_m)*0.5;
5249

53-
if(std::abs(d)<tol_equivalence){
50+
if(std::abs(d)<toll_equivalence){
5451
mu=diag[m]-std::abs(b_m_1);
5552
} else{
5653
mu= a_m - b_m_1*b_m_1/( d*( 1+sqrt(d*d+b_m_1*b_m_1)/std::abs(d) ) );
@@ -86,7 +83,7 @@ std::pair<std::vector<double>, std::vector<std::vector<double>> >
8683
}
8784

8885
}else{
89-
if(std::abs(d)<tol_equivalence){
86+
if(std::abs(d)<toll_equivalence){
9087
if (off_diag[0]*d>0)
9188
{
9289
c=std::sqrt(2)/2;
@@ -134,11 +131,11 @@ std::pair<std::vector<double>, std::vector<std::vector<double>> >
134131

135132

136133
unsigned j, k;
137-
for(unsigned int i=0; i<n-1; i++){
134+
for(unsigned int i=0; i<m; i++){
138135
c=Matrix_trigonometric[i][0];
139136
s=Matrix_trigonometric[i][1];
140137

141-
for(j=0; j<n;j=j+5){
138+
for(j=0; (j+5)<n;j=j+5){
142139
k=n*i+j;
143140
tmp=Q[k];
144141
Q[k]=tmp*c-Q[k+n]*s;
@@ -168,8 +165,7 @@ std::pair<std::vector<double>, std::vector<std::vector<double>> >
168165
}
169166

170167
iter++;
171-
if (iter >= max_iter) {std::cout << "Max iteration has been reached. Maybe tol was too low?" << std::endl;}
172-
if ( std::abs(off_diag[m-1]) < tol*( std::abs(diag[m]) + std::abs(diag[m-1]) ) )
168+
if ( std::abs(off_diag[m-1]) < toll*( std::abs(diag[m]) + std::abs(diag[m-1]) ) )
173169
{
174170
--m;
175171
}
@@ -181,8 +177,9 @@ std::pair<std::vector<double>, std::vector<std::vector<double>> >
181177

182178
std::vector<std::vector<double>> eig_vec(n,std::vector<double> (n, 0));
183179
//std::cout<<"Iteration: "<<iter<<std::endl;
184-
for(unsigned int i=0; i<n; i++){
185-
for(unsigned j=0; j<n; j++){
180+
181+
for(unsigned j=0; j<n; j++){
182+
for(unsigned int i=0; i<n; i++){
186183
eig_vec[i][j]=Q[i+j*n];
187184
}
188185
}
@@ -193,7 +190,7 @@ std::pair<std::vector<double>, std::vector<std::vector<double>> >
193190
}
194191

195192
std::vector<double>
196-
Eigen_value_calculator(std::vector<double> diag, std::vector<double> off_diag, const double tol=1e-8, const unsigned int max_iter=5000){
193+
Eigen_value_calculator(std::vector<double> diag, std::vector<double> off_diag, const double toll=1e-8, const unsigned int max_iter=5000){
197194

198195
if(diag.size() != (off_diag.size()+1)){
199196
std::invalid_argument("The dimension of the diagonal and off-diagonal vector are not compatible");
@@ -214,7 +211,7 @@ std::vector<double>
214211
double a_m=0, b_m_1=0;
215212
double x=0, y=0;
216213
unsigned int m=n-1;
217-
double tol_equivalence=1e-10;
214+
double toll_equivalence=1e-10;
218215
double w=0, z=0;
219216

220217

@@ -226,7 +223,7 @@ std::vector<double>
226223
b_m_1=off_diag[m-1];
227224
d=(diag[m-1]-a_m)*0.5;
228225

229-
if(std::abs(d)<tol_equivalence){
226+
if(std::abs(d)<toll_equivalence){
230227
mu=diag[m]-std::abs(b_m_1);
231228
} else{
232229
mu= a_m - b_m_1*b_m_1/( d*( 1+sqrt(d*d+b_m_1*b_m_1)/std::abs(d) ) );
@@ -262,7 +259,7 @@ std::vector<double>
262259
}
263260

264261
}else{
265-
if(std::abs(d)<tol_equivalence){
262+
if(std::abs(d)<toll_equivalence){
266263
if (off_diag[0]*d>0)
267264
{
268265
c=std::sqrt(2)/2;
@@ -309,8 +306,7 @@ std::vector<double>
309306

310307

311308
iter++;
312-
if (iter >= max_iter) {std::cout << "Max iteration has been reached. Maybe tol was too low?" << std::endl;}
313-
if ( std::abs(off_diag[m-1]) < tol*( std::abs(diag[m]) + std::abs(diag[m-1]) ) )
309+
if ( std::abs(off_diag[m-1]) < toll*( std::abs(diag[m]) + std::abs(diag[m-1]) ) )
314310
{
315311
--m;
316312
}
@@ -322,14 +318,20 @@ std::vector<double>
322318

323319

324320
return diag;
321+
325322

326323
}
327324

325+
#include "pybind11/pybind11.h"
326+
#include "pybind11/stl.h"
327+
328328

329329

330+
namespace py=pybind11;
331+
330332
PYBIND11_MODULE(QR_cpp, m) {
331333
m.doc() = "Function that computes the eigenvalue and eigenvector"; // Optional module docstring.
332334

333-
m.def("QR_algorithm", &QR_algorithm, py::arg("diag"), py::arg("off_diag"), py::arg("tol")=1e-8, py::arg("max_iter")=5000);
334-
m.def("Eigen_value_calculator", &Eigen_value_calculator, py::arg("diag"), py::arg("off_diag"), py::arg("tol")=1e-8, py::arg("max_iter")=5000);
335-
}
335+
m.def("QR_algorithm", &QR_algorithm, py::arg("diag"), py::arg("off_diag"), py::arg("toll")=1e-8, py::arg("max_iter")=5000);
336+
m.def("Eigen_value_calculator", &Eigen_value_calculator, py::arg("diag"), py::arg("off_diag"), py::arg("toll")=1e-8, py::arg("max_iter")=5000);
337+
}

0 commit comments

Comments
 (0)