aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xutils/sddm/dc.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/utils/sddm/dc.py b/utils/sddm/dc.py
index d681831..42b49b1 100755
--- a/utils/sddm/dc.py
+++ b/utils/sddm/dc.py
@@ -76,20 +76,20 @@ def get_proposal_func(stepsizes, low, high):
# log_p_x_given_x0 = truncnorm.logpdf(x,a,b,x0,stepsizes).sum(axis=1)
#
# but I think there is a bug in truncnorm.logpdf() which barfs when
- # passed 2D arrays, so instead we just loop over the first axis.
- log_p_x_given_x0 = np.empty(x0.shape[0])
- for i in range(x0.shape[0]):
- log_p_x_given_x0[i] = truncnorm.logpdf(x[i],a[i],b[i],x0[i],stepsizes).sum()
+ # passed 2D arrays, so instead we convert all the arrays to 1D, call
+ # logpdf(), and then reshape the final array.
+ log_p_x_given_x0 = truncnorm.logpdf(x.ravel(),a.ravel(),b.ravel(),x0.ravel(),np.tile(stepsizes,x0.shape[0]))
+ log_p_x_given_x0 = log_p_x_given_x0.reshape(x0.shape).sum(axis=1)
a, b = (low - x)/stepsizes, (high - x)/stepsizes
# Note: Should be able to do this:
#
# log_p_x0_given_x = truncnorm.logpdf(x0,a,b,x,stepsizes).sum(axis=1)
#
# but I think there is a bug in truncnorm.logpdf() which barfs when
- # passed 2D arrays, so instead we just loop over the first axis.
- log_p_x0_given_x = np.empty(x0.shape[0])
- for i in range(x0.shape[0]):
- log_p_x0_given_x[i] = truncnorm.logpdf(x0[i],a[i],b[i],x[i],stepsizes).sum()
+ # passed 2D arrays, so instead we convert all the arrays to 1D, call
+ # logpdf(), and then reshape the final array.
+ log_p_x0_given_x = truncnorm.logpdf(x0.ravel(),a.ravel(),b.ravel(),x.ravel(),np.tile(stepsizes,x0.shape[0]))
+ log_p_x0_given_x = log_p_x0_given_x.reshape(x0.shape).sum(axis=1)
return x, log_p_x0_given_x - log_p_x_given_x0
return proposal