diff options
author | tlatorre <tlatorre@uchicago.edu> | 2021-01-03 12:08:01 -0600 |
---|---|---|
committer | tlatorre <tlatorre@uchicago.edu> | 2021-01-03 12:08:01 -0600 |
commit | c528bf6c0e166e49eadfafd15a9ccc9d384d818f (patch) | |
tree | ed54743d63f7a1bae811316a411ca8a1f4501a03 /utils | |
parent | 00a64e94c87fd9ccd1a0c3b01f49832990139dc1 (diff) | |
download | sddm-c528bf6c0e166e49eadfafd15a9ccc9d384d818f.tar.gz sddm-c528bf6c0e166e49eadfafd15a9ccc9d384d818f.tar.bz2 sddm-c528bf6c0e166e49eadfafd15a9ccc9d384d818f.zip |
add a numba optimized version of interp
Diffstat (limited to 'utils')
-rw-r--r-- | utils/sddm/utils.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/utils/sddm/utils.py b/utils/sddm/utils.py index 2b67e63..e2e2238 100644 --- a/utils/sddm/utils.py +++ b/utils/sddm/utils.py @@ -11,6 +11,39 @@ def fast_cdf(x,loc,scale): """ return np.interp((x-loc)/scale,_fast_cdf_x,_fast_cdf_y) +try: + from numba import vectorize, float64, njit + + print("numba found! Loading optimized version of fast_cdf()") + + @njit + def interp(x,xp,yp): + n = len(xp) + + if x <= xp[0]: + return yp[0] + + idx = 1.0/(xp[1]-xp[0]) + + i = int((x-xp[0])*idx) + + if i > n-2: + return yp[n-1] + + return yp[i] + (yp[i+1]-yp[i])*(x-xp[i])*idx + + @vectorize([float64(float64, float64, float64)]) + def fast_cdf(x,loc,scale): + # Here, we optimize for the case that the value is out of bounds since that + # is the most likely. We do this check here to avoid an expensive division. + if x - loc > _fast_cdf_x[-1]*scale: + return _fast_cdf_y[-1] + elif x - loc < _fast_cdf_x[0]*scale: + return _fast_cdf_y[0] + return interp((x-loc)/scale,_fast_cdf_x,_fast_cdf_y) +except ImportError: + print("No numba, loading slower version of fast_cdf()") + # Energy bias of reconstruction relative to Monte Carlo. # # Note: You can recreate this array using: |