Skip to content

Commit 67643a1

Browse files
committed
WIP: check whether new workflow is ok
1 parent bf1780f commit 67643a1

File tree

6 files changed

+169
-50
lines changed

6 files changed

+169
-50
lines changed

.github/workflows/tests.yml

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,27 @@ on:
55
branches:
66
- main
77
- final_project
8+
- gpu_branch
89

910
pull_request:
1011
branches:
1112
- main
1213
- final_project
14+
- gpu_branch
1315

1416
jobs:
15-
docs:
17+
test:
1618
name: Testing using pytest
1719
runs-on: ubuntu-latest
20+
container:
21+
image: cupy/cupy:v13.4.0
22+
1823
steps:
1924
- uses: actions/checkout@v3
2025

21-
- name: Set up Python
22-
uses: actions/setup-python@v3
23-
with:
24-
python-version: 3.10.16
25-
26-
- name: Set up Python environment
27-
run: |
28-
python -m venv venv
29-
source venv/bin/activate
30-
31-
- name: Install dependencies
26+
- name: Install additional dependencies
3227
run: |
28+
python3 -m venv venv
3329
source venv/bin/activate
3430
python -m pip install .
3531

shell/submit.sbatch

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#SBATCH --gpus=1
77
#SBATCH --ntasks-per-node=1
88
#SBATCH --gpus-per-task=1
9-
#SBATCH --mem=32000
9+
#SBATCH --mem=2000
1010
#SBATCH --time=01:00:00
1111
#SBATCH --mail-user=ltomada@sissa.it
1212
#SBATCH --output=%x.o%j.%N
@@ -24,11 +24,13 @@ echo '------------------------------------------------------'
2424
#
2525

2626
cd $SLURM_SUBMIT_DIR
27-
export SLURM_NTASKS_PER_NODE=2 # due to Ulysses's bug
27+
export SLURM_NTASKS_PER_NODE=1 # due to Ulysses's bug
2828

2929
module load cuda/12.1
3030
conda init
31+
source ~/.bashrc
3132
conda activate ~/miniconda3/envs/devtools_scicomp
32-
33+
pip install cupy-cuda12x
34+
python -m pip freeze > requirements.txt
3335
# Run the script
34-
python scripts/run.py fit --config=experiments/config.yaml
36+
#python scripts/run.py fit --config=experiments/config.yaml

src/pyclassify/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
__all__ = [
22
"eigenvalues_np",
33
"eigenvalues_sp",
4+
"eigenvalues_cp",
45
"power_method",
56
"power_method_numba",
67
]
78

89
from .eigenvalues import (
910
eigenvalues_np,
1011
eigenvalues_sp,
12+
eigenvalues_cp,
1113
power_method,
1214
power_method_numba,
13-
)
15+
)

src/pyclassify/eigenvalues.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import numpy as np
22
import scipy.sparse as sp
3+
import scipy.linalg as spla
34
from line_profiler import profile
45
from numpy.linalg import eig, eigh
5-
from pyclassify.utils import check_A_square_matrix, power_method_numba_helper
6+
from pyclassify.utils import check_square_matrix, check_symm_square, power_method_numba_helper
7+
import cupy as cp
8+
cp.cuda.Device(0)
9+
import cupy.linalg as cpla
10+
import cupyx.scipy.sparse as cpsp
611

712

813
@profile
914
def eigenvalues_np(A, symmetric=True):
1015
"""
1116
Compute the eigenvalues of a square matrix using NumPy's `eig` or `eigh` function.
1217
13-
This function checks if the input matrix is square (and is actually a matrix) using 'check_A_square_matrix', and then computes its eigenvalues.
18+
This function checks if the input matrix is square (and is actually a matrix) using 'check_square_matrix', and then computes its eigenvalues.
1419
If the matrix is symmetric, it uses `eigh` (which is more efficient for symmetric matrices).
1520
Otherwise, it uses `eig`.
1621
@@ -27,7 +32,7 @@ def eigenvalues_np(A, symmetric=True):
2732
TypeError: If the input is not a NumPy array or a SciPy sparse matrix.
2833
ValueError: If number of rows != number of columns.
2934
"""
30-
check_A_square_matrix(A)
35+
check_square_matrix(A)
3136
eigenvalues, _ = eigh(A) if symmetric else eig(A)
3237
return eigenvalues
3338

@@ -53,22 +58,46 @@ def eigenvalues_sp(A, symmetric=True):
5358
TypeError: If the input is not a NumPy array or a SciPy sparse matrix.
5459
ValueError: If number of rows != number of columns.
5560
"""
56-
check_A_square_matrix(A)
61+
check_square_matrix(A)
5762
eigenvalues, _ = (
58-
sp.linalg.eigsh(A, k=A.shape[0] - 1)
63+
spla.eigsh(A, k=A.shape[0] - 1)
5964
if symmetric
60-
else sp.linalg.eigs(A, k=A.shape[0] - 1)
65+
else spla.eigs(A, k=A.shape[0] - 1)
6166
)
6267
return eigenvalues
6368

6469

70+
@profile
71+
def eigenvalues_cp(A):
72+
"""
73+
Compute the eigenvalues of a sparse matrix using CuPy's `eigsh` function.
74+
75+
This function checks if the input matrix is square and symmetric, then computes its eigenvalues using
76+
CupY's sparse linear algebra solvers. For symmetric matrices, it uses `eigsh` for
77+
more efficient computation.
78+
79+
Args:
80+
A (cpsp.spmatrix): A square sparse matrix whose eigenvalues are to be computed.
81+
82+
Returns:
83+
np.ndarray: An array containing the eigenvalues of the sparse matrix `A`.
84+
85+
Raises:
86+
TypeError: If the input is not a CuPy sparse symmetric matrix.
87+
ValueError: If number of rows != number of columns.
88+
"""
89+
check_symm_square(A)
90+
eigenvalues, _ = cpla.eigsh(A, k=A.shape[0] - 1)
91+
return eigenvalues
92+
93+
6594
@profile
6695
def power_method(A, max_iter=500, tol=1e-4, x=None):
6796
"""
6897
Compute the dominant eigenvalue of a square matrix using the power method.
6998
7099
Args:
71-
A (np.ndarray or sp.spmatrix): A square matrix whose dominant eigenvalue is to be computed.
100+
A (np.ndarray or sp.spmatrix or cpsp.spmatrix): A square matrix whose dominant eigenvalue is to be computed.
72101
max_iter (int, optional): Maximum number of iterations to perform (default is 500).
73102
tol (float, optional): Tolerance for convergence based on the relative change between iterations
74103
(default is 1e-4).
@@ -81,19 +110,19 @@ def power_method(A, max_iter=500, tol=1e-4, x=None):
81110
TypeError: If the input is not a NumPy array or a SciPy sparse matrix.
82111
ValueError: If number of rows != number of columns.
83112
"""
84-
check_A_square_matrix(A)
113+
check_square_matrix(A)
85114
if x is None:
86115
x = np.random.rand(A.shape[0])
87-
x /= np.linalg.norm(x)
116+
x /= spla.norm(x)
88117
x_old = x
89118

90119
iteration = 0
91120
update_norm = tol + 1
92121

93122
while iteration < max_iter and update_norm > tol:
94123
x = A @ x
95-
x /= np.linalg.norm(x)
96-
update_norm = np.linalg.norm(x - x_old) / np.linalg.norm(x_old)
124+
x /= spla.norm(x)
125+
update_norm = spla.norm(x - x_old) / spla.norm(x_old)
97126
x_old = x.copy()
98127
iteration += 1
99128

@@ -121,6 +150,44 @@ def power_method_numba(A):
121150
return power_method_numba_helper(A)
122151

123152

153+
@profile
154+
def power_method_cp(A, max_iter=500, tol=1e-4, x=None):
155+
"""
156+
Compute the dominant eigenvalue of a square matrix using the power method.
157+
158+
Args:
159+
A (cp.spmatrix): A square matrix whose dominant eigenvalue is to be computed.
160+
max_iter (int, optional): Maximum number of iterations to perform (default is 500).
161+
tol (float, optional): Tolerance for convergence based on the relative change between iterations
162+
(default is 1e-4).
163+
x (cp.ndarray, optional): Initial guess for the eigenvector. If None, a random vector is generated.
164+
165+
Returns:
166+
float: The approximated dominant eigenvalue of the matrix `A`, computed as the Rayleigh quotient x @ A @ x.
167+
168+
Raises:
169+
TypeError: If the input is not a NumPy array or a SciPy sparse matrix.
170+
ValueError: If number of rows != number of columns.
171+
"""
172+
check_square_matrix(A)
173+
if x is None:
174+
x = cp.random.rand(A.shape[0])
175+
x /= cpla.norm(x)
176+
x_old = x
177+
178+
iteration = 0
179+
update_norm = tol + 1
180+
181+
while iteration < max_iter and update_norm > tol:
182+
x = A @ x
183+
x /= cpla.norm(x)
184+
update_norm = cpla.norm(x - x_old) / cpla.norm(x_old)
185+
x_old = x.copy()
186+
iteration += 1
187+
188+
return x @ A @ x
189+
190+
124191
def Lanczos_PRO(A, q, m=None, toll=np.sqrt(np.finfo(float).eps)):
125192
"""
126193
Perform the Lanczos algorithm for symmetric matrices.
@@ -223,6 +290,5 @@ def QR_method(A_copy, tol=1e-10, max_iter=100):
223290
Q = Q@R
224291
A=A@Q
225292
iter+=1
226-
227293

228294
return np.diag(A), Q

src/pyclassify/utils.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
11
import numpy as np
22
import scipy.sparse as sp
3+
import cupy as cp
4+
import cupyx.scipy.sparse as cpsp
35
import numba
46
import os
57
import yaml
68
from line_profiler import profile
79

10+
811
# from numba.pycc import CC
912

1013

11-
def check_A_square_matrix(A):
14+
def check_square_matrix(A):
1215
"""
1316
Checks if the input matrix is a square matrix of type NumPy ndarray or SciPy sparse matrix.
1417
This is done to ensure that the input matrix `A` is both:
15-
1. Of type `np.ndarray` (NumPy array) or `scipy.sparse` (SciPy sparse matrix).
18+
1. Of type `np.ndarray` (NumPy array) or `scipy.sparse.spmatrix` (SciPy sparse matrix) or 'cupyx.scipy.sparse.spmatrix'.
1619
2. A square matrix.
1720
1821
Args:
19-
A (np.ndarray or sp.spmatrix): The matrix to be checked.
22+
A (np.ndarray or sp.spmatrix or cpsp.spmatrix): The matrix to be checked.
2023
2124
Raises:
2225
TypeError: If the input is not a NumPy array or a SciPy sparse matrix.
2326
ValueError: If number of rows != number of columns.
2427
"""
25-
if not isinstance(A, (np.ndarray, sp.spmatrix)):
26-
raise TypeError("Input matrix must be a NumPy array or a SciPy sparse matrix!")
28+
if not isinstance(A, (np.ndarray, sp.spmatrix, cpsp.spmatrix)):
29+
raise TypeError("Input matrix must be a NumPy array or a SciPy/CuPy sparse matrix!")
2730
if A.shape[0] != A.shape[1]:
2831
raise ValueError("Matrix must be square!")
2932

@@ -33,24 +36,46 @@ def make_symmetric(A):
3336
"""
3437
Ensures the input matrix is symmetric by averaging it with its transpose.
3538
36-
This function first checks if the matrix is square using the `check_A_square_matrix` function.
39+
This function first checks if the matrix is square using the `check_square_matrix` function.
3740
Then, it makes the matrix symmetric by averaging it with its transpose.
3841
3942
Args:
40-
A (np.ndarray or sp.spmatrix): The input square matrix to be made symmetric.
43+
A (np.ndarray or sp.spmatrix or cpsp.spmatrix): The input square matrix to be made symmetric.
4144
4245
Returns:
43-
np.ndarray or sp.spmatrix: The symmetric version of the input matrix.
46+
np.ndarray or sp.spmatrix or cpsp.spmatrix: The symmetric version of the input matrix.
4447
4548
Raises:
46-
TypeError: If the input matrix is not a NumPy array or SciPy sparse matrix.
49+
TypeError: If the input matrix is not a NumPy array or SciPy or CuPy sparse matrix.
4750
ValueError: If the input matrix is not square.
4851
"""
49-
check_A_square_matrix(A)
52+
check_square_matrix(A)
5053
A_sym = (A + A.T) / 2
5154
return A_sym
5255

5356

57+
def check_symm_square(A):
58+
"""
59+
Checks if the input matrix is a square symmetric matrix of type SciPy/CuPy sparse matrix.
60+
This is done to ensure that the input matrix `A` is all of the following:
61+
1. A scipy sparse matrix or CuPy sparse matrix.
62+
2. A square matrix.
63+
3. Symmetric.
64+
65+
Args:
66+
A (sp.spmatrix or cpsp.spmatrix): The matrix to be checked.
67+
68+
Raises:
69+
TypeError: If the input is not a SciPy or CuPy sparse matrix.
70+
ValueError: If number of rows != number of columns or the matrix is not symmetric.
71+
"""
72+
check_square_matrix(A)
73+
if isinstance(A, sp.spmatrix) and not np.allclose(A.toarray(), A.toarray().T):
74+
raise ValueError("Matrix must be symmetric!")
75+
elif isinstance(A, cpsp.spmatrix) and not cp.allclose(A.get(), A.get().T):
76+
raise ValueError("Matrix must be symmetric!")
77+
78+
5479
@numba.njit(nogil=True, parallel=True)
5580
def power_method_numba_helper(A, max_iter=500, tol=1e-4, x=None):
5681
"""
@@ -109,4 +134,4 @@ def read_config(file: str) -> dict:
109134
filepath = os.path.abspath(f"{file}.yaml")
110135
with open(filepath, "r") as stream:
111136
kwargs = yaml.safe_load(stream)
112-
return kwargs
137+
return kwargs

0 commit comments

Comments
 (0)