import argparse import utils import normal import MIA def train_networks(args): device = utils.get_pytorch_device() utils.create_path('./outputs') if 'distill' in args.mode: model_path_tar = 'networks/{}/{}'.format(0, args.mode.split('_')[-1]) utils.create_path(model_path_tar) model_path_dis = 'networks/{}/{}'.format(args.seed, args.mode) utils.create_path(model_path_dis) else: model_path_tar = 'networks/{}/{}'.format(args.seed, args.mode) utils.create_path(model_path_tar) model_path_dis = None utils.set_logger('outputs/train_models'.format(args.seed)) normal.train_models(args, model_path_tar, model_path_dis, device) def membership_inference_attack(args): print(f'--------------{args.mia_type}-------------') device = utils.get_pytorch_device() if args.mia_type == 'build-out-dataset': models_path = 'networks/{}'.format(0) MIA.build_aug_out_trajectory_membership_dataset(args, models_path, device) if args.mia_type == 'build-poison-dataset': models_path = 'networks/{}'.format(0) MIA.build_aug_posion_trajectory_membership_dataset(args, models_path, device) if args.mia_type == 'poison-black-box': trained_models_path = 'networks/{}'.format(args.seed) MIA.poison_aug_trajectory_black_box_membership_inference_attack(args, trained_models_path, device) if args.mia_type == 'out-black-box': trained_models_path = 'networks/{}'.format(args.seed) MIA.out_aug_trajectory_black_box_membership_inference_attack(args, trained_models_path, device) if __name__ == '__main__': parser = argparse.ArgumentParser(description='DSMIM-MIA') parser.add_argument('--action', type=int, default=0, help=[0, 1]) parser.add_argument('--aug_num', type=int, default=8, help=[1,2,4,8]) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--mode', type=str, default='target', help=['target', 'shadow', 'distill_poison_shadow', 'distill_poison_target', 'distill_out_shadow', 'distill_out_target']) parser.add_argument('--model', type=str, default='resnet', help=['resnet', 'mobilenet', 'vgg', 'wideresnet']) parser.add_argument('--data', type=str, default='cifar10', help=['cinic10', 'cifar10', 'cifar100']) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--model_distill', type=str, default='resnet', help=['resnet', 'mobilenet', 'vgg', 'wideresnet']) parser.add_argument('--epochs_distill', type=int, default=60) parser.add_argument('--nums_poison', type=int, default=16) parser.add_argument('--nums_valid', type=int, default=32) parser.add_argument('--nums_aug_valid', type=int, default=8) parser.add_argument('--nums_candidate', type=int, default=5000,help=[100,1000,5000,10000]) parser.add_argument('--mia_type', type=str, default='build-out-dataset', help=['build-poison-dataset', 'poison-black-box','out-black-box','build-out-dataset']) args = parser.parse_args() utils.set_random_seeds(args.seed) print('random seed:{}'.format(args.seed)) if args.action == 0: train_networks(args) elif args.action == 1: membership_inference_attack(args)