1
1
from mpi4py import MPI
2
2
import numpy as np
3
3
from time import time
4
- from pyclassify .cxx_utils import QR_algorithm , secular_solver_cxx , deflate_eigenpairs_cxx
4
+ from pyclassify .cxx_utils import (
5
+ QR_algorithm ,
6
+ secular_solver_cxx ,
7
+ deflate_eigenpairs_cxx ,
8
+ )
5
9
from pyclassify .zero_finder import secular_solver_python as secular_solver
6
10
from line_profiler import profile , LineProfiler
7
11
import scipy .sparse as sp
@@ -158,19 +162,22 @@ def deflate_eigenpairs(D, v, beta, tol_factor=1e-12):
158
162
159
163
return deflated_eigvals , np .array (deflated_eigvecs ), D_keep , v_keep , P_3 @ P_2 @ P
160
164
165
+
161
166
def find_interval_extreme (total_dimension , n_processor ):
162
167
"""
163
168
Computes the intervals for vector for being scattered.
164
169
Input:
165
170
-total_dimension: the dimension of the vector that has to be splitted
166
171
-n_processor: the number of processor to which the scatter vector has to be sent
167
-
172
+
168
173
"""
169
174
170
- base = total_dimension // n_processor
175
+ base = total_dimension // n_processor
171
176
rest = total_dimension % n_processor
172
177
173
- counts = np .array ([base + 1 if i < rest else base for i in range (n_processor )], dtype = int )
178
+ counts = np .array (
179
+ [base + 1 if i < rest else base for i in range (n_processor )], dtype = int
180
+ )
174
181
displs = np .insert (np .cumsum (counts ), 0 , 0 )[:- 1 ]
175
182
176
183
return counts , displs
@@ -210,7 +217,6 @@ def parallel_tridiag_eigen(
210
217
n = len (diag )
211
218
prof_filename = f"Profile_folder/profile.rank{ current_rank } .depth{ depth } .lprof"
212
219
213
-
214
220
if n <= min_size or size == 1 :
215
221
eigvals , eigvecs = QR_algorithm (diag , off , 1e-16 , max_iterQR )
216
222
eigvecs = np .array (eigvecs )
@@ -245,8 +251,8 @@ def parallel_tridiag_eigen(
245
251
depth = depth + 1 ,
246
252
profiler = profiler ,
247
253
)
248
- eigvals_right = None
249
- eigvecs_right = None
254
+ eigvals_right = None
255
+ eigvecs_right = None
250
256
else :
251
257
eigvals_right , eigvecs_right = parallel_tridiag_eigen (
252
258
diag2 ,
@@ -257,50 +263,49 @@ def parallel_tridiag_eigen(
257
263
depth = depth + 1 ,
258
264
profiler = profiler ,
259
265
)
260
- eigvals_left = None
261
- eigvecs_left = None
262
-
266
+ eigvals_left = None
267
+ eigvecs_left = None
263
268
264
269
# 1) Identify the two “root” ranks in MPI.COMM_WORLD
265
- left_size = size // 2 if size > 1 else 1
266
- root_left = 0
270
+ left_size = size // 2 if size > 1 else 1
271
+ root_left = 0
267
272
root_right = left_size
268
- other_root = root_right if color == 0 else root_left
273
+ other_root = root_right if color == 0 else root_left
269
274
270
- # now exchange between the two roots
275
+ # now exchange between the two roots
271
276
if subcomm .Get_rank () == 0 :
272
- send_data = (eigvals_left , eigvecs_left ) \
273
- if color == 0 else (eigvals_right , eigvecs_right )
277
+ send_data = (
278
+ (eigvals_left , eigvecs_left )
279
+ if color == 0
280
+ else (eigvals_right , eigvecs_right )
281
+ )
274
282
recv_data = comm .sendrecv (
275
- send_data , dest = other_root , sendtag = depth ,
276
- source = other_root , recvtag = depth
283
+ send_data , dest = other_root , sendtag = depth , source = other_root , recvtag = depth
277
284
)
278
285
# unpack
279
286
if color == 0 :
280
287
eigvals_right , eigvecs_right = recv_data
281
288
else :
282
- eigvals_left , eigvecs_left = recv_data
289
+ eigvals_left , eigvecs_left = recv_data
283
290
284
- eigvals_left = subcomm .bcast (eigvals_left , root = 0 )
285
- eigvecs_left = subcomm .bcast (eigvecs_left , root = 0 )
291
+ eigvals_left = subcomm .bcast (eigvals_left , root = 0 )
292
+ eigvecs_left = subcomm .bcast (eigvecs_left , root = 0 )
286
293
eigvals_right = subcomm .bcast (eigvals_right , root = 0 )
287
294
eigvecs_right = subcomm .bcast (eigvecs_right , root = 0 )
288
295
289
-
290
296
# if rank == 0:
291
297
# eigvals_right = comm.recv(source=left_size, tag=77)
292
298
# eigvecs_right = comm.recv(source=left_size, tag=78)
293
299
# elif rank == left_size:
294
300
# comm.send(eigvals_right, dest=0, tag=77)
295
301
# comm.send(eigvecs_right, dest=0, tag=78)
296
-
297
302
298
303
if rank == 0 :
299
304
300
305
# Merge Step
301
306
n1 = len (eigvals_left )
302
307
D = np .concatenate ((eigvals_left , eigvals_right ))
303
- D_size = D .size
308
+ D_size = D .size
304
309
v_vec = np .concatenate ((eigvecs_left [- 1 , :], eigvecs_right [0 , :]))
305
310
306
311
deflated_eigvals , deflated_eigvecs , D_keep , v_keep , P = deflate_eigenpairs_cxx (
@@ -329,37 +334,37 @@ def parallel_tridiag_eigen(
329
334
beta , D_keep [idx ], v_keep [idx ] , np .arange (reduced_dim )
330
335
)
331
336
lam = np .array (lam )
332
- delta = np .array (delta )
333
- changing_position = np .array (changing_position )
334
-
337
+ delta = np .array (delta )
338
+ changing_position = np .array (changing_position )
339
+ # #diff=lam_s-lam
335
340
else :
336
341
lam = np .array ([])
337
-
342
+
338
343
counts , displs = find_interval_extreme (reduced_dim , size )
339
344
340
345
else :
341
346
counts = None
342
347
displs = None
343
- lam = None
344
- D_keep = None
345
- v_keep = None
346
- delta = None
347
- reduced_dim = None
348
- D_size = None
349
- changing_position = None
350
- type_lam = None
351
- type_D = None
352
- P = None
353
- idx_inv = None
354
- n1 = None
348
+ lam = None
349
+ D_keep = None
350
+ v_keep = None
351
+ delta = None
352
+ reduced_dim = None
353
+ D_size = None
354
+ changing_position = None
355
+ type_lam = None
356
+ type_D = None
357
+ P = None
358
+ idx_inv = None
359
+ n1 = None
355
360
deflated_eigvals = None
356
361
deflated_eigvecs = None
357
-
362
+
358
363
counts = comm .bcast (counts , root = 0 )
359
364
displs = comm .bcast (displs , root = 0 )
360
- lam = comm .bcast (lam , root = 0 )
361
- D_keep = comm .bcast (D_keep , root = 0 )
362
- v_keep = comm .bcast (v_keep , root = 0 )
365
+ lam = comm .bcast (lam , root = 0 )
366
+ D_keep = comm .bcast (D_keep , root = 0 )
367
+ v_keep = comm .bcast (v_keep , root = 0 )
363
368
my_count = counts [rank ]
364
369
type_lam = comm .bcast (lam .dtype , root = 0 )
365
370
@@ -389,28 +394,30 @@ def parallel_tridiag_eigen(
389
394
# now do the scatterv
390
395
comm .Scatterv (
391
396
[lam , counts , displs , mpi_type ], # send tuple, only root’s lam is used here
392
- lam_buffer , # recvbuf on every rank
393
- root = 0
397
+ lam_buffer , # recvbuf on every rank
398
+ root = 0 ,
394
399
)
395
400
396
401
397
- initial_point = displs [rank ]
402
+ initial_point = displs [rank ]
398
403
399
404
for k_rel in range (lam_buffer .size ):
400
- k = k_rel + initial_point
405
+ k = k_rel + initial_point
401
406
numerator = lam - D_keep [k ]
402
407
denominator = np .concatenate ((D_keep [:k ], D_keep [k + 1 :])) - D_keep [k ]
403
408
numerator [:- 1 ] = numerator [:- 1 ] / denominator
404
409
v_keep [k ] = np .sqrt (np .abs (np .prod (numerator ) / beta )) * np .sign (v_keep [k ])
405
410
406
411
# eigenpairs = []
407
412
408
- eig_vecs = np .empty ((D_size , my_count ),dtype = type_lam )
409
- eig_val = np .empty (my_count , dtype = type_lam )
413
+ eig_vecs = np .empty ((D_size , my_count ), dtype = type_lam )
414
+ eig_val = np .empty (my_count , dtype = type_lam )
410
415
411
416
for j_rel in range (lam_buffer .size ):
412
417
y = np .zeros (D_size )
413
- j = j_rel + initial_point
418
+ # y[:reduced_dim]=v_keep/(lam[j]-D_keep)
419
+ # y /= np.linalg.norm(y)
420
+ j = j_rel + initial_point
414
421
diff = lam [j ] - D_keep
415
422
diff [idx_inv [changing_position [j ]]] = delta [j ]
416
423
y [:reduced_dim ] = v_keep / (diff )
@@ -421,75 +428,72 @@ def parallel_tridiag_eigen(
421
428
y = P .T @ y
422
429
vec = np .concatenate ((eigvecs_left @ y [:n1 ], eigvecs_right @ y [n1 :]))
423
430
vec /= np .linalg .norm (vec )
424
- eig_vecs [:, j_rel ]= vec
425
- eig_val [j_rel ]= lam [j ]
426
- #eigenpairs.append((lam[j], vec))
431
+ eig_vecs [:, j_rel ] = vec
432
+ eig_val [j_rel ] = lam [j ]
433
+ # eigenpairs.append((lam[j], vec))
427
434
428
- if reduced_dim < D_size :
435
+ if reduced_dim < D_size :
429
436
430
- if rank == 0 :
431
- le_deflation = len (deflated_eigvals )
437
+ if rank == 0 :
438
+ le_deflation = len (deflated_eigvals )
432
439
counts , displs = find_interval_extreme (le_deflation , size )
433
440
434
441
counts = comm .bcast (counts , root = 0 )
435
442
displs = comm .bcast (displs , root = 0 )
436
443
my_count = counts [rank ]
437
-
438
- deflated_eigvals_buffer = np .empty (my_count , dtype = type_lam )
444
+
445
+ deflated_eigvals_buffer = np .empty (my_count , dtype = type_lam )
439
446
if rank == 0 :
440
447
char = deflated_eigvals .dtype .char
441
- type_eig = deflated_eigvals .dtype
448
+ type_eig = deflated_eigvals .dtype
442
449
else :
443
450
char = None
444
- type_eig = None
451
+ type_eig = None
445
452
446
453
# now everyone learns the character code:
447
454
char = comm .bcast (char , root = 0 )
448
455
type_eig = comm .bcast (type_eig , root = 0 )
449
456
comm .Scatterv (
450
- [deflated_eigvals , counts , displs , MPI ._typedict [char ]],
451
- deflated_eigvals_buffer ,
452
- root = 0 ,
457
+ [deflated_eigvals , counts , displs , MPI ._typedict [char ]],
458
+ deflated_eigvals_buffer ,
459
+ root = 0 ,
453
460
)
454
- if rank == 0 :
455
- _ , k = deflated_eigvecs .shape
461
+ if rank == 0 :
462
+ _ , k = deflated_eigvecs .shape
456
463
else :
457
464
mat = None
458
- k = None
459
- k = comm .bcast (k , root = 0 )
465
+ k = None
466
+ k = comm .bcast (k , root = 0 )
460
467
# each row of `mat` is one deflated vec
461
468
sendcounts = [c * k for c in counts ]
462
- senddispls = [d * k for d in displs ]
463
- deflated_eigvecs_buffer = np .empty ( (my_count , k ), dtype = type_eig )
469
+ senddispls = [d * k for d in displs ]
470
+ deflated_eigvecs_buffer = np .empty ((my_count , k ), dtype = type_eig )
464
471
if rank == 0 :
465
- flat_send = deflated_eigvecs .copy ().flatten () # shape (M*k,)
472
+ flat_send = deflated_eigvecs .copy ().flatten () # shape (M*k,)
466
473
sendbuf = [flat_send , sendcounts , senddispls , MPI ._typedict [char ]]
467
474
else :
468
475
sendbuf = None
469
-
470
476
471
477
# now scatter to everyone
472
478
comm .Scatterv (
473
- sendbuf , # only meaningful on rank 0
474
- deflated_eigvecs_buffer , # each rank’s recv‐buffer of length k × my_count
475
- root = 0
479
+ sendbuf , # only meaningful on rank 0
480
+ deflated_eigvecs_buffer , # each rank’s recv‐buffer of length k × my_count
481
+ root = 0 ,
476
482
)
477
-
478
483
479
- #local_final_vecs = np.empty((k, my_count), dtype=deflated_eigvecs.dtype)
484
+ # local_final_vecs = np.empty((k, my_count), dtype=deflated_eigvecs.dtype)
480
485
for i in range (my_count ):
481
486
small_vec = deflated_eigvecs_buffer [i ]
482
487
# apply the two block Q’s:
483
- left_part = eigvecs_left @ small_vec [:n1 ]
488
+ left_part = eigvecs_left @ small_vec [:n1 ]
484
489
right_part = eigvecs_right @ small_vec [n1 :]
485
490
local_final_vecs = np .concatenate ((left_part , right_part ))
486
491
local_final_vecs = local_final_vecs .reshape (k , 1 )
487
- eig_val = np .append (eig_val , deflated_eigvals_buffer [i ])
488
- eig_vecs = np .concatenate ([eig_vecs , local_final_vecs ], axis = 1 )
492
+ eig_val = np .append (eig_val , deflated_eigvals_buffer [i ])
493
+ eig_vecs = np .concatenate ([eig_vecs , local_final_vecs ], axis = 1 )
489
494
490
-
491
495
# 1) Each rank computes its local length:
492
- local_count = eig_val .size # or however many elements you’ll send
496
+ local_count = eig_val .size # or however many elements you’ll send
493
497
494
498
# 2) Everyone exchanges counts via allgather:
495
499
# this returns a Python list of length `size` on every rank
@@ -500,48 +504,39 @@ def parallel_tridiag_eigen(
500
504
501
505
# # 2) Broadcast that list from rank 0 back to everyone
502
506
# recvcounts = comm.bcast(counts, root=0)
503
-
504
507
505
508
final_eig_val = np .empty (D_size , dtype = eig_val .dtype )
506
509
507
-
508
-
509
- displs = np .append ([0 ], np .cumulative_sum (recvcounts [:- 1 ]).astype (int ))
510
+ displs = np .append ([0 ], np .cumulative_sum (recvcounts [:- 1 ]).astype (int ))
510
511
511
512
mpi_t = MPI ._typedict [eig_val .dtype .char ]
512
- comm .Allgatherv (
513
- [eig_val , mpi_t ],
514
- [final_eig_val , recvcounts , displs , mpi_t ]
515
- )
516
-
513
+ comm .Allgatherv ([eig_val , mpi_t ], [final_eig_val , recvcounts , displs , mpi_t ])
514
+
517
515
# 1) Flatten local eigenvector block
518
- # eig_vecs has shape (D_size, local_count)
519
- local_flat = eig_vecs .T .flatten ()
516
+ # eig_vecs has shape (D_size, local_count)
517
+ local_flat = eig_vecs .T .flatten ()
520
518
521
519
# 2) Build sendcounts & displacements for the flattened arrays
522
520
sendcounts_vecs = [c * D_size for c in recvcounts ]
523
- senddispls_vecs = [d * D_size for d in displs ]
521
+ senddispls_vecs = [d * D_size for d in displs ]
524
522
525
523
# 3) Allocate full receive buffer on every rank
526
524
flat_all = np .empty (sum (sendcounts_vecs ), dtype = eig_vecs .dtype )
527
525
528
526
# 4) Perform the all-gather-variable-counts
529
527
mpi_tvec = MPI ._typedict [eig_vecs .dtype .char ]
530
528
comm .Allgatherv (
531
- [local_flat , mpi_tvec ], # sendbuf
532
- [flat_all ,
533
- sendcounts_vecs ,
534
- senddispls_vecs ,
535
- mpi_tvec ] # recvbuf spec
529
+ [local_flat , mpi_tvec ], # sendbuf
530
+ [flat_all , sendcounts_vecs , senddispls_vecs , mpi_tvec ], # recvbuf spec
536
531
)
537
532
538
533
# 5) Reshape on every rank (or just on rank 0 if you prefer)
539
534
# total_pairs == sum(recvcounts)
540
535
final_eig_vecs = flat_all .reshape (D_size , D_size )
541
- final_eig_vecs = final_eig_vecs .T
542
- index_sort = np .argsort (final_eig_val )
543
- final_eig_vecs = final_eig_vecs [:, index_sort ]
544
- final_eig_val = final_eig_val [index_sort ]
536
+ final_eig_vecs = final_eig_vecs .T
537
+ index_sort = np .argsort (final_eig_val )
538
+ final_eig_vecs = final_eig_vecs [:, index_sort ]
539
+ final_eig_val = final_eig_val [index_sort ]
545
540
# if rank==0:
546
541
# print(final_eig_val)
547
542
# print(final_eig_vecs)
0 commit comments