Skip to content

Commit 89da792

Browse files
committed
improved QR algorithm with shift
1 parent 0d28eba commit 89da792

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

src/pyclassify/QR_alg_Wilkison.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#include <utility>
2+
#include <vector>
3+
#include <cmath>
4+
#include <iostream>
5+
#include <algorithm>
6+
#include <array>
7+
8+
9+
//std::pair<std::vector<double>, std::vector<std::vector<double>> >
10+
void QR_algorithm(std::vector<double> diag, std::vector<double> off_diag, const double toll=1e-8, const unsigned int max_iter=1000){
11+
12+
const unsigned int n = diag.size();
13+
14+
std::vector<std::vector<double>> Q(n, std::vector<double>(n, 0));
15+
16+
for(unsigned int i = 1; i < n-1; i++){
17+
Q[i][i] = 1;
18+
}
19+
Q[0][0] = 1;
20+
Q[n-1][n-1] = 1;
21+
22+
23+
std::vector<std::array<double, 2>> Matrix_trigonometric(n-1, {0, 0});
24+
25+
unsigned int iter = 0;
26+
std::vector<double> eigenvalues_old{diag};
27+
28+
double r=0, c=0, s=0;
29+
double d=0, mu; // mu: Wilkinson shift
30+
double a_m=0, b_m_1=0;
31+
double tmp=0;
32+
double x=0, y=0;
33+
int m=n-1;
34+
double toll_equivalence=1e-10;
35+
double w=0, z=0;
36+
37+
while (iter<max_iter && m>0)
38+
{
39+
// prefetching most used value to avoid call overhead
40+
a_m=diag[m];
41+
b_m_1=off_diag[m-1];
42+
d=(diag[m-1]-a_m)*0.5;
43+
44+
if(std::abs(d)<toll_equivalence){
45+
mu=diag[m]-std::abs(b_m_1);
46+
} else{
47+
mu= a_m - b_m_1*b_m_1/( d*( 1+sqrt(d*d+b_m_1*b_m_1)/std::abs(d) ) );
48+
}
49+
50+
x=diag[0]-mu;
51+
y=off_diag[0];
52+
53+
for(unsigned int i=0; i<m; i++){
54+
if (m>1)
55+
{
56+
r=std::sqrt(x*x+y*y);
57+
c=x/r;
58+
s=-y/r;
59+
}else{
60+
if(std::abs(d)<toll_equivalence){
61+
if (off_diag[0]*d>0)
62+
{
63+
c=std::sqrt(2)/2;
64+
s=-std::sqrt(2)/2;
65+
} else{
66+
c=s=std::sqrt(2)/2;
67+
}
68+
69+
} else{
70+
double x_0=0, x_new=0, b_2=off_diag[0];
71+
if(off_diag[0]*d>0){
72+
x_0=-3.14/4;
73+
} else{
74+
x_0=3.14/4;
75+
}
76+
double err_rel=1;
77+
unsigned int iter_newton=0;
78+
while (err_rel>1e-10 && iter_newton<1000)
79+
{
80+
x_new=x_0+std::cos(x_0)*std::cos(x_0)*(std::tan(x_0) + b_2/d);
81+
err_rel=std::abs((x_new-x_0)/x_new);
82+
x_0=x_new;
83+
++iter_newton;
84+
}
85+
c=std::cos(x_new/2);
86+
s=std::sin(x_new/2);
87+
x=x+mu;
88+
89+
}
90+
91+
}
92+
93+
Matrix_trigonometric[i][0] = c;
94+
Matrix_trigonometric[i][1] = s;
95+
96+
w=c*x-s*y;
97+
d=diag[i]-diag[i+1];
98+
z=(2*c*off_diag[i] +d*s)*s;
99+
diag[i] -= z;
100+
diag[i+1] += z;
101+
off_diag[i]= d*c*s + (c*c-s*s)*off_diag[i];
102+
x=off_diag[i];
103+
if (i>0)
104+
{
105+
off_diag[i-1]=w;
106+
}
107+
108+
if(i<m-1){
109+
y=-s*off_diag[i+1];
110+
off_diag[i+1]=c*off_diag[i+1];
111+
}
112+
113+
}
114+
iter++;
115+
if ( std::abs(off_diag[m-1]) < toll*( std::abs(diag[m]) + std::abs(diag[m-1]) ) )
116+
{
117+
--m;
118+
}
119+
120+
121+
}
122+
123+
124+
}
125+
126+
127+
int main(){
128+
129+
std::vector<double> diag{1, 2, 3, 4, 5}, offdiag(4, 2);
130+
QR_algorithm(diag, offdiag);
131+
return 0;
132+
}

0 commit comments

Comments
 (0)