Skip to content

Commit 29ef962

Browse files
expand cluster_builder and hmatrix interface.
1 parent 1dde4ed commit 29ef962

File tree

2 files changed

+52
-9
lines changed

2 files changed

+52
-9
lines changed

src/htool/clustering/cluster_builder.hpp

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,55 @@ void declare_cluster_builder(py::module &m, const std::string &className) {
1919
py::class_<Class> py_class(m, className.c_str());
2020

2121
py_class.def(py::init<>());
22-
py_class.def("create_cluster_tree", [](Class &self, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> coordinates, int number_of_children, int size_of_partition) {
23-
return self.create_cluster_tree(coordinates.shape()[1], coordinates.shape()[0], coordinates.data(), number_of_children, size_of_partition);
24-
});
25-
py_class.def("create_cluster_tree", [](Class &self, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> coordinates, int number_of_children, int size_of_partition, py::array_t<int, py::array::f_style | py::array::forcecast> partition) {
26-
if (partition.ndim() != 2 && partition.shape()[0] != 2) {
27-
throw std::runtime_error("Wrong format for partition"); // LCOV_EXCL_LINE
28-
}
29-
return self.create_cluster_tree(coordinates.shape()[1], coordinates.shape()[0], coordinates.data(), number_of_children, size_of_partition, partition.data());
30-
});
22+
py_class.def(
23+
"create_cluster_tree", [](Class &self, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> coordinates, int number_of_children, int size_of_partition, py::array_t<int, py::array::f_style | py::array::forcecast> partition, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> radii, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> weigths) {
24+
if (partition.ndim() != 2 && partition.shape()[0] != 2) {
25+
throw std::runtime_error("Wrong format for partition"); // LCOV_EXCL_LINE
26+
}
27+
return self.create_cluster_tree(coordinates.shape()[1], coordinates.shape()[0], coordinates.data(), radii.data(), weigths.data(), number_of_children, size_of_partition, partition.data());
28+
},
29+
py::arg("coordinates"),
30+
py::arg("number_of_children"),
31+
py::arg("size_of_partition"),
32+
py::kw_only(),
33+
py::arg("partition"), // make them optional with C++17 and None
34+
py::arg("radii"), // make them optional with C++17 and None
35+
py::arg("weights") // make them optional with C++17 and None
36+
);
37+
py_class.def(
38+
"create_cluster_tree", [](Class &self, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> coordinates, int number_of_children, int size_of_partition, py::array_t<int, py::array::f_style | py::array::forcecast> partition, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> radii) {
39+
if (partition.ndim() != 2 && partition.shape()[0] != 2) {
40+
throw std::runtime_error("Wrong format for partition"); // LCOV_EXCL_LINE
41+
}
42+
return self.create_cluster_tree(coordinates.shape()[1], coordinates.shape()[0], coordinates.data(), radii.data(), nullptr, number_of_children, size_of_partition, partition.data());
43+
},
44+
py::arg("coordinates"),
45+
py::arg("number_of_children"),
46+
py::arg("size_of_partition"),
47+
py::kw_only(),
48+
py::arg("partition"), // make them optional with C++17 and None
49+
py::arg("radii") // make them optional with C++17 and None
50+
);
51+
py_class.def(
52+
"create_cluster_tree", [](Class &self, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> coordinates, int number_of_children, int size_of_partition, py::array_t<int, py::array::f_style | py::array::forcecast> partition) {
53+
if (partition.ndim() != 2 && partition.shape()[0] != 2) {
54+
throw std::runtime_error("Wrong format for partition"); // LCOV_EXCL_LINE
55+
}
56+
return self.create_cluster_tree(coordinates.shape()[1], coordinates.shape()[0], coordinates.data(), number_of_children, size_of_partition, partition.data());
57+
},
58+
py::arg("coordinates"),
59+
py::arg("number_of_children"),
60+
py::arg("size_of_partition"),
61+
py::kw_only(),
62+
py::arg("partition"));
63+
64+
py_class.def(
65+
"create_cluster_tree", [](Class &self, py::array_t<CoordinatePrecision, py::array::f_style | py::array::forcecast> coordinates, int number_of_children, int size_of_partition) {
66+
return self.create_cluster_tree(coordinates.shape()[1], coordinates.shape()[0], coordinates.data(), number_of_children, size_of_partition);
67+
},
68+
py::arg("coordinates"),
69+
py::arg("number_of_children"),
70+
py::arg("size_of_partition"));
3171
py_class.def("set_minclustersize", &Class::set_minclustersize);
3272
py_class.def("set_direction_computation_strategy", &Class::set_direction_computation_strategy);
3373
py_class.def("set_splitting_strategy", &Class::set_splitting_strategy);

src/htool/hmatrix/hmatrix.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ void declare_HMatrix(py::module &m, const std::string &className) {
4545
py_class.def("get_tree_parameters", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix) { return htool::get_tree_parameters(hmatrix); });
4646
py_class.def("get_local_information", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix) { return htool::get_hmatrix_information(hmatrix); });
4747
py_class.def("get_distributed_information", [](const HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix, MPI_Comm_wrapper comm) { return htool::get_distributed_hmatrix_information(hmatrix, comm); });
48+
py_class.def("get_target_cluster",&HMatrix<CoefficientPrecision, CoordinatePrecision>::get_target_cluster,py::return_value_policy::reference_internal);
49+
py_class.def("get_source_cluster",&HMatrix<CoefficientPrecision, CoordinatePrecision>::get_source_cluster,py::return_value_policy::reference_internal);
50+
4851

4952
m.def("recompression", &htool::recompression<CoefficientPrecision, CoordinatePrecision, std::function<void(LowRankMatrix<CoefficientPrecision> &)>>);
5053
m.def("recompression", [](HMatrix<CoefficientPrecision, CoordinatePrecision> &hmatrix) { recompression(hmatrix); });

0 commit comments

Comments
 (0)