0001 function [F,Lav, KL_w,KL_alpha,KL_lambda,KL_a] = spm_vb_F (Y,slice)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 if slice.verbose
0014 disp('Updating F');
0015 end
0016
0017 T=slice.T;
0018 p=slice.p;
0019 k=slice.k;
0020 N=slice.N;
0021 X=slice.X;
0022
0023 C2=slice.C2;
0024
0025 Bk = kron(diag(slice.mean_alpha),slice.D);
0026 B=slice.Hw*Bk*slice.Hw';
0027
0028 if slice.p>0
0029 Jk = kron(diag(slice.mean_beta),slice.D);
0030 J=slice.Ha*Jk*slice.Ha';
0031 end
0032
0033 tr_B_qcov=0;
0034 log_det_qcov=0;
0035
0036 if slice.p>0
0037 tr_J_acov=0;
0038 log_det_acov=0;
0039 KL_a=0;
0040 end
0041
0042
0043 KL_lambda=0;
0044 C1=0;
0045 Lav_term=0;
0046 for n=1:slice.N,
0047 block_n = [(n-1)*k+1:n*k];
0048 ablock_n = [(n-1)*p+1:n*p];
0049 if slice.p > 0
0050 G(n,1)=spm_vb_get_Gn(Y,slice,n);
0051 tr_J_acov=tr_J_acov+trace(J(ablock_n,ablock_n)*slice.a_cov{n});
0052 log_det_acov=log_det_acov+log(det(slice.a_cov{n}));
0053 else
0054 wc=slice.w_cov{n};
0055 en=(Y(:,n)-X*slice.w_mean(block_n,1));
0056 Gn=trace(wc*slice.XTX)+en'*en;
0057 Lav_term=Lav_term+slice.mean_lambda(n)*Gn;
0058 end
0059
0060 C1 = C1 + spm_digamma(slice.c_lambda(n)) + log(slice.b_lambda(n));
0061 KL_lambda=KL_lambda+spm_kl_gamma(slice.b_lambda(n),slice.c_lambda(n),slice.b_lambda_prior,slice.c_lambda_prior);
0062
0063 tr_B_qcov=tr_B_qcov+trace(B(block_n,block_n)*slice.w_cov{n});
0064 log_det_qcov=log_det_qcov+log(det(slice.w_cov{n}));
0065 end
0066 if slice.p > 0
0067 Lav_term=slice.mean_lambda.'*G;
0068 end
0069 Lav = ((T-p)*C1 - Lav_term - C2)./2;
0070
0071
0072 KL_alpha=0;
0073 log_det_alphas=0;
0074 for j = 1:k,
0075 KL_alpha=KL_alpha+spm_kl_gamma(slice.b_alpha(j),slice.c_alpha(j),slice.b_alpha_prior,slice.c_alpha_prior);
0076 log_det_alphas=log_det_alphas+log(slice.mean_alpha(j));
0077 end
0078 term1=-0.5*N*log_det_alphas;
0079
0080
0081 if slice.p > 0
0082 KL_beta=0;
0083 log_det_betas=0;
0084 for j = 1:p,
0085 KL_beta=KL_beta+spm_kl_gamma(slice.b_beta(j),slice.c_beta(j),slice.b_beta_prior,slice.c_beta_prior);
0086 log_det_betas=log_det_betas+log(slice.mean_beta(j));
0087 end
0088 beta_term1=-0.5*N*log_det_betas;
0089 end
0090
0091
0092 try
0093
0094
0095 slice.log_det_D;
0096 catch
0097 [vvv,ddd]=eig(full(slice.D));
0098 slice.log_det_D=sum(log(diag(ddd)));
0099 end
0100 term1=term1-0.5*k*slice.log_det_D;
0101 KL_w=term1-0.5*log_det_qcov+0.5*tr_B_qcov+0.5*slice.w_mean'*B*slice.w_mean-0.5*N*k;
0102
0103
0104 if slice.p > 0
0105 beta_term1=beta_term1-0.5*p*slice.log_det_D;
0106 KL_a=beta_term1-0.5*log_det_acov+0.5*tr_J_acov+0.5*slice.a_mean'*J*slice.a_mean-0.5*N*p;
0107 F = Lav - (KL_w+KL_alpha+KL_lambda+KL_a+KL_beta);
0108 else
0109 F = Lav - (KL_w+KL_alpha+KL_lambda);
0110 end
0111
0112