function [] = network_main_v22() %UNTITLED Summary of this function goes here % Detailed explanation goes here ads = audioDatastore('onlineDataset_v5-3','IncludeSubfolders',true,'LabelSource','foldernames'); labelTable = countEachLabel(ads); numLabels = size(labelTable,1); numAudios = length(ads.Labels); overlapPercentage=75; num_folds=5; for fold_idx=1:num_folds fprintf('processing %d among %d folds \n',fold_idx,num_folds); test_idx=fold_idx:num_folds:numAudios; adsTest=subset(ads,test_idx); testLabels = adsTest.Labels; testAudios = adsTest.Files; % testFeatures = yammet_features(testAudios); [testFeatures,testLabels,segmentsPerFile] = vggishPreprocess(adsTest,overlapPercentage); train_idx=setdiff(1:length(ads.Files),test_idx); adsTrainVlidation=subset(ads,train_idx); [adsTrain, adsValidation] = splitEachLabel(adsTrainVlidation,0.9,'randomized'); trainLabels = adsTrain.Labels; uniqueLabels = unique(trainLabels); trainAudios =adsTrain.Files; % trainFeatures = yammet_features(trainAudios); [trainFeatures,trainLabels] = vggishPreprocess(adsTrain,overlapPercentage); validationLabels = adsValidation.Labels; validationAudios = adsValidation.Files; % validationFeatures = yammet_features(validationAudios); [validationFeatures,validationLabels] = vggishPreprocess(adsValidation,overlapPercentage); clear net; net = load('yamnet1.mat'); net =net.trainedNet; lgraph = layerGraph(net.Layers); lgraph = replaceLayer(lgraph,"dense",fullyConnectedLayer(numLabels,"Name","dense")); lgraph = replaceLayer(lgraph,"Sounds",classificationLayer("Name","Sounds","Classes",uniqueLabels)); layers = lgraph.Layers; for i = 1:21 if isprop(layers(i),'WeightLearnRateFactor') layers(i).WeightLearnRateFactor = 0; end if isprop(layers(i),'WeightL2Factor') layers(i).WeightL2Factor = 0; end if isprop(layers(i),'BiasLearnRateFactor') layers(i).BiasLearnRateFactor = 0; end if isprop(layers(i),'BiasL2Factor') layers(i).BiasL2Factor = 0; end end lgraph = layerGraph(layers); miniBatchSize = 32; options = trainingOptions('adam',... 'MaxEpochs',20,'MiniBatchSize',miniBatchSize,... 'Shuffle','every-epoch', ... 'ValidationData',{single(validationFeatures),validationLabels}, ... 'Verbose',false, ... 'Plots','training-progress',... 'ValidationPatience',7,'ExecutionEnvironment','gpu','InitialLearnRate',0.0001); [trainedNet, netInfo] =trainNetwork(single(trainFeatures),trainLabels,lgraph,options); [testPredictions scores] = classify(trainedNet,single(testFeatures)); save(['BW-yamnet.1',num2str(fold_idx),'.mat'],'trainedNet'); save(['info',num2str(fold_idx),'.mat'],'netInfo'); idx = 1; for ii = 1:numel(adsTest.Files) testPredictionsPerFile(1,test_idx(ii)) = mode(testPredictions(idx:idx+segmentsPerFile(ii)-1)); % scoresPerFile(1,test_idx(ii)) = mode(scores(idx:idx+segmentsPerFile(ii)-1)); scoresPerFile(test_idx(ii),:) = mode(scores(idx:idx+segmentsPerFile(ii)-1,:)); idx = idx + segmentsPerFile(ii); end end actual_labels=ads.Labels; % Confusion Matrix figure; plotconfusion(actual_labels,testPredictionsPerFile') title(''); % figure; % [X,Y,T,AUC]=perfcurve(actual_labels,scoresPerFile,'brainwave entrainment'); % AUC % plot(X,Y); % xlabel('False positive rate') % ylabel('True positive rate') % title('ROC Curve') % figure; % [X1,Y1,T,AUC1]=perfcurve(actual_labels,scoresPerFile(:,1),'alpha'); % plot(X1,Y1); % hold on % % [X2,Y2,T,AUC2]=perfcurve(actual_labels,scoresPerFile(:,2),'beta'); % plot(X2,Y2); % hold on % % [X3,Y3,T,AUC3]=perfcurve(actual_labels,scoresPerFile(:,3),'delta'); % plot(X3,Y3); % hold on % % [X4,Y4,T,AUC4]=perfcurve(actual_labels,scoresPerFile(:,6),'theta'); % plot(X4,Y4); % hold on % [X5,Y5,T,AUC5]=perfcurve(actual_labels,scoresPerFile(:,4),'gamma'); % plot(X5,Y5); % hold off % AUC = (AUC1+AUC2+AUC3+AUC4+AUC5)/5 % legend('alpha','beta','delta','theta','gamma','Location','SE'); % xlabel('False positive rate') % ylabel('True positive rate') % title('ROC Curve') end