aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortlatorre <tlatorre@uchicago.edu>2021-01-03 12:08:01 -0600
committertlatorre <tlatorre@uchicago.edu>2021-01-03 12:08:01 -0600
commitc528bf6c0e166e49eadfafd15a9ccc9d384d818f (patch)
treeed54743d63f7a1bae811316a411ca8a1f4501a03
parent00a64e94c87fd9ccd1a0c3b01f49832990139dc1 (diff)
downloadsddm-c528bf6c0e166e49eadfafd15a9ccc9d384d818f.tar.gz
sddm-c528bf6c0e166e49eadfafd15a9ccc9d384d818f.tar.bz2
sddm-c528bf6c0e166e49eadfafd15a9ccc9d384d818f.zip
add a numba optimized version of interp
-rw-r--r--utils/sddm/utils.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/utils/sddm/utils.py b/utils/sddm/utils.py
index 2b67e63..e2e2238 100644
--- a/utils/sddm/utils.py
+++ b/utils/sddm/utils.py
@@ -11,6 +11,39 @@ def fast_cdf(x,loc,scale):
"""
return np.interp((x-loc)/scale,_fast_cdf_x,_fast_cdf_y)
+try:
+ from numba import vectorize, float64, njit
+
+ print("numba found! Loading optimized version of fast_cdf()")
+
+ @njit
+ def interp(x,xp,yp):
+ n = len(xp)
+
+ if x <= xp[0]:
+ return yp[0]
+
+ idx = 1.0/(xp[1]-xp[0])
+
+ i = int((x-xp[0])*idx)
+
+ if i > n-2:
+ return yp[n-1]
+
+ return yp[i] + (yp[i+1]-yp[i])*(x-xp[i])*idx
+
+ @vectorize([float64(float64, float64, float64)])
+ def fast_cdf(x,loc,scale):
+ # Here, we optimize for the case that the value is out of bounds since that
+ # is the most likely. We do this check here to avoid an expensive division.
+ if x - loc > _fast_cdf_x[-1]*scale:
+ return _fast_cdf_y[-1]
+ elif x - loc < _fast_cdf_x[0]*scale:
+ return _fast_cdf_y[0]
+ return interp((x-loc)/scale,_fast_cdf_x,_fast_cdf_y)
+except ImportError:
+ print("No numba, loading slower version of fast_cdf()")
+
# Energy bias of reconstruction relative to Monte Carlo.
#
# Note: You can recreate this array using: