#include #include #include #include namespace py = pybind11; template py::array_t compute_sin(const Backend &backend, py::array_t x) { ssize_t n = x.shape(0); const T *x_ptr = x.data(); py::array_t s(n); T *s_ptr = s.mutable_data(); backend.compute_sinf(static_cast(n), x_ptr, s_ptr); return s; } template py::array_t compute_cos(const Backend &backend, py::array_t x) { ssize_t n = x.shape(0); const T *x_ptr = x.data(); py::array_t c(n); T *c_ptr = c.mutable_data(); backend.compute_cosf(static_cast(n), x_ptr, c_ptr); return c; } template std::tuple, py::array_t> compute_sincos(const Backend &backend, py::array_t x) { ssize_t n = x.shape(0); const T *x_ptr = x.data(); py::array_t s(n); py::array_t c(n); backend.compute_sincosf(static_cast(n), x_ptr, s.mutable_data(), c.mutable_data()); return std::make_tuple(s, c); } template void bind_backend(py::module &m, const char *name) { py::class_>(m, name) .def(py::init<>()) .def("compute_sinf", &compute_sin) .def("compute_cosf", &compute_cos) .def("compute_sincosf", &compute_sincos); } PYBIND11_MODULE(pytrigdx, m) { py::class_>(m, "Backend") .def("init", &Backend::init); bind_backend(m, "Reference"); bind_backend>(m, "Lookup16K"); bind_backend>(m, "Lookup32K"); bind_backend>(m, "LookupAVX16K"); bind_backend>(m, "LookupAVX32K"); #if defined(TRIGDX_USE_MKL) bind_backend(m, "MKL"); #endif #if defined(TRIGDX_USE_GPU) bind_backend(m, "GPU"); #endif #if defined(TRIGDX_USE_XSIMD) bind_backend>(m, "LookupXSIMD16K"); bind_backend>(m, "LookupXSIMD32K"); #endif }