Skip to content

Commit f0d37d5

Browse files
committed
fixing formatting
1 parent c3b61c4 commit f0d37d5

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

src/pyclassify/parallel_tridiag_eigen.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -316,22 +316,20 @@ def parallel_tridiag_eigen(
316316
# D, v_vec, beta, tol_factor
317317
# )
318318

319-
320-
321-
D_keep=np.array(D_keep)
319+
D_keep = np.array(D_keep)
322320

323321
reduced_dim = len(D_keep)
324322

325323
if D_keep.size > 0:
326324
idx = np.argsort(D_keep)
327325
idx_inv = np.arange(0, reduced_dim)
328326
idx_inv = idx_inv[idx]
329-
327+
330328
# T= np.diag(D_keep) + beta * np.outer(v_keep, v_keep)
331329
# lam , _ = np.linalg.eigh(T)
332330

333331
lam, changing_position, delta = secular_solver_cxx(
334-
beta, D_keep[idx], v_keep[idx] , np.arange(reduced_dim)
332+
beta, D_keep[idx], v_keep[idx], np.arange(reduced_dim)
335333
)
336334
lam = np.array(lam)
337335
delta = np.array(delta)
@@ -366,18 +364,17 @@ def parallel_tridiag_eigen(
366364
D_keep = comm.bcast(D_keep, root=0)
367365
v_keep = comm.bcast(v_keep, root=0)
368366
my_count = counts[rank]
369-
type_lam=comm.bcast(lam.dtype, root=0)
370-
371-
lam_buffer=np.empty(my_count, dtype=type_lam)
367+
type_lam = comm.bcast(lam.dtype, root=0)
372368

373-
P=comm.bcast(P, root=0)
374-
D_size=comm.bcast(D_size)
375-
changing_position=comm.bcast(changing_position, root=0)
376-
delta=comm.bcast(delta, root=0)
377-
idx_inv=comm.bcast(idx_inv, root=0)
378-
n1=comm.bcast(n1, root=0)
379-
reduced_dim=comm.bcast(reduced_dim, root=0)
369+
lam_buffer = np.empty(my_count, dtype=type_lam)
380370

371+
P = comm.bcast(P, root=0)
372+
D_size = comm.bcast(D_size)
373+
changing_position = comm.bcast(changing_position, root=0)
374+
delta = comm.bcast(delta, root=0)
375+
idx_inv = comm.bcast(idx_inv, root=0)
376+
n1 = comm.bcast(n1, root=0)
377+
reduced_dim = comm.bcast(reduced_dim, root=0)
381378

382379
# map numpy dtype → MPI datatype
383380
if lam.dtype == np.float64:
@@ -398,7 +395,6 @@ def parallel_tridiag_eigen(
398395
root=0,
399396
)
400397

401-
402398
initial_point = displs[rank]
403399

404400
for k_rel in range(lam_buffer.size):
@@ -543,8 +539,6 @@ def parallel_tridiag_eigen(
543539
return final_eig_val, final_eig_vecs
544540

545541

546-
547-
548542
def parallel_eigen(
549543
main_diag, off_diag, tol_QR=1e-15, max_iterQR=5000, tol_deflation=1e-15
550544
):
@@ -578,7 +572,7 @@ def parallel_eigen(
578572
# main_diag = np.ones(n, dtype=np.float64) * 2.0
579573
# off_diag = np.ones(n - 1, dtype=np.float64) *1.0
580574
main_diag = (np.random.rand(n) * 2).astype(np.float64)
581-
off_diag = (np.random.rand(n - 1) *1).astype(np.float64)
575+
off_diag = (np.random.rand(n - 1) * 1).astype(np.float64)
582576
# eig = np.arange(1, n + 1)
583577
# A = np.diag(eig)
584578
# U = scipy.stats.ortho_group.rvs(n)

0 commit comments

Comments
 (0)