diff --git a/mkl_umath/src/_patch_numpy.pyx b/mkl_umath/src/_patch_numpy.pyx index 598a0167..27712bb4 100644 --- a/mkl_umath/src/_patch_numpy.pyx +++ b/mkl_umath/src/_patch_numpy.pyx @@ -40,6 +40,16 @@ from libc.stdlib cimport free, malloc cnp.import_umath() +cdef extern from *: + """ + #include "numpy/ufuncobject.h" + static inline char* _get_ufunc_types(PyObject *u) { + return (char *)((PyUFuncObject *)u)->types; + } + """ + char* _get_ufunc_types(object u) noexcept + + ctypedef struct function_info: cnp.PyUFuncGenericFunction original_function cnp.PyUFuncGenericFunction patch_function @@ -53,32 +63,39 @@ cdef class _patch_impl: functions_dict = dict() def __cinit__(self): - cdef int pi, oi + cdef int pi, oi, i, nargs + cdef int expected_count + cdef char* patch_types + cdef char* orig_types - umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)] + self.functions = NULL self.functions_count = 0 + + umaths = [x for x in dir(mu) if isinstance(getattr(mu, x), np.ufunc)] + expected_count = 0 for umath in umaths: mkl_umath_func = getattr(mu, umath) - self.functions_count += mkl_umath_func.ntypes + expected_count += mkl_umath_func.ntypes self.functions = malloc( - self.functions_count * sizeof(function_info) + expected_count * sizeof(function_info) ) - func_number = 0 for umath in umaths: patch_umath = getattr(mu, umath) c_patch_umath = patch_umath c_orig_umath = getattr(np, umath) nargs = c_patch_umath.nargs + patch_types = _get_ufunc_types(c_patch_umath) + orig_types = _get_ufunc_types(c_orig_umath) for pi in range(c_patch_umath.ntypes): oi = 0 while oi < c_orig_umath.ntypes: found = True - for i in range(c_patch_umath.nargs): + for i in range(nargs): if ( - c_patch_umath.types[pi * nargs + i] - != c_orig_umath.types[oi * nargs + i] + patch_types[pi * nargs + i] + != orig_types[oi * nargs + i] ): found = False break @@ -86,23 +103,23 @@ cdef class _patch_impl: break oi = oi + 1 if oi < c_orig_umath.ntypes: - self.functions[func_number].original_function = ( + self.functions[self.functions_count].original_function = ( c_orig_umath.functions[oi] ) - self.functions[func_number].patch_function = ( + self.functions[self.functions_count].patch_function = ( c_patch_umath.functions[pi] ) - self.functions[func_number].signature = ( + self.functions[self.functions_count].signature = ( malloc(nargs * sizeof(int)) ) for i in range(nargs): - self.functions[func_number].signature[i] = ( - c_patch_umath.types[pi * nargs + i] + self.functions[self.functions_count].signature[i] = ( + patch_types[pi * nargs + i] ) self.functions_dict[(umath, patch_umath.types[pi])] = ( - func_number + self.functions_count ) - func_number = func_number + 1 + self.functions_count += 1 else: raise RuntimeError( f"Unable to find original function for: {umath} " @@ -110,9 +127,10 @@ cdef class _patch_impl: ) def __dealloc__(self): - for i in range(self.functions_count): - free(self.functions[i].signature) - free(self.functions) + if self.functions is not NULL: + for i in range(self.functions_count): + free(self.functions[i].signature) + free(self.functions) cdef int _replace_loop( self,