Home > spm_vbglmar_slice > spm_vb_init_slice.m

spm_vb_init_slice

PURPOSE ^

Initialise Variational Bayes for GLM-AR models

SYNOPSIS ^

function [slice] = spm_vb_init_slice (Y,slice)

DESCRIPTION ^

 Initialise Variational Bayes for GLM-AR models
 FORMAT [slice] = spm_vb_init_slice (Y,slice)

 Y             [T x N] time series with T time points, N voxels
 slice         GLM-AR data structure

 %W% Will Penny and Nelson Trujillo-Barreto %E%

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [slice] = spm_vb_init_slice (Y,slice)
0002 % Initialise Variational Bayes for GLM-AR models
0003 % FORMAT [slice] = spm_vb_init_slice (Y,slice)
0004 %
0005 % Y             [T x N] time series with T time points, N voxels
0006 % slice         GLM-AR data structure
0007 %
0008 % %W% Will Penny and Nelson Trujillo-Barreto %E%
0009 
0010 k=slice.k;
0011 p=slice.p;
0012 N=slice.N;
0013 T=slice.T;
0014 X=slice.X;
0015 
0016 %% Default optimisation parameters
0017 if ~isfield(slice,'tol')
0018     slice.tol=0.0001;
0019 end
0020 if ~isfield(slice,'maxits')
0021     slice.maxits=4;
0022 end
0023 if ~isfield(slice,'verbose')
0024     slice.verbose=1;
0025 end
0026 
0027 if slice.verbose
0028     disp('Initialising slice');
0029     disp(' ');
0030 end
0031 
0032 %% Default priors
0033 if ~isfield(slice,'b_alpha_prior')
0034     slice.b_alpha_prior=10;
0035 end   
0036 if ~isfield(slice,'c_alpha_prior')
0037     slice.c_alpha_prior=0.1;
0038 end
0039 if ~isfield(slice,'b_beta_prior')
0040     slice.b_beta_prior=10;
0041 end   
0042 if ~isfield(slice,'c_beta_prior')
0043     slice.c_beta_prior=0.1;
0044 end
0045 
0046 if ~isfield(slice,'b_lambda_prior')
0047     slice.b_lambda_prior=10;
0048 end
0049 if ~isfield(slice,'c_lambda_prior')
0050     slice.c_lambda_prior=0.1;
0051 end
0052 
0053 
0054 % Initialise approximate alpha posterior
0055 if isfield(slice,'mean_alpha')
0056     k1=size(slice.mean_alpha);
0057     if ~(k1==k)
0058         disp('Error in spm_vbglmar_slice: mean_alpha incorrect dimension');
0059         return
0060     end
0061     slice.b_alpha=slice.mean_alpha/slice.c_alpha_prior;
0062     slice.c_alpha=slice.c_alpha_prior*ones(k,1);
0063 else
0064     slice.b_alpha=slice.b_alpha_prior*ones(k,1);
0065     slice.c_alpha=slice.c_alpha_prior*ones(k,1);
0066     slice.mean_alpha=slice.b_alpha.*slice.c_alpha;
0067 end
0068 slice.c_alpha  = (N./2 + slice.c_alpha_prior)*ones(k,1);
0069 
0070 % Initialise approximate beta posterior
0071 if isfield(slice,'mean_beta')
0072     p1=size(slice.mean_beta);
0073     if ~(p1==p)
0074         disp('Error in spm_vbglmar_slice: mean_beta incorrect dimension');
0075         return
0076     end
0077     slice.b_beta=slice.mean_beta/slice.c_beta_prior;
0078     slice.c_beta=slice.c_beta_prior*ones(p,1);
0079 else
0080     slice.b_beta=slice.b_beta_prior*ones(p,1);
0081     slice.c_beta=slice.c_beta_prior*ones(p,1);
0082     slice.mean_beta=slice.b_beta.*slice.c_beta;
0083 end
0084 slice.c_beta  = (p/2 + slice.c_beta_prior)*ones(p,1);
0085 
0086 % Initialise approximate lambda posterior
0087 if isfield(slice,'mean_lambda')
0088     n1=size(slice.mean_lambda,1);
0089     if ~(n1==N)
0090         disp('Error in spm_vbglmar_slice: mean_lambda incorrect dimension');
0091         return
0092     end
0093     slice.b_lambda=slice.mean_lambda/slice.c_lambda_prior;
0094     slice.c_lambda=slice.c_lambda_prior*ones(N,1);
0095 else
0096     slice.b_lambda=slice.b_lambda_prior*ones(N,1);
0097     slice.c_lambda=slice.c_lambda_prior*ones(N,1);
0098     slice.mean_lambda=slice.b_lambda.*slice.c_lambda;
0099 end
0100 slice.c_lambda = ((T-slice.p)./2 + slice.c_lambda_prior)*ones(N,1);
0101 
0102 disp('Initialising regression coefficient posterior');
0103 % Initialise approximate w posterior
0104 try
0105     Xp=slice.Xp;
0106     X2=slice.X2;
0107 catch
0108     [ux,dx,vx]=svd(X);
0109     ddx=diag(dx);
0110     svd_tol=max(ddx)*eps*k;
0111     rank_X=sum(ddx > svd_tol);
0112     ddxm=diag(ones(rank_X,1)./ddx(1:rank_X));
0113     ddxm2=diag(ones(rank_X,1)./(ddx(1:rank_X).^2));
0114     Xp=vx(:,1:rank_X)*ddxm*ux(:,1:rank_X)';
0115     X2=vx(:,1:rank_X)*ddxm2*vx(:,1:rank_X)';
0116 end  
0117 
0118 w_ols=Xp*Y;
0119 w_mean = w_ols;
0120 Y_pred = X*w_mean;
0121 v=mean((Y-Y_pred).^2);
0122 for n=1:N,
0123     w_cov_temp = v(n)*X2;
0124     slice.w_cov{n} = w_cov_temp;
0125 end;
0126 slice.w_mean=w_mean(:);
0127 slice.w_ols=slice.w_mean;
0128 slice.wk_mean      = reshape(slice.w_mean,k,N);
0129 slice.wk_ols      = reshape(slice.w_ols,k,N);
0130 
0131 % Initialise AR coefficient posterior
0132 disp('Initialising AR coefficients');
0133 % Embed data
0134 for pp=1:p,
0135     dy(pp,1:T-p,:)=Y(p-pp+1:T-pp,:);
0136 end
0137 if p>0
0138     e = Y(p+1:T,:) - Y_pred(p+1:T,:);
0139     for n=1:N,
0140         for pp=1:p,
0141             dyhat(pp,:)=w_mean(:,n)'*squeeze(slice.dX(pp,:,:));
0142         end
0143         E_tilde=dy(:,:,n)-dyhat;
0144         iterm       = inv(E_tilde * E_tilde');
0145         slice.ap_mean(:,n) = (iterm * E_tilde*e(:,n));  
0146         e_pred      = E_tilde' * slice.ap_mean(:,n);
0147         v2          = mean((e(:,n) - e_pred).^2);
0148         slice.a_cov{n} = v2 * iterm;
0149         slice.a2{n}=slice.ap_mean(:,n)*slice.ap_mean(:,n)'+slice.a_cov{n};
0150     end
0151     slice.ap_ols=slice.ap_mean;
0152     slice.a_mean=slice.ap_mean(:);
0153 end
0154 
0155 if p>0
0156     disp('Setting up cross-covariance matrices');
0157     % Get input-output lagged covariances (I.rxy, I.gxy, I.Gxy and I.D)
0158     % and (I.Gy, I.gy)
0159     slice.I.gxy=slice.X(p+1:T,:)'*Y(p+1:T,:);
0160     for n=1:N,
0161         slice.I.rxy(:,:,n)=dy(:,:,n)*X(p+1:T,:);
0162         for ki=1:k,
0163             if slice.p>1
0164                 Dtmp=dy(:,:,n)*squeeze(slice.dX(:,ki,:))';
0165             else
0166                 % With p=1, 'squeeze' already tranposes singleton dimension
0167                 Dtmp=dy(:,:,n)*squeeze(slice.dX(:,ki,:));
0168             end
0169             Dv=Dtmp(:)';
0170             slice.I.D(ki,:,n)=Dv;
0171             if slice.p>1
0172                 slice.I.Gxy(:,ki,n)=squeeze(slice.dX(:,ki,:))*Y(p+1:T,n);
0173             else
0174                 slice.I.Gxy(:,ki,n)=squeeze(slice.dX(:,ki,:))'*Y(p+1:T,n);
0175             end
0176         end
0177         slice.I.Gy(:,:,n)=dy(:,:,n)*dy(:,:,n)';
0178         slice.I.gy(:,n)=dy(:,:,n)*Y(p+1:T,n);
0179     end
0180 end
0181 
0182 disp('Setting up spatial permutation matrices');
0183 % Set up permutation matrix for regression coefficients
0184 Nk = N*k;
0185 slice.Hw=sparse(Nk,Nk);
0186 ii=[];
0187 for kk=1:k,
0188     ii=[ii, kk:k:Nk];
0189 end
0190 for nk=1:Nk,
0191     slice.Hw(ii(nk),nk)=1;
0192 end
0193 % Set up permutation matrix for AR coefficients
0194 if p > 0
0195     Np = N*p;
0196     slice.Ha=sparse(Np,Np);
0197     ii=[];
0198     for pp=1:p,
0199         ii=[ii, pp:p:Np];
0200     end
0201     for np=1:Np,
0202         slice.Ha(ii(np),np)=1;
0203     end
0204 end
0205 
0206 if slice.update_F
0207     disp('Computing log determinant of spatial precision matrix for evidence');
0208     % Get log determinant of D
0209     [vvv,ddd]=eig(full(slice.D));
0210     slice.log_det_D=sum(log(diag(ddd)));
0211 end
0212 
0213 disp('Computing data projections');
0214 % Set up design and data projections
0215 try 
0216     slice.XT;
0217     slice.XTX;
0218 catch
0219     slice.XT=X';
0220     slice.XTX=slice.XT*X;
0221 end
0222 
0223 for n=1:N,
0224     slice.XTY(:,n)=slice.XT*Y(:,n);
0225     slice.y2(n)=Y(p+1:T,n)'*Y(p+1:T,n);
0226 end
0227 
0228 % Precompute quantities for the Negative Free Energy
0229 slice.C2  = N*T*log(2*pi);
0230 
0231

Generated on Mon 23-Aug-2004 14:59:38 by m2html © 2003