Contents
function [X,Y,Out] = nmfc(data,known,m,n,r,opts)
Parameters and defaults
if isfield(opts,'tol'), tol = opts.tol; else tol = 1e-4; end
if isfield(opts,'maxit'), maxit = opts.maxit; else maxit = 500; end
if isfield(opts,'maxT'), maxT = opts.maxT; else maxT = 1e3; end
if isfield(opts,'rw'), rw = opts.rw; else rw = 1; end
Data preprocessing and initialization
if isfield(opts,'X0'), X0 = opts.X0; else X0 = max(0,randn(m,r)); end
if isfield(opts,'Y0'), Y0 = opts.Y0; else Y0 = max(0,randn(r,n)); end
[known,Id] = sort(known); data = data(Id);
M = zeros(m,n); M(known) = data;
nrmb = norm(data);
Mnrm = nrmb;
X0 = X0/norm(X0,'fro')*sqrt(nrmb);
Y0 = Y0/norm(Y0,'fro')*sqrt(nrmb);
Xm = X0; Ym = Y0;
Yt = Y0'; Ys = Y0*Yt; MYt = M*Yt;
obj0 = 0.5*Mnrm^2;
t0 = 1;
Lx = 1; Ly = 1;
Iterations of block-coordinate update
iteratively updated variables:
M: estimated matrix. M(known) is fixed; M(~known) is iterative updated
Gx, Gy: gradients with respect to X, Y
X, Y: new updates
Xm, Ym: extrapolations of X, Y
Lx, Lx0: current and previous Lipschitz bounds used in X-update
Ly, Ly0: current and previous Lipschitz bounds used in Y-update
obj, obj0: current and previous objective valuesfprintf('Iteration: ');
start_time = tic;
for k = 1:maxit
fprintf('\b\b\b\b\b%5i',k);
Lx0 = Lx; Lx = norm(Ys);
Gx = Xm*Ys - MYt;
X = max(0, Xm - Gx/Lx);
Xt = X'; Xs = Xt*X;
Ly0 = Ly; Ly = norm(Xs);
Gy = Xs*Ym - Xt*M;
Y = max(0, Ym - Gy/Ly);
Yt = Y'; Ys = Y*Yt;
M = X*Y; M(known) = data;
Mnrm = norm(M,'fro');
MYt = M*Yt;
obj = 0.5*(sum(sum(Xs.*Ys))-2*sum(sum(X.*MYt))+Mnrm^2);
relerr1 = abs(obj-obj0)/(obj0+1); relerr2 = sqrt(2*obj)/nrmb;
Out.hist_obj(k) = obj;
Out.hist_rel(1,k) = relerr1;
Out.hist_rel(2,k) = relerr2;
crit = relerr1<tol;
if crit; nstall = nstall+1; else nstall = 0; end
if nstall>=3 || relerr2<tol; break; end
if toc(start_time) > maxT; break; end;
t = (1+sqrt(1+4*t0^2))/2;
if obj>obj0
Xm = X0; Ym = Y0; Yt = Y0'; Ys = Y0*Yt;
MYt = M0*Yt;
else
w = (t0-1)/t;
wx = min([w,rw*sqrt(Lx0/Lx)]);
wy = min([w,rw*sqrt(Ly0/Ly)]);
Xm = X+wx*(X-X0); Ym = Y+wy*(Y-Y0);
X0 = X; Y0 = Y; M0 = M; t0 = t; obj0 = obj;
end
end
fprintf('\n');
Out.iter = k;