function mednet() %for counter=0:5 unzip('medData8.zip'); inputSize=[128 128]; imds = imageDatastore('medData8', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames') [imdsTrain,imdsValidation] = splitEachLabel(imds,0.7); imdstmp=imdsValidation; imds = augmentedImageDatastore(inputSize,imds); net =googlenet; analyzeNetwork(net) net.Layers(1) % display('hello'); inputSize = net.Layers(1).InputSize if isa(net,'SeriesNetwork') lgraph = layerGraph(net.Layers); else lgraph = layerGraph(net); end display('nihao'); [learnableLayer,classLayer] = findLayersToReplace(lgraph); % [learnableLayer,classLayer] numClasses = numel(categories(imdsTrain.Labels)); if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer') newLearnableLayer = fullyConnectedLayer(numClasses, ... 'Name','new_fc', ... 'WeightLearnRateFactor',10, ... 'BiasLearnRateFactor',10); elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer') newLearnableLayer = convolution2dLayer(1,numClasses, ... 'Name','new_conv', ... 'WeightLearnRateFactor',10, ... 'BiasLearnRateFactor',10); end lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer); newClassLayer = classificationLayer('Name','new_classoutput'); lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer); figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]); plot(lgraph) ylim([0,10]) layers = lgraph.Layers; connections = lgraph.Connections; layers(1:10) = freezeWeights(layers(1:10)); lgraph = createLgraphUsingConnections(layers,connections); pixelRange = [-30 30]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange, ... 'RandXScale',scaleRange, ... 'RandYScale',scaleRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... 'DataAugmentation',imageAugmenter); augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation); % augimds=augmentedImageDatastore(inputSize(1:2),imds); miniBatchSize = 10; valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize); options = trainingOptions('FCGDCKS', ... 'MiniBatchSize',miniBatchSize, ... 'MaxEpochs',6, ... 'InitialLearnRate',3e-4, ... 'Shuffle','every-epoch', ... 'ValidationData',augimdsValidation, ... 'ValidationFrequency',valFrequency, ... 'Verbose',true,'VerboseFrequency',19, ... 'Plots','training-progress',... 'GradientThresholdMethod','l2norm'); net = trainNetwork(augimdsTrain,lgraph,options); [YPred,probs] = classify(net,augimdsValidation); accuracy = mean(YPred == imdsValidation.Labels) save net %end % [YPred,probs] = classify(net,augimds); figure plotconfusion(imdstmp.Labels,YPred); figure [xroc,yroc,troc,auc]=perfcurve(augimds.Files,probs(:,1),1); plot(xroc,yroc); % auc % idx = randperm(numel(imdsValidation.Files),4); % figure % for i = 1:4 % subplot(2,2,i) % I = readimage(imdsValidation,idx(i)); % imshow(I) % label = YPred(idx(i)); % title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); % end