function [trainedNet validationPredictions] = network_main() %UNTITLED Summary of this function goes here % Detailed explanation goes here ads = audioDatastore('onlineDataset','IncludeSubfolders',true,'LabelSource','foldernames'); labelTable = countEachLabel(ads); numClasses = size(labelTable,1); numAudios = length(ads.Labels); % [adsTrain, adsValidation] = splitEachLabel(ads,0.8); % % K-fold validation num_folds=5; for fold_idx=1:num_folds fprintf('processing %d among %d folds \n',fold_idx,num_folds); test_idx=fold_idx:num_folds:numAudios; adsTest=subset(ads,test_idx); train_idx=setdiff(1:length(ads.Files),test_idx); adsTrain=subset(ads,train_idx); overlapPercentage=75; [trainFeatures,trainLabels] = vggishPreprocess(adsTrain,overlapPercentage); [validationFeatures,validationLabels,segmentsPerFile] = vggishPreprocess(adsTest,overlapPercentage); clear net; net = vggish; lgraph = layerGraph(net.Layers); lgraph = removeLayers(lgraph,'regressionoutput'); lgraph.Layers(end); layers = lgraph.Layers; for i = 1:23 if isprop(layers(i),'WeightLearnRateFactor') layers(i).WeightLearnRateFactor = 0; end if isprop(layers(i),'WeightL2Factor') layers(i).WeightL2Factor = 0; end if isprop(layers(i),'BiasLearnRateFactor') layers(i).BiasLearnRateFactor = 0; end if isprop(layers(i),'BiasL2Factor') layers(i).BiasL2Factor = 0; end end % lgraph = addLayers(lgraph,dropoutLayer('Name','drop')); lgraph = addLayers(lgraph,fullyConnectedLayer(numClasses,'Name','FCFinal')); lgraph = addLayers(lgraph,softmaxLayer('Name','softmax')); lgraph = addLayers(lgraph,classificationLayer('Name','classOut')); % lgraph = connectLayers(lgraph,'EmbeddingBatch','drop'); lgraph = connectLayers(lgraph,'EmbeddingBatch','FCFinal'); lgraph = connectLayers(lgraph,'FCFinal','softmax'); lgraph = connectLayers(lgraph,'softmax','classOut'); miniBatchSize = 32; options = trainingOptions('adam',... 'MaxEpochs',20,'MiniBatchSize',miniBatchSize,... 'Shuffle','every-epoch', ... 'ValidationData',{validationFeatures,validationLabels}, ... 'Verbose',false, ... 'Plots','training-progress',... 'ValidationPatience',7,'ExecutionEnvironment','gpu','InitialLearnRate',0.001); [trainedNet, netInfo] = trainNetwork(trainFeatures,trainLabels,lgraph,options); validationPredictions = classify(trainedNet,validationFeatures); idx = 1; for ii = 1:numel(adsTest.Files) validationPredictionsPerFile(1,test_idx(ii)) = mode(validationPredictions(idx:idx+segmentsPerFile(ii)-1)); idx = idx + segmentsPerFile(ii); end save(['net',num2str(fold_idx),'.mat'],'trainedNet'); save(['info',num2str(fold_idx),'.mat'],'netInfo'); end actual_labels=ads.Labels; % Confusion Matrix figure; plotconfusion(actual_labels,validationPredictionsPerFile') title('Confusion Matrix: VGGish'); end