function [eigvector, eigvalue, elapse] = SDA(gnd,fea,LabelIdx,UnlabelIdx,options) % SDA: Semi-supervised Discriminant Analysis % % [eigvector, eigvalue, elapse] = SDA(gnd,feaLabel,feaUnlabel,options) % % Input: % gnd - Label vector. % fea - data matrix. Each row vector of fea is a data point. % % LabelIdx - fea(LabelIdx,:) is the labeled data matrix. % UnlabelIdx - fea(UnlabelIdx,:) is the unlabeled data matrix. % % options - Struct value in Matlab. The fields in options % that can be set: % % WOptions Please see ConstructW.m for detailed options. % or % W You can construct the W outside. % % beta Paramter to tune the weight between % supervised info and local info % Default 0.1. % beta*L+\tilde{I} % % Please see LGE.m for other options. % % Output: % eigvector - Each column is an embedding function, for a new % data point (row vector) x, y = x*eigvector % will be the embedding result of x. % eigvalue - The eigvalue of SDA eigen-problem. sorted from % smallest to largest. % elapse - Time spent on different steps % % % Examples: % % % % % See also LPP, LGE % %Reference: % % Deng Cai, Xiaofei He and Jiawei Han, "Semi-Supervised Discriminant % Analysis ", IEEE International Conference on Computer Vision (ICCV), % Rio de Janeiro, Brazil, Oct. 2007. % % version 2.0 --July/2007 % version 1.0 --May/2006 % % Written by Deng Cai (dengcai2 AT cs.uiuc.edu) if ~isfield(options,'ReguType') options.ReguType = 'Ridge'; end if ~isfield(options,'ReguAlpha') options.ReguAlpha = 0.1; end [nSmp,nFea] = size(fea); nSmpLabel = length(LabelIdx); nSmpUnlabel = length(UnlabelIdx); if nSmpLabel+nSmpUnlabel ~= nSmp error('input error!'); end gnd = gnd(LabelIdx); classLabel = unique(gnd); nClass = length(classLabel); Dim = nClass; if ~isfield(options,'W') [W, timeW] = constructW(fea,options.WOptions); if isfield(options.WOptions,'bSemiSupervised') & options.WOptions.bSemiSupervised if ~isfield(options.WOptions,'SameCategoryWeight') options.WOptions.SameCategoryWeight = 1; end G2 = zeros(nSmpLabel,nSmpLabel); Label = unique(gnd); nLabel = length(Label); for idx=1:nLabel classIdx = find(gnd==Label(idx)); G2(classIdx,classIdx) = options.WOptions.SameCategoryWeight; end W(LabelIdx,LabelIdx) = G2; end else W = options.W; timeW = 0; end tmp_T = cputime; D = full(sum(W,2)); W = -W; for i=1:size(W,1) W(i,i) = W(i,i) + D(i); end beta = 0.1; if isfield(options,'beta') & (options.beta > 0) beta = options.beta; end D = W*beta; for i=1:nSmpLabel D(LabelIdx(i),LabelIdx(i)) = D(LabelIdx(i),LabelIdx(i)) + 1; end elapse.timeW = timeW + cputime - tmp_T; tmp_T = cputime; %========================== % If data is too large, the following centering codes can be commented %========================== if isfield(options,'keepMean') & options.keepMean ; else if issparse(fea) fea = full(fea); end sampleMean = mean(fea,1); fea = (fea - repmat(sampleMean,nSmp,1)); end %========================== DPrime = fea'*D*fea; switch lower(options.ReguType) case {lower('Ridge')} for i=1:size(DPrime,1) DPrime(i,i) = DPrime(i,i) + options.ReguAlpha; end case {lower('Tensor')} DPrime = DPrime + options.ReguAlpha*options.regularizerR; case {lower('Custom')} DPrime = DPrime + options.ReguAlpha*options.regularizerR; otherwise error('ReguType does not exist!'); end DPrime = max(DPrime,DPrime'); feaLabel = fea(LabelIdx,:); Hb = zeros(nClass,nFea); for i = 1:nClass, index = find(gnd==classLabel(i)); classMean = mean(feaLabel(index,:),1); Hb (i,:) = sqrt(length(index))*classMean; end WPrime = Hb'*Hb; WPrime = max(WPrime,WPrime'); elapse.timePCA = cputime - tmp_T; tmp_T = cputime; dimMatrix = size(WPrime,2); if Dim > dimMatrix Dim = dimMatrix; end if isfield(options,'bEigs') if options.bEigs bEigs = 1; else bEigs = 0; end else if (dimMatrix > 1000 & Dim < dimMatrix/10) | (dimMatrix > 500 & Dim < dimMatrix/20) | (dimMatrix > 250 & Dim < dimMatrix/30) bEigs = 1; else bEigs = 0; end end if bEigs %disp('use eigs to speed up!'); option = struct('disp',0); [eigvector, eigvalue] = eigs(WPrime,DPrime,Dim,'la',option); eigvalue = diag(eigvalue); else [eigvector, eigvalue] = eig(WPrime,DPrime); eigvalue = diag(eigvalue); [junk, index] = sort(-eigvalue); eigvalue = eigvalue(index); eigvector = eigvector(:,index); if Dim < size(eigvector,2) eigvector = eigvector(:, 1:Dim); eigvalue = eigvalue(1:Dim); end end for i = 1:size(eigvector,2) eigvector(:,i) = eigvector(:,i)./norm(eigvector(:,i)); end elapse.timeMethod = cputime - tmp_T; elapse.timeAll = elapse.timeW + elapse.timePCA + elapse.timeMethod;