Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
"build_cache_size": 2,
"default_benchmark_timeout": 500,
"regressions_thresholds": {
".*": 0.3
".*": 0.2
}
}
5 changes: 5 additions & 0 deletions benchmarks/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import psutil

from ._patch_setup import _apply_patches

_MIN_THREADS = 4 # minimum physical cores required for multi-threaded mode


Expand All @@ -19,3 +21,6 @@ def _thread_count():

_THREADS = os.environ.get("MKL_NUM_THREADS", _thread_count())
os.environ["MKL_NUM_THREADS"] = _THREADS

_apply_patches()
del _apply_patches
67 changes: 67 additions & 0 deletions benchmarks/benchmarks/_patch_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""MKL patch setup — executed once per ASV worker process at import time.

Patches NumPy FFT with the Intel MKL FFT implementation.
Hard-fails with a descriptive RuntimeError if mkl_fft is missing or the
patch does not take effect, so benchmarks never silently run on stock NumPy.
"""
Comment thread
vchamarthi marked this conversation as resolved.

_PATCH_MAP = [
("mkl_fft", "patch_numpy_fft"),
]
Comment thread
vchamarthi marked this conversation as resolved.
Outdated


def _apply_patches():
import importlib

import numpy as np

patched = {}

for mod_name, patch_fn_name in _PATCH_MAP:
try:
mod = importlib.import_module(mod_name)
except ImportError as exc:
raise RuntimeError(
f"[mkl-patch] Cannot import {mod_name}: {exc}\n"
f" Ensure the conda env contains {mod_name} "
f"from the Intel channel.\n"
" Required channels: "
"https://software.repos.intel.com/python/conda"
) from exc

patch_fn = getattr(mod, patch_fn_name, None)
if patch_fn is None:
raise RuntimeError(
f"[mkl-patch] {mod_name} has no {patch_fn_name}(). "
Comment thread
vchamarthi marked this conversation as resolved.
Outdated
f"Upgrade {mod_name} to a version that exposes "
"the stock-numpy patch API."
)

try:
patch_fn()
except Exception as exc:
raise RuntimeError(
f"[mkl-patch] {mod_name}.{patch_fn_name}() raised: {exc!r}"
) from exc

is_patched_fn = getattr(mod, "is_patched", None)
if callable(is_patched_fn) and not is_patched_fn():
raise RuntimeError(
f"[mkl-patch] {mod_name}.is_patched() returned False "
"after patching. NumPy may have been imported before "
"patching in a conflicting state."
)

patched[mod_name] = mod

_attr_checks = {
"mkl_fft": lambda: np.fft.fft.__module__,
}
for mod_name in patched:
try:
attr = _attr_checks[mod_name]()
except Exception:
attr = "unknown"
print(f"[mkl-patch] {mod_name}: numpy.fft dispatch -> {attr}")
Comment thread
vchamarthi marked this conversation as resolved.
Outdated

print("[mkl-patch] ALL OK -- mkl_fft active")
Comment thread
vchamarthi marked this conversation as resolved.
Outdated
Comment thread
vchamarthi marked this conversation as resolved.
Outdated
2 changes: 2 additions & 0 deletions benchmarks/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
psutil
scipy
Comment thread
vchamarthi marked this conversation as resolved.
Loading