Skip to content

Commit 30662f0

Browse files
committed
Restructuring the parallel_tridiag_eigen.py script
1 parent 49e5b0b commit 30662f0

File tree

1 file changed

+138
-157
lines changed

1 file changed

+138
-157
lines changed

src/pyclassify/parallel_tridiag_eigen.py

Lines changed: 138 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -29,139 +29,6 @@ def check_column_directions(A, B):
2929
B[:, i] = -B[:, i]
3030

3131

32-
@profile
33-
def deflate_eigenpairs(D, v, beta, tol_factor=1e-12):
34-
"""
35-
Applying the deflation step to the divide and conquer algorithm to reduce the size of
36-
the system to be solved and find the trivial eigenvalues and eigenvectors.
37-
Notice that we cannot use jit since we work with scipy sparse matrices.
38-
Input:
39-
-D: element on the diagonal.
40-
-v: rank one update vector
41-
-beta: scalar value that multiplies the rank one update
42-
43-
Output:
44-
-deflated_eigvals: Vector containing the trivial eigenvalue
45-
-deflated_eigvecs: Matrix containing the trivial eigenvectors
46-
-D_keep: element on the diagonal after the deflation.
47-
-v_keep: rank one update vector after the deflation
48-
-P_3 @ P_2 @ P: Permutation matrix
49-
"""
50-
51-
norm_T = np.linalg.norm(
52-
np.diag(D) + beta * np.outer(v, v)
53-
) # Normalize tolerance to matrix size
54-
# abs_v=np.abs(v)
55-
# norm_T=0.5*(np.max(np.abs(D)) + beta * np.max(abs_v)*np.sum(abs_v))
56-
tol = tol_factor * norm_T
57-
keep_indices = []
58-
deflated_eigvals = np.zeros_like(D)
59-
deflated_eigvecs = []
60-
deflated_indices = []
61-
reduced_dimension = len(D)
62-
63-
# Zero component deflation
64-
e_vec = np.zeros(len(D))
65-
j = 0
66-
for i in range(len(D)):
67-
if abs(v[i]) < tol:
68-
deflated_eigvals[j] = D[i]
69-
e_vec[i] = 1.0 # Standard basis vector
70-
deflated_eigvecs.append(e_vec.copy())
71-
deflated_indices.append(i)
72-
e_vec[i] = 0.0
73-
j += 1
74-
else:
75-
keep_indices.append(i)
76-
77-
new_order = keep_indices + deflated_indices
78-
reduced_dimension = len(keep_indices)
79-
80-
# Create permutation matrix P (use sparse)
81-
n = len(D)
82-
P = sp.lil_array((n, n)) # Use sparse format
83-
P_2 = sp.eye(n, format="csr") # Efficient multiplication format
84-
P_3 = sp.lil_array((n, n)) # Permutation matrix
85-
86-
for new_pos, old_pos in enumerate(new_order):
87-
P[new_pos, old_pos] = 1 # Assign 1s to move elements accordingly
88-
89-
P = P.tocsr() # Convert to CSR format for fast multiplication
90-
91-
D_keep = D[keep_indices]
92-
v_keep = v[keep_indices]
93-
94-
to_check = set(
95-
np.arange(reduced_dimension, dtype=np.int64)
96-
) # Use set for fast lookup
97-
rotation_matrix = []
98-
vec_idx_list = [] # Use list instead of np.append()
99-
100-
to_check_copy = list(to_check) # Convert to list for iteration
101-
102-
for i in to_check_copy[:-1]:
103-
if i not in to_check:
104-
continue # Skip if the index was removed
105-
106-
# Find duplicates in a vectorized way
107-
idx_duplicate_vec = np.where(np.abs(D_keep[i + 1 :] - D_keep[i]) < tol)[0]
108-
if len(idx_duplicate_vec):
109-
idx_duplicate_vec += i + 1 # Adjust indices
110-
111-
for idx_duplicate in idx_duplicate_vec:
112-
to_check.discard(idx_duplicate) # O(1) removal instead of np.delete()
113-
114-
# Compute Givens rotation parameters
115-
r = np.hypot(
116-
v_keep[i], v_keep[idx_duplicate]
117-
) # More stable than sqrt(x^2 + y^2)
118-
c = v_keep[i] / r
119-
s = -v_keep[idx_duplicate] / r
120-
121-
v_keep[i] = r
122-
v_keep[idx_duplicate] = 0
123-
124-
# Store transformation
125-
rotation_matrix.append((i, idx_duplicate, c, s))
126-
deflated_eigvals[j] = D_keep[i]
127-
j += 1
128-
129-
# Efficient eigenvector computation
130-
eig_vec_local = np.zeros(n)
131-
eig_vec_local[idx_duplicate] = c
132-
eig_vec_local[i] = s
133-
deflated_eigvecs.append(P.T @ eig_vec_local)
134-
135-
vec_idx_list.append(
136-
idx_duplicate
137-
) # Use list instead of slow np.append()
138-
139-
new_order = np.concatenate(
140-
(list(to_check), vec_idx_list)
141-
) # Efficient concatenation
142-
new_order = new_order.astype(int)
143-
deflated_eigvals = deflated_eigvals[:j]
144-
145-
# Apply Givens rotations
146-
for i, j, c, s in rotation_matrix:
147-
G = sp.eye(n, n, format="csr") # Sparse identity matrix
148-
G[i, i] = G[j, j] = c
149-
G[i, j] = -s
150-
G[j, i] = s
151-
P_2 = P_2 @ G # Sparse multiplication
152-
153-
for new_pos, old_pos in enumerate(new_order):
154-
P_3[new_pos, old_pos] = 1 # Assign 1s
155-
156-
P_3 = P_3.tocsr()
157-
158-
to_check = [i for i in to_check]
159-
reduced_dimension = len(to_check)
160-
D_keep = D_keep[to_check]
161-
v_keep = v_keep[to_check]
162-
163-
return deflated_eigvals, np.array(deflated_eigvecs), D_keep, v_keep, P_3 @ P_2 @ P
164-
16532

16633
def find_interval_extreme(total_dimension, n_processor):
16734
"""
@@ -293,12 +160,6 @@ def parallel_tridiag_eigen(
293160
eigvals_right = subcomm.bcast(eigvals_right, root=0)
294161
eigvecs_right = subcomm.bcast(eigvecs_right, root=0)
295162

296-
# if rank == 0:
297-
# eigvals_right = comm.recv(source=left_size, tag=77)
298-
# eigvecs_right = comm.recv(source=left_size, tag=78)
299-
# elif rank == left_size:
300-
# comm.send(eigvals_right, dest=0, tag=77)
301-
# comm.send(eigvecs_right, dest=0, tag=78)
302163

303164
if rank == 0:
304165

@@ -312,10 +173,6 @@ def parallel_tridiag_eigen(
312173
D, v_vec, beta, tol_factor
313174
)
314175

315-
# Pdeflated_eigvals, Pdeflated_eigvecs, PD_keep, Pv_keep, PP = deflate_eigenpairs(
316-
# D, v_vec, beta, tol_factor
317-
# )
318-
319176
D_keep = np.array(D_keep)
320177

321178
reduced_dim = len(D_keep)
@@ -325,9 +182,6 @@ def parallel_tridiag_eigen(
325182
idx_inv = np.arange(0, reduced_dim)
326183
idx_inv = idx_inv[idx]
327184

328-
# T= np.diag(D_keep) + beta * np.outer(v_keep, v_keep)
329-
# lam , _ = np.linalg.eigh(T)
330-
331185
lam, changing_position, delta = secular_solver_cxx(
332186
beta, D_keep[idx], v_keep[idx], np.arange(reduced_dim)
333187
)
@@ -492,14 +346,8 @@ def parallel_tridiag_eigen(
492346
local_count = eig_val.size # or however many elements you’ll send
493347

494348
# 2) Everyone exchanges counts via allgather:
495-
# this returns a Python list of length `size` on every rank
496349
recvcounts = comm.allgather(local_count)
497350

498-
# # 1) Gather all the counts to rank 0
499-
# counts = comm.gather(local_count, root=0)
500-
501-
# # 2) Broadcast that list from rank 0 back to everyone
502-
# recvcounts = comm.bcast(counts, root=0)
503351

504352
final_eig_val = np.empty(D_size, dtype=eig_val.dtype)
505353

@@ -527,15 +375,12 @@ def parallel_tridiag_eigen(
527375
)
528376

529377
# 5) Reshape on every rank (or just on rank 0 if you prefer)
530-
# total_pairs == sum(recvcounts)
531378
final_eig_vecs = flat_all.reshape(D_size, D_size)
532379
final_eig_vecs = final_eig_vecs.T
533380
index_sort = np.argsort(final_eig_val)
534381
final_eig_vecs = final_eig_vecs[:, index_sort]
535382
final_eig_val = final_eig_val[index_sort]
536-
# if rank==0:
537-
# print(final_eig_val)
538-
# print(final_eig_vecs)
383+
539384
return final_eig_val, final_eig_vecs
540385

541386

@@ -556,7 +401,143 @@ def parallel_eigen(
556401
)
557402
return eigvals, eigvecs
558403

559-
404+
405+
406+
# ORIGINAL DEFLATION IMPLEMENTATION
407+
# @profile
408+
# def deflate_eigenpairs(D, v, beta, tol_factor=1e-12):
409+
# """
410+
# Applying the deflation step to the divide and conquer algorithm to reduce the size of
411+
# the system to be solved and find the trivial eigenvalues and eigenvectors.
412+
# Notice that we cannot use jit since we work with scipy sparse matrices.
413+
# Input:
414+
# -D: element on the diagonal.
415+
# -v: rank one update vector
416+
# -beta: scalar value that multiplies the rank one update
417+
418+
# Output:
419+
# -deflated_eigvals: Vector containing the trivial eigenvalue
420+
# -deflated_eigvecs: Matrix containing the trivial eigenvectors
421+
# -D_keep: element on the diagonal after the deflation.
422+
# -v_keep: rank one update vector after the deflation
423+
# -P_3 @ P_2 @ P: Permutation matrix
424+
# """
425+
426+
# norm_T = np.linalg.norm(
427+
# np.diag(D) + beta * np.outer(v, v)
428+
# ) # Normalize tolerance to matrix size
429+
# # abs_v=np.abs(v)
430+
# # norm_T=0.5*(np.max(np.abs(D)) + beta * np.max(abs_v)*np.sum(abs_v))
431+
# tol = tol_factor * norm_T
432+
# keep_indices = []
433+
# deflated_eigvals = np.zeros_like(D)
434+
# deflated_eigvecs = []
435+
# deflated_indices = []
436+
# reduced_dimension = len(D)
437+
438+
# # Zero component deflation
439+
# e_vec = np.zeros(len(D))
440+
# j = 0
441+
# for i in range(len(D)):
442+
# if abs(v[i]) < tol:
443+
# deflated_eigvals[j] = D[i]
444+
# e_vec[i] = 1.0 # Standard basis vector
445+
# deflated_eigvecs.append(e_vec.copy())
446+
# deflated_indices.append(i)
447+
# e_vec[i] = 0.0
448+
# j += 1
449+
# else:
450+
# keep_indices.append(i)
451+
452+
# new_order = keep_indices + deflated_indices
453+
# reduced_dimension = len(keep_indices)
454+
455+
# # Create permutation matrix P (use sparse)
456+
# n = len(D)
457+
# P = sp.lil_array((n, n)) # Use sparse format
458+
# P_2 = sp.eye(n, format="csr") # Efficient multiplication format
459+
# P_3 = sp.lil_array((n, n)) # Permutation matrix
460+
461+
# for new_pos, old_pos in enumerate(new_order):
462+
# P[new_pos, old_pos] = 1 # Assign 1s to move elements accordingly
463+
464+
# P = P.tocsr() # Convert to CSR format for fast multiplication
465+
466+
# D_keep = D[keep_indices]
467+
# v_keep = v[keep_indices]
468+
469+
# to_check = set(
470+
# np.arange(reduced_dimension, dtype=np.int64)
471+
# ) # Use set for fast lookup
472+
# rotation_matrix = []
473+
# vec_idx_list = [] # Use list instead of np.append()
474+
475+
# to_check_copy = list(to_check) # Convert to list for iteration
476+
477+
# for i in to_check_copy[:-1]:
478+
# if i not in to_check:
479+
# continue # Skip if the index was removed
480+
481+
# # Find duplicates in a vectorized way
482+
# idx_duplicate_vec = np.where(np.abs(D_keep[i + 1 :] - D_keep[i]) < tol)[0]
483+
# if len(idx_duplicate_vec):
484+
# idx_duplicate_vec += i + 1 # Adjust indices
485+
486+
# for idx_duplicate in idx_duplicate_vec:
487+
# to_check.discard(idx_duplicate) # O(1) removal instead of np.delete()
488+
489+
# # Compute Givens rotation parameters
490+
# r = np.hypot(
491+
# v_keep[i], v_keep[idx_duplicate]
492+
# ) # More stable than sqrt(x^2 + y^2)
493+
# c = v_keep[i] / r
494+
# s = -v_keep[idx_duplicate] / r
495+
496+
# v_keep[i] = r
497+
# v_keep[idx_duplicate] = 0
498+
499+
# # Store transformation
500+
# rotation_matrix.append((i, idx_duplicate, c, s))
501+
# deflated_eigvals[j] = D_keep[i]
502+
# j += 1
503+
504+
# # Efficient eigenvector computation
505+
# eig_vec_local = np.zeros(n)
506+
# eig_vec_local[idx_duplicate] = c
507+
# eig_vec_local[i] = s
508+
# deflated_eigvecs.append(P.T @ eig_vec_local)
509+
510+
# vec_idx_list.append(
511+
# idx_duplicate
512+
# ) # Use list instead of slow np.append()
513+
514+
# new_order = np.concatenate(
515+
# (list(to_check), vec_idx_list)
516+
# ) # Efficient concatenation
517+
# new_order = new_order.astype(int)
518+
# deflated_eigvals = deflated_eigvals[:j]
519+
520+
# # Apply Givens rotations
521+
# for i, j, c, s in rotation_matrix:
522+
# G = sp.eye(n, n, format="csr") # Sparse identity matrix
523+
# G[i, i] = G[j, j] = c
524+
# G[i, j] = -s
525+
# G[j, i] = s
526+
# P_2 = P_2 @ G # Sparse multiplication
527+
528+
# for new_pos, old_pos in enumerate(new_order):
529+
# P_3[new_pos, old_pos] = 1 # Assign 1s
530+
531+
# P_3 = P_3.tocsr()
532+
533+
# to_check = [i for i in to_check]
534+
# reduced_dimension = len(to_check)
535+
# D_keep = D_keep[to_check]
536+
# v_keep = v_keep[to_check]
537+
538+
# return deflated_eigvals, np.array(deflated_eigvecs), D_keep, v_keep, P_3 @ P_2 @ P
539+
540+
# This portion of the script was used during the intense debugging phase.
560541
if __name__ == "__main__":
561542
comm = MPI.COMM_WORLD
562543
rank = comm.Get_rank()

0 commit comments

Comments
 (0)