Skip to content

Commit 588df88

Browse files
Apply Black code formatting
1 parent 3f0e139 commit 588df88

File tree

5 files changed

+32
-16
lines changed

5 files changed

+32
-16
lines changed

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def build_extension(self, ext):
2323
self.spawn(["cmake", "--build", build_temp, "--target", "QR_cpp"])
2424

2525
# Dynamically find the compiled shared library
26-
matches = glob.glob(os.path.join(ext.sourcedir, "src", "pyclassify", "QR_cpp*.so"))
26+
matches = glob.glob(
27+
os.path.join(ext.sourcedir, "src", "pyclassify", "QR_cpp*.so")
28+
)
2729
if not matches:
2830
raise RuntimeError(
2931
"Could not find compiled QR_cpp shared library in expected location."

src/pyclassify/eigenvalues.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
max_iteration_warning,
1515
)
1616

17-
#from parallel_tridiag_eigen import parallel_eigen
17+
# from parallel_tridiag_eigen import parallel_eigen
18+
1819

1920
def eigenvalues_np(A, symmetric=True):
2021
"""
@@ -221,7 +222,7 @@ def __init__(self, A: np.ndarray, max_iter=5000, tol=1e-8, tol_deflation=1e-12):
221222
self.diag = None
222223
self.off_diag = None
223224
self.Q = None
224-
self.tol_deflation=tol_deflation
225+
self.tol_deflation = tol_deflation
225226

226227
# @jit(nopython=True, parallel=True) # removed because is it not compatible with C++ functions!
227228
def Lanczos_PRO(self, A=None, q=None, m=None, tol=np.sqrt(np.finfo(float).eps)):
@@ -383,7 +384,6 @@ def eig(self, diag=None, off_diag=None):
383384
Q_triangular = np.array(Q_triangular)
384385
return np.array(eig), Q_triangular @ self.Q.T
385386

386-
387387
# def parallel_tridiagMatrix_eig_solver(self, diag=None, off_diag=None):
388388
# if diag is None and off_diag is None:
389389
# if self.diag is None:
@@ -394,9 +394,10 @@ def eig(self, diag=None, off_diag=None):
394394
# off_diag = self.off_diag
395395
# if len(diag) != (len(off_diag) + 1):
396396
# raise ValueError("Mismatch between diagonal and off diagonal size")
397-
397+
398398
# return(parallel_eigen(self.diag, self.off_diag, self.tol_QR, self.max_iterQR, self.tol_deflation))
399399

400+
400401
# def power_method_cp(A, max_iter=500, tol=1e-4, x=None):
401402
# """
402403
# Compute the dominant eigenvalue of a square matrix using the power method.

src/pyclassify/parallel_tridiag_eigen.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,17 @@ def deflate_eigenpairs(D, v, beta, tol_factor=1e-12):
139139

140140

141141
@profile
142-
def parallel_tridiag_eigen(diag, off, comm=None, tol_factor=1e-16, min_size=1, depth=0, profiler=None, tol_QR=1e-8, max_iterQR=5000):
142+
def parallel_tridiag_eigen(
143+
diag,
144+
off,
145+
comm=None,
146+
tol_factor=1e-16,
147+
min_size=1,
148+
depth=0,
149+
profiler=None,
150+
tol_QR=1e-8,
151+
max_iterQR=5000,
152+
):
143153
"""
144154
Computes eigenvalues and eigenvectors of a symmetric tridiagonal matrix.
145155
Input:
@@ -183,10 +193,8 @@ def parallel_tridiag_eigen(diag, off, comm=None, tol_factor=1e-16, min_size=1,
183193

184194
if n <= min_size or size == 1:
185195
eigvals, eigvecs = QR_algorithm(diag, off, tol_QR, max_iterQR)
186-
eigvecs=np.array(eigvecs)
187-
eigvals=np.array(eigvals)
188-
189-
196+
eigvecs = np.array(eigvecs)
197+
eigvals = np.array(eigvals)
190198

191199
# profiler[-1].disable_by_count()
192200
# with open(prof_filename, "w") as f:
@@ -384,9 +392,16 @@ def parallel_eigen(main_diag, off_diag, tol_QR, max_iterQR, tol_deflation):
384392
comm = MPI.COMM_WORLD
385393
main_diag = comm.bcast(main_diag, root=0)
386394
off_diag = comm.bcast(off_diag, root=0)
387-
eigvals, eigvecs = parallel_tridiag_eigen(main_diag, off_diag, comm=comm, min_size=1, tol_factor=tol_deflation, tol_QR=tol_QR, max_iterQR=max_iterQR)
395+
eigvals, eigvecs = parallel_tridiag_eigen(
396+
main_diag,
397+
off_diag,
398+
comm=comm,
399+
min_size=1,
400+
tol_factor=tol_deflation,
401+
tol_QR=tol_QR,
402+
max_iterQR=max_iterQR,
403+
)
388404
return eigvals, eigvecs
389-
390405

391406

392407
if __name__ == "__main__":

test/test_eigensolvers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from pyclassify.utils import make_symmetric
1919

20+
2021
@pytest.fixture(autouse=True)
2122
def set_random_seed():
2223
seed = 1422
@@ -115,8 +116,6 @@ def test_EigenSolver(size):
115116
_ = eigensolver.compute_eigenval(diag=np.arange(2), off_diag=np.arange(49))
116117

117118

118-
119-
120119
# @pytest.mark.parametrize("size", sizes)
121120
# @pytest.mark.parametrize("density", densities)
122121
# def test_cupy(size, density):

test/test_zero_finder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def test_psi_s(rho, d, v, i):
3838
assert psi2(lambda_guess) * rho >= 0, "Error. Inconsistent with the theory."
3939

4040

41-
4241
@pytest.mark.parametrize("d", d_s)
4342
@pytest.mark.parametrize("rho", rho_s)
4443
@pytest.mark.parametrize("v", v_s)
@@ -61,6 +60,7 @@ def test_compute_eigenvalues(rho, d, v):
6160
np.abs(computed_eigs[i] - exact_eigs[i]) < 1e-8
6261
), "Error. The eigenvalues were not computed correctly."
6362

63+
6464
@pytest.mark.parametrize("d", d_s)
6565
@pytest.mark.parametrize("rho", rho_s)
6666
@pytest.mark.parametrize("v", v_s)
@@ -82,4 +82,3 @@ def test_compute_eigenvalues(rho, d, v):
8282
assert (
8383
np.abs(computed_eigs[i] - exact_eigs[i]) < 1e-8
8484
), "Error. The eigenvalues were not computed correctly."
85-

0 commit comments

Comments
 (0)