Skip to content

Commit bc2d8f4

Browse files
committed
chore: improve type hints in network_weights, .opt_probs
1 parent c88f9ab commit bc2d8f4

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/mlrose_ky/neural/fitness/network_weights.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
if len(node_list) < 2:
7070
raise ValueError("node_list must contain at least 2 elements.")
7171
if not isinstance(bias, bool):
72-
raise TypeError(f"bias must be a bool (True or False).")
72+
raise TypeError("bias must be a bool (True or False).")
7373
if np.shape(X)[1] != (node_list[0] - bias):
7474
raise ValueError(f"The number of columns in X must equal {node_list[0] - bias}.")
7575
if np.shape(y)[1] != node_list[-1]:
@@ -79,27 +79,27 @@ def __init__(
7979
if learning_rate <= 0:
8080
raise ValueError("learning_rate must be greater than 0.")
8181

82-
self.X = X
83-
self.y_true = y
84-
self.node_list = node_list
85-
self.activation = activation
86-
self.bias = bias
87-
self.is_classifier = is_classifier
88-
self.learning_rate = learning_rate
82+
self.X: np.ndarray = X
83+
self.y_true: np.ndarray = y
84+
self.node_list: list[int] = node_list
85+
self.activation: Callable = activation
86+
self.bias: bool = bias
87+
self.is_classifier: bool = is_classifier
88+
self.learning_rate: float = learning_rate
8989

9090
if self.is_classifier:
91-
self.loss = skm.log_loss
92-
self.output_activation = act.sigmoid if np.shape(self.y_true)[1] == 1 else act.softmax
91+
self.loss: skm.log_loss = skm.log_loss
92+
self.output_activation: Callable = act.sigmoid if np.shape(self.y_true)[1] == 1 else act.softmax
9393
else:
94-
self.loss = skm.mean_squared_error
95-
self.output_activation = act.identity
94+
self.loss: skm.mean_squared_error = skm.mean_squared_error
95+
self.output_activation: Callable = act.identity
9696

97-
self.inputs_list = []
98-
self.y_pred = y
99-
self.weights = []
100-
self.prob_type = "continuous"
97+
self.inputs_list: list[np.ndarray] = []
98+
self.y_pred: np.ndarray = y
99+
self.weights: list[np.ndarray] = []
100+
self.prob_type: str = "continuous"
101101

102-
self.nodes = sum(node_list[i] * node_list[i + 1] for i in range(len(node_list) - 1))
102+
self.nodes: int = sum(node_list[i] * node_list[i + 1] for i in range(len(node_list) - 1))
103103

104104
def evaluate(self, state: np.ndarray) -> float:
105105
"""

src/mlrose_ky/opt_probs/continuous_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, length: int, fitness_fn: Any, maximize: bool = True, min_val:
6060
if step <= 0:
6161
raise ValueError("step size must be positive.")
6262
if (max_val - min_val) < step:
63-
raise ValueError(f"step size must be less than (max_val - min_val).")
63+
raise ValueError("step size must be less than (max_val - min_val).")
6464

6565
self.prob_type: str = "continuous"
6666
self.min_val: float = min_val

src/mlrose_ky/opt_probs/discrete_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def find_neighbors(self) -> None:
226226

227227
def find_sample_order(self) -> None:
228228
"""Determine order in which to generate sample vector elements."""
229-
sample_order = []
229+
sample_order: list[int] = []
230230
last = [0]
231231
parent = self.parent_nodes
232232

0 commit comments

Comments
 (0)