Skip to content

Commit 23a26ff

Browse files
committed
merged action of parallel_tridiag
2 parents b291022 + b259889 commit 23a26ff

File tree

1 file changed

+98
-103
lines changed

1 file changed

+98
-103
lines changed

src/pyclassify/parallel_tridiag_eigen.py

Lines changed: 98 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from mpi4py import MPI
22
import numpy as np
33
from time import time
4-
from pyclassify.cxx_utils import QR_algorithm, secular_solver_cxx, deflate_eigenpairs_cxx
4+
from pyclassify.cxx_utils import (
5+
QR_algorithm,
6+
secular_solver_cxx,
7+
deflate_eigenpairs_cxx,
8+
)
59
from pyclassify.zero_finder import secular_solver_python as secular_solver
610
from line_profiler import profile, LineProfiler
711
import scipy.sparse as sp
@@ -158,19 +162,22 @@ def deflate_eigenpairs(D, v, beta, tol_factor=1e-12):
158162

159163
return deflated_eigvals, np.array(deflated_eigvecs), D_keep, v_keep, P_3 @ P_2 @ P
160164

165+
161166
def find_interval_extreme(total_dimension, n_processor):
162167
"""
163168
Computes the intervals for vector for being scattered.
164169
Input:
165170
-total_dimension: the dimension of the vector that has to be splitted
166171
-n_processor: the number of processor to which the scatter vector has to be sent
167-
172+
168173
"""
169174

170-
base= total_dimension // n_processor
175+
base = total_dimension // n_processor
171176
rest = total_dimension % n_processor
172177

173-
counts = np.array([base + 1 if i < rest else base for i in range(n_processor)], dtype=int)
178+
counts = np.array(
179+
[base + 1 if i < rest else base for i in range(n_processor)], dtype=int
180+
)
174181
displs = np.insert(np.cumsum(counts), 0, 0)[:-1]
175182

176183
return counts, displs
@@ -210,7 +217,6 @@ def parallel_tridiag_eigen(
210217
n = len(diag)
211218
prof_filename = f"Profile_folder/profile.rank{current_rank}.depth{depth}.lprof"
212219

213-
214220
if n <= min_size or size == 1:
215221
eigvals, eigvecs = QR_algorithm(diag, off, 1e-16, max_iterQR)
216222
eigvecs = np.array(eigvecs)
@@ -245,8 +251,8 @@ def parallel_tridiag_eigen(
245251
depth=depth + 1,
246252
profiler=profiler,
247253
)
248-
eigvals_right=None
249-
eigvecs_right=None
254+
eigvals_right = None
255+
eigvecs_right = None
250256
else:
251257
eigvals_right, eigvecs_right = parallel_tridiag_eigen(
252258
diag2,
@@ -257,50 +263,49 @@ def parallel_tridiag_eigen(
257263
depth=depth + 1,
258264
profiler=profiler,
259265
)
260-
eigvals_left=None
261-
eigvecs_left=None
262-
266+
eigvals_left = None
267+
eigvecs_left = None
263268

264269
# 1) Identify the two “root” ranks in MPI.COMM_WORLD
265-
left_size = size // 2 if size>1 else 1
266-
root_left = 0
270+
left_size = size // 2 if size > 1 else 1
271+
root_left = 0
267272
root_right = left_size
268-
other_root = root_right if color==0 else root_left
273+
other_root = root_right if color == 0 else root_left
269274

270-
# now exchange between the two roots
275+
# now exchange between the two roots
271276
if subcomm.Get_rank() == 0:
272-
send_data = (eigvals_left, eigvecs_left) \
273-
if color == 0 else (eigvals_right, eigvecs_right)
277+
send_data = (
278+
(eigvals_left, eigvecs_left)
279+
if color == 0
280+
else (eigvals_right, eigvecs_right)
281+
)
274282
recv_data = comm.sendrecv(
275-
send_data, dest=other_root, sendtag=depth,
276-
source=other_root, recvtag=depth
283+
send_data, dest=other_root, sendtag=depth, source=other_root, recvtag=depth
277284
)
278285
# unpack
279286
if color == 0:
280287
eigvals_right, eigvecs_right = recv_data
281288
else:
282-
eigvals_left, eigvecs_left = recv_data
289+
eigvals_left, eigvecs_left = recv_data
283290

284-
eigvals_left = subcomm.bcast(eigvals_left, root=0)
285-
eigvecs_left = subcomm.bcast(eigvecs_left, root=0)
291+
eigvals_left = subcomm.bcast(eigvals_left, root=0)
292+
eigvecs_left = subcomm.bcast(eigvecs_left, root=0)
286293
eigvals_right = subcomm.bcast(eigvals_right, root=0)
287294
eigvecs_right = subcomm.bcast(eigvecs_right, root=0)
288295

289-
290296
# if rank == 0:
291297
# eigvals_right = comm.recv(source=left_size, tag=77)
292298
# eigvecs_right = comm.recv(source=left_size, tag=78)
293299
# elif rank == left_size:
294300
# comm.send(eigvals_right, dest=0, tag=77)
295301
# comm.send(eigvecs_right, dest=0, tag=78)
296-
297302

298303
if rank == 0:
299304

300305
# Merge Step
301306
n1 = len(eigvals_left)
302307
D = np.concatenate((eigvals_left, eigvals_right))
303-
D_size=D.size
308+
D_size = D.size
304309
v_vec = np.concatenate((eigvecs_left[-1, :], eigvecs_right[0, :]))
305310

306311
deflated_eigvals, deflated_eigvecs, D_keep, v_keep, P = deflate_eigenpairs_cxx(
@@ -329,37 +334,37 @@ def parallel_tridiag_eigen(
329334
beta, D_keep[idx], v_keep[idx] , np.arange(reduced_dim)
330335
)
331336
lam = np.array(lam)
332-
delta=np.array(delta)
333-
changing_position=np.array(changing_position)
334-
337+
delta = np.array(delta)
338+
changing_position = np.array(changing_position)
339+
# #diff=lam_s-lam
335340
else:
336341
lam = np.array([])
337-
342+
338343
counts, displs = find_interval_extreme(reduced_dim, size)
339344

340345
else:
341346
counts = None
342347
displs = None
343-
lam=None
344-
D_keep=None
345-
v_keep=None
346-
delta=None
347-
reduced_dim=None
348-
D_size=None
349-
changing_position=None
350-
type_lam=None
351-
type_D= None
352-
P=None
353-
idx_inv=None
354-
n1=None
348+
lam = None
349+
D_keep = None
350+
v_keep = None
351+
delta = None
352+
reduced_dim = None
353+
D_size = None
354+
changing_position = None
355+
type_lam = None
356+
type_D = None
357+
P = None
358+
idx_inv = None
359+
n1 = None
355360
deflated_eigvals = None
356361
deflated_eigvecs = None
357-
362+
358363
counts = comm.bcast(counts, root=0)
359364
displs = comm.bcast(displs, root=0)
360-
lam=comm.bcast(lam, root=0)
361-
D_keep=comm.bcast(D_keep, root=0)
362-
v_keep=comm.bcast(v_keep, root=0)
365+
lam = comm.bcast(lam, root=0)
366+
D_keep = comm.bcast(D_keep, root=0)
367+
v_keep = comm.bcast(v_keep, root=0)
363368
my_count = counts[rank]
364369
type_lam=comm.bcast(lam.dtype, root=0)
365370

@@ -389,28 +394,30 @@ def parallel_tridiag_eigen(
389394
# now do the scatterv
390395
comm.Scatterv(
391396
[lam, counts, displs, mpi_type], # send tuple, only root’s lam is used here
392-
lam_buffer, # recvbuf on every rank
393-
root=0
397+
lam_buffer, # recvbuf on every rank
398+
root=0,
394399
)
395400

396401

397-
initial_point=displs[rank]
402+
initial_point = displs[rank]
398403

399404
for k_rel in range(lam_buffer.size):
400-
k=k_rel+initial_point
405+
k = k_rel + initial_point
401406
numerator = lam - D_keep[k]
402407
denominator = np.concatenate((D_keep[:k], D_keep[k + 1 :])) - D_keep[k]
403408
numerator[:-1] = numerator[:-1] / denominator
404409
v_keep[k] = np.sqrt(np.abs(np.prod(numerator) / beta)) * np.sign(v_keep[k])
405410

406411
# eigenpairs = []
407412

408-
eig_vecs=np.empty((D_size, my_count),dtype=type_lam )
409-
eig_val=np.empty(my_count, dtype=type_lam )
413+
eig_vecs = np.empty((D_size, my_count), dtype=type_lam)
414+
eig_val = np.empty(my_count, dtype=type_lam)
410415

411416
for j_rel in range(lam_buffer.size):
412417
y = np.zeros(D_size)
413-
j=j_rel + initial_point
418+
# y[:reduced_dim]=v_keep/(lam[j]-D_keep)
419+
# y /= np.linalg.norm(y)
420+
j = j_rel + initial_point
414421
diff = lam[j] - D_keep
415422
diff[idx_inv[changing_position[j]]] = delta[j]
416423
y[:reduced_dim] = v_keep / (diff)
@@ -421,75 +428,72 @@ def parallel_tridiag_eigen(
421428
y = P.T @ y
422429
vec = np.concatenate((eigvecs_left @ y[:n1], eigvecs_right @ y[n1:]))
423430
vec /= np.linalg.norm(vec)
424-
eig_vecs[:, j_rel]=vec
425-
eig_val[j_rel]=lam[j]
426-
#eigenpairs.append((lam[j], vec))
431+
eig_vecs[:, j_rel] = vec
432+
eig_val[j_rel] = lam[j]
433+
# eigenpairs.append((lam[j], vec))
427434

428-
if reduced_dim<D_size:
435+
if reduced_dim < D_size:
429436

430-
if rank==0:
431-
le_deflation=len(deflated_eigvals)
437+
if rank == 0:
438+
le_deflation = len(deflated_eigvals)
432439
counts, displs = find_interval_extreme(le_deflation, size)
433440

434441
counts = comm.bcast(counts, root=0)
435442
displs = comm.bcast(displs, root=0)
436443
my_count = counts[rank]
437-
438-
deflated_eigvals_buffer=np.empty(my_count, dtype=type_lam)
444+
445+
deflated_eigvals_buffer = np.empty(my_count, dtype=type_lam)
439446
if rank == 0:
440447
char = deflated_eigvals.dtype.char
441-
type_eig=deflated_eigvals.dtype
448+
type_eig = deflated_eigvals.dtype
442449
else:
443450
char = None
444-
type_eig= None
451+
type_eig = None
445452

446453
# now everyone learns the character code:
447454
char = comm.bcast(char, root=0)
448455
type_eig = comm.bcast(type_eig, root=0)
449456
comm.Scatterv(
450-
[deflated_eigvals, counts, displs, MPI._typedict[char]],
451-
deflated_eigvals_buffer,
452-
root=0,
457+
[deflated_eigvals, counts, displs, MPI._typedict[char]],
458+
deflated_eigvals_buffer,
459+
root=0,
453460
)
454-
if rank==0:
455-
_, k=deflated_eigvecs.shape
461+
if rank == 0:
462+
_, k = deflated_eigvecs.shape
456463
else:
457464
mat = None
458-
k=None
459-
k=comm.bcast(k, root=0)
465+
k = None
466+
k = comm.bcast(k, root=0)
460467
# each row of `mat` is one deflated vec
461468
sendcounts = [c * k for c in counts]
462-
senddispls = [d * k for d in displs]
463-
deflated_eigvecs_buffer = np.empty( (my_count, k), dtype=type_eig)
469+
senddispls = [d * k for d in displs]
470+
deflated_eigvecs_buffer = np.empty((my_count, k), dtype=type_eig)
464471
if rank == 0:
465-
flat_send = deflated_eigvecs.copy().flatten() # shape (M*k,)
472+
flat_send = deflated_eigvecs.copy().flatten() # shape (M*k,)
466473
sendbuf = [flat_send, sendcounts, senddispls, MPI._typedict[char]]
467474
else:
468475
sendbuf = None
469-
470476

471477
# now scatter to everyone
472478
comm.Scatterv(
473-
sendbuf, # only meaningful on rank 0
474-
deflated_eigvecs_buffer, # each rank’s recv‐buffer of length k × my_count
475-
root=0
479+
sendbuf, # only meaningful on rank 0
480+
deflated_eigvecs_buffer, # each rank’s recv‐buffer of length k × my_count
481+
root=0,
476482
)
477-
478483

479-
#local_final_vecs = np.empty((k, my_count), dtype=deflated_eigvecs.dtype)
484+
# local_final_vecs = np.empty((k, my_count), dtype=deflated_eigvecs.dtype)
480485
for i in range(my_count):
481486
small_vec = deflated_eigvecs_buffer[i]
482487
# apply the two block Q’s:
483-
left_part = eigvecs_left @ small_vec[:n1]
488+
left_part = eigvecs_left @ small_vec[:n1]
484489
right_part = eigvecs_right @ small_vec[n1:]
485490
local_final_vecs = np.concatenate((left_part, right_part))
486491
local_final_vecs = local_final_vecs.reshape(k, 1)
487-
eig_val=np.append(eig_val, deflated_eigvals_buffer[i])
488-
eig_vecs=np.concatenate([eig_vecs, local_final_vecs], axis=1)
492+
eig_val = np.append(eig_val, deflated_eigvals_buffer[i])
493+
eig_vecs = np.concatenate([eig_vecs, local_final_vecs], axis=1)
489494

490-
491495
# 1) Each rank computes its local length:
492-
local_count = eig_val.size # or however many elements you’ll send
496+
local_count = eig_val.size # or however many elements you’ll send
493497

494498
# 2) Everyone exchanges counts via allgather:
495499
# this returns a Python list of length `size` on every rank
@@ -500,48 +504,39 @@ def parallel_tridiag_eigen(
500504

501505
# # 2) Broadcast that list from rank 0 back to everyone
502506
# recvcounts = comm.bcast(counts, root=0)
503-
504507

505508
final_eig_val = np.empty(D_size, dtype=eig_val.dtype)
506509

507-
508-
509-
displs=np.append([0], np.cumulative_sum(recvcounts[:-1]).astype(int))
510+
displs = np.append([0], np.cumulative_sum(recvcounts[:-1]).astype(int))
510511

511512
mpi_t = MPI._typedict[eig_val.dtype.char]
512-
comm.Allgatherv(
513-
[eig_val, mpi_t],
514-
[final_eig_val, recvcounts, displs, mpi_t]
515-
)
516-
513+
comm.Allgatherv([eig_val, mpi_t], [final_eig_val, recvcounts, displs, mpi_t])
514+
517515
# 1) Flatten local eigenvector block
518-
# eig_vecs has shape (D_size, local_count)
519-
local_flat = eig_vecs.T.flatten()
516+
# eig_vecs has shape (D_size, local_count)
517+
local_flat = eig_vecs.T.flatten()
520518

521519
# 2) Build sendcounts & displacements for the flattened arrays
522520
sendcounts_vecs = [c * D_size for c in recvcounts]
523-
senddispls_vecs = [d * D_size for d in displs]
521+
senddispls_vecs = [d * D_size for d in displs]
524522

525523
# 3) Allocate full receive buffer on every rank
526524
flat_all = np.empty(sum(sendcounts_vecs), dtype=eig_vecs.dtype)
527525

528526
# 4) Perform the all-gather-variable-counts
529527
mpi_tvec = MPI._typedict[eig_vecs.dtype.char]
530528
comm.Allgatherv(
531-
[local_flat, mpi_tvec], # sendbuf
532-
[flat_all,
533-
sendcounts_vecs,
534-
senddispls_vecs,
535-
mpi_tvec] # recvbuf spec
529+
[local_flat, mpi_tvec], # sendbuf
530+
[flat_all, sendcounts_vecs, senddispls_vecs, mpi_tvec], # recvbuf spec
536531
)
537532

538533
# 5) Reshape on every rank (or just on rank 0 if you prefer)
539534
# total_pairs == sum(recvcounts)
540535
final_eig_vecs = flat_all.reshape(D_size, D_size)
541-
final_eig_vecs=final_eig_vecs.T
542-
index_sort=np.argsort(final_eig_val)
543-
final_eig_vecs=final_eig_vecs[:, index_sort]
544-
final_eig_val=final_eig_val[index_sort]
536+
final_eig_vecs = final_eig_vecs.T
537+
index_sort = np.argsort(final_eig_val)
538+
final_eig_vecs = final_eig_vecs[:, index_sort]
539+
final_eig_val = final_eig_val[index_sort]
545540
# if rank==0:
546541
# print(final_eig_val)
547542
# print(final_eig_vecs)

0 commit comments

Comments
 (0)