Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions mkl_umath/src/_patch_numpy.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ from libc.stdlib cimport free, malloc

cnp.import_umath()

cdef extern from *:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to populate the chnagelog

"""
#include "numpy/ufuncobject.h"
static inline char* _get_ufunc_types(PyObject *u) {
return (char *)((PyUFuncObject *)u)->types;
}
"""
char* _get_ufunc_types(object u) noexcept


ctypedef struct function_info:
cnp.PyUFuncGenericFunction original_function
cnp.PyUFuncGenericFunction patch_function
Expand All @@ -53,66 +63,74 @@ cdef class _patch_impl:
functions_dict = dict()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be an instance attribute instead of class-level one?

    cdef dict functions_dict

In that case each _patch_impl instance owns its own mapping and it's freed with the instance (more robust approach).

But we need to add initialization to the cinit below:

    def __cinit__(self):
        self.functions_dict = {}
        ...


def __cinit__(self):
cdef int pi, oi
cdef int pi, oi, i, nargs
cdef int expected_count
cdef char* patch_types
cdef char* orig_types

umaths = [i for i in dir(mu) if isinstance(getattr(mu, i), np.ufunc)]
self.functions = NULL
self.functions_count = 0

umaths = [x for x in dir(mu) if isinstance(getattr(mu, x), np.ufunc)]
expected_count = 0
for umath in umaths:
mkl_umath_func = getattr(mu, umath)
self.functions_count += mkl_umath_func.ntypes
expected_count += mkl_umath_func.ntypes

self.functions = <function_info *> malloc(
self.functions_count * sizeof(function_info)
expected_count * sizeof(function_info)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call malloc only when expected_count > 0?

)

func_number = 0
for umath in umaths:
patch_umath = getattr(mu, umath)
c_patch_umath = <cnp.ufunc>patch_umath
c_orig_umath = <cnp.ufunc>getattr(np, umath)
nargs = c_patch_umath.nargs
patch_types = _get_ufunc_types(c_patch_umath)
orig_types = _get_ufunc_types(c_orig_umath)
for pi in range(c_patch_umath.ntypes):
oi = 0
while oi < c_orig_umath.ntypes:
found = True
for i in range(c_patch_umath.nargs):
for i in range(nargs):
if (
c_patch_umath.types[pi * nargs + i]
!= c_orig_umath.types[oi * nargs + i]
patch_types[pi * nargs + i]
!= orig_types[oi * nargs + i]
):
found = False
break
if found is True:
break
oi = oi + 1
if oi < c_orig_umath.ntypes:
self.functions[func_number].original_function = (
self.functions[self.functions_count].original_function = (
c_orig_umath.functions[oi]
)
self.functions[func_number].patch_function = (
self.functions[self.functions_count].patch_function = (
c_patch_umath.functions[pi]
)
self.functions[func_number].signature = (
self.functions[self.functions_count].signature = (
<int *> malloc(nargs * sizeof(int))
)
for i in range(nargs):
self.functions[func_number].signature[i] = (
c_patch_umath.types[pi * nargs + i]
self.functions[self.functions_count].signature[i] = (
patch_types[pi * nargs + i]
)
self.functions_dict[(umath, patch_umath.types[pi])] = (
func_number
self.functions_count
)
func_number = func_number + 1
self.functions_count += 1
else:
raise RuntimeError(
f"Unable to find original function for: {umath} "
f"{patch_umath.types[pi]}"
)

def __dealloc__(self):
for i in range(self.functions_count):
free(self.functions[i].signature)
free(self.functions)
if self.functions is not NULL:
for i in range(self.functions_count):
free(self.functions[i].signature)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be NULL when nargs == 0

free(self.functions)

cdef int _replace_loop(
self,
Expand Down
Loading