Use Type template for helper functions

This commit is contained in:
Bram Veenboer
2025-08-14 11:02:30 +02:00
parent f40c44d5dd
commit 97692cface

View File

@@ -6,42 +6,45 @@
namespace py = pybind11; namespace py = pybind11;
py::array_t<float> compute_sinf_py( template <typename T>
const Backend &backend, py::array_t<T>
py::array_t<float, py::array::c_style | py::array::forcecast> x) { compute_sin(const Backend &backend,
py::array_t<T, py::array::c_style | py::array::forcecast> x) {
ssize_t n = x.shape(0); ssize_t n = x.shape(0);
auto x_ptr = x.data(); const T *x_ptr = x.data();
py::array_t<float> s(n); py::array_t<float> s(n);
auto s_ptr = s.mutable_data(); T *s_ptr = s.mutable_data();
backend.compute_sinf(static_cast<size_t>(n), x_ptr, s_ptr); backend.compute_sinf(static_cast<size_t>(n), x_ptr, s_ptr);
return s; return s;
} }
py::array_t<float> compute_cosf_py( template <typename T>
const Backend &backend, py::array_t<T>
py::array_t<float, py::array::c_style | py::array::forcecast> x) { compute_cos(const Backend &backend,
py::array_t<T, py::array::c_style | py::array::forcecast> x) {
ssize_t n = x.shape(0); ssize_t n = x.shape(0);
auto x_ptr = x.data(); const T *x_ptr = x.data();
py::array_t<float> c(n); py::array_t<T> c(n);
auto c_ptr = c.mutable_data(); T *c_ptr = c.mutable_data();
backend.compute_cosf(static_cast<size_t>(n), x_ptr, c_ptr); backend.compute_cosf(static_cast<size_t>(n), x_ptr, c_ptr);
return c; return c;
} }
std::tuple<py::array_t<float>, py::array_t<float>> compute_sincosf_py( template <typename T>
const Backend &backend, std::tuple<py::array_t<T>, py::array_t<T>>
py::array_t<float, py::array::c_style | py::array::forcecast> x) { compute_sincos(const Backend &backend,
py::array_t<T, py::array::c_style | py::array::forcecast> x) {
ssize_t n = x.shape(0); ssize_t n = x.shape(0);
auto x_ptr = x.data(); const T *x_ptr = x.data();
py::array_t<float> s(n); py::array_t<T> s(n);
py::array_t<float> c(n); py::array_t<T> c(n);
backend.compute_sincosf(static_cast<size_t>(n), x_ptr, s.mutable_data(), backend.compute_sincosf(static_cast<size_t>(n), x_ptr, s.mutable_data(),
c.mutable_data()); c.mutable_data());
@@ -53,9 +56,9 @@ template <typename BackendType>
void bind_backend(py::module &m, const char *name) { void bind_backend(py::module &m, const char *name) {
py::class_<BackendType, Backend, std::shared_ptr<BackendType>>(m, name) py::class_<BackendType, Backend, std::shared_ptr<BackendType>>(m, name)
.def(py::init<>()) .def(py::init<>())
.def("compute_sinf", &compute_sinf_py) .def("compute_sinf", &compute_sin<float>)
.def("compute_cosf", &compute_cosf_py) .def("compute_cosf", &compute_cos<float>)
.def("compute_sincosf", &compute_sincosf_py); .def("compute_sincosf", &compute_sincos<float>);
} }
PYBIND11_MODULE(pytrigdx, m) { PYBIND11_MODULE(pytrigdx, m) {