{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":5170,"status":"ok","timestamp":1645444094777,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"},"user_tz":-480},"id":"Mvkd2nrxXVBz","outputId":"76a26df5-7e14-4160-9939-c9debb3c85c3"},"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting dgl==0.4.1\n"," Downloading dgl-0.4.1-cp37-cp37m-manylinux1_x86_64.whl (2.4 MB)\n","\u001b[K |████████████████████████████████| 2.4 MB 5.4 MB/s \n","\u001b[?25hRequirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from dgl==0.4.1) (1.4.1)\n","Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.7/dist-packages (from dgl==0.4.1) (1.21.5)\n","Requirement already satisfied: networkx>=2.1 in /usr/local/lib/python3.7/dist-packages (from dgl==0.4.1) (2.6.3)\n","Installing collected packages: dgl\n","Successfully installed dgl-0.4.1\n"]}],"source":["pip install dgl==0.4.1"]},{"cell_type":"code","source":["#pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html"],"metadata":{"id":"T2jpko6PGGbT"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":19494,"status":"ok","timestamp":1645444114264,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"},"user_tz":-480},"id":"HIH8LFuxnnyS","outputId":"4c950c42-ef26-4885-8c99-aa6193e8b453"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive\n","/content/gdrive/MyDrive/HeterSumGraph/tools/logger.py\n","# -*- coding: utf-8 -*-\n","\n","import logging\n","import sys\n","\n","logger = logging.getLogger(\"Summarization logger\")\n","\n","formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s')\n","\n","console_handler = logging.StreamHandler(sys.stdout)\n","console_handler.formatter = formatter \n","console_handler.setLevel(logging.INFO)\n","logger.addHandler(console_handler)\n","\n","logger.setLevel(logging.DEBUG)\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/gdrive')\n","!ls /content/gdrive/MyDrive/HeterSumGraph/tools/logger.py\n","!cat '/content/gdrive/MyDrive/HeterSumGraph/tools/logger.py'\n","import sys\n","sys.path.append('/content/gdrive/MyDrive/HeterSumGraph/tools')\n","from logger import *"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":4520,"status":"ok","timestamp":1645444118778,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"},"user_tz":-480},"id":"1g6msKaGN1BQ","outputId":"8cd5e0fc-5294-411b-cc3a-c3f748b3730d"},"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting rouge\n"," Downloading rouge-1.0.1-py3-none-any.whl (13 kB)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from rouge) (1.15.0)\n","Installing collected packages: rouge\n","Successfully installed rouge-1.0.1\n"]}],"source":["pip install rouge"]},{"cell_type":"code","source":["pip install git+git://github.com/bheinzerling/pyrouge"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NuNzVUhHGEcn","executionInfo":{"status":"ok","timestamp":1645444127881,"user_tz":-480,"elapsed":9108,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"}},"outputId":"eeab9526-f1af-40cb-a6e4-5d57768e04ff"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting git+git://github.com/bheinzerling/pyrouge\n"," Cloning git://github.com/bheinzerling/pyrouge to /tmp/pip-req-build-kb0y5z88\n"," Running command git clone -q git://github.com/bheinzerling/pyrouge /tmp/pip-req-build-kb0y5z88\n","Building wheels for collected packages: pyrouge\n"," Building wheel for pyrouge (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for pyrouge: filename=pyrouge-0.1.3-py3-none-any.whl size=191924 sha256=4a42ad02937db1368f03fdfe2b0f1b11f8ba90385793c258bd47bf2d8d88a0e4\n"," Stored in directory: /tmp/pip-ephem-wheel-cache-rispcffs/wheels/ad/5b/5e/2c225d0fead5b90f75e639f9015863c3a8d8fd2791e00d5b90\n","Successfully built pyrouge\n","Installing collected packages: pyrouge\n","Successfully installed pyrouge-0.1.3\n"]}]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":13841,"status":"ok","timestamp":1645444141715,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"},"user_tz":-480},"id":"8Io2HVagEsb9","outputId":"6e860608-c2b7-4cd1-8cef-5abb5d31b130"},"outputs":[{"output_type":"stream","name":"stdout","text":["[nltk_data] Downloading package stopwords to /root/nltk_data...\n","[nltk_data] Unzipping corpora/stopwords.zip.\n","Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n","evaluation.py PositionEmbedding.py README.md\ttools\n","HiGraph.py PrepareDataset.sh script\ttrain.py\n","module\t __pycache__\t Tester.py\n","cat: /content/gdrive/MyDrive/HeterSumGraph: Is a directory\n"]}],"source":["import nltk\n","nltk.download('stopwords')\n","from google.colab import drive\n","drive.mount('/content/gdrive')\n","!ls /content/gdrive/MyDrive/HeterSumGraph\n","!cat '/content/gdrive/MyDrive/HeterSumGraph'\n","import sys\n","sys.path.append('/content/gdrive/MyDrive/HeterSumGraph')\n","from tools.logger import *\n","from HiGraph import HSumGraph, HSumDocGraph\n","from Tester import SLTester\n","from module.dataloader import ExampleSet, MultiExampleSet, graph_collate_fn\n","from module.embedding import Word_Embedding\n","from module.vocabulary import Vocab"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4xagclsVM0a7"},"outputs":[],"source":["import argparse\n","import datetime\n","import os\n","import shutil\n","import time\n","import random\n","import dgl\n","import numpy as np\n","import torch\n","from rouge import Rouge"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"b7p9fT54Dy9Q"},"outputs":[],"source":["_DEBUG_FLAG_ = False"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"x-m9l0DIkooQ"},"outputs":[],"source":["def save_model(model, save_file):\n"," with open(save_file, 'wb') as f:\n"," torch.save(model.state_dict(), f)\n"," logger.info('[INFO] Saving model to %s', save_file)\n","\n","\n","def setup_training(model, train_loader, valid_loader, valset, hps):\n"," \"\"\" Does setup before starting training (run_training)\n"," \n"," :param model: the model\n"," :param train_loader: train dataset loader\n"," :param valid_loader: valid dataset loader\n"," :param valset: valid dataset which includes text and summary\n"," :param hps: hps for model\n"," :return: \n"," \"\"\"\n","\n"," train_dir = os.path.join(hps.save_root, \"train\")\n"," if os.path.exists(train_dir) and hps.restore_model != 'None':\n"," logger.info(\"[INFO] Restoring %s for training...\", hps.restore_model)\n"," bestmodel_file = os.path.join(train_dir, hps.restore_model)\n"," model.load_state_dict(torch.load(bestmodel_file))\n"," hps.save_root = hps.save_root + \"_reload\"\n"," else:\n"," logger.info(\"[INFO] Create new model for training...\")\n"," if os.path.exists(train_dir): shutil.rmtree(train_dir)\n"," os.makedirs(train_dir)\n","\n"," try:\n"," run_training(model, train_loader, valid_loader, valset, hps, train_dir)\n"," except KeyboardInterrupt:\n"," logger.error(\"[Error] Caught keyboard interrupt on worker. Stopping supervisor...\")\n"," save_model(model, os.path.join(train_dir, \"earlystop\"))\n","\n","\n","def run_training(model, train_loader, valid_loader, valset, hps, train_dir):\n"," ''' Repeatedly runs training iterations, logging loss to screen and log files\n"," \n"," :param model: the model\n"," :param train_loader: train dataset loader\n"," :param valid_loader: valid dataset loader\n"," :param valset: valid dataset which includes text and summary\n"," :param hps: hps for model\n"," :param train_dir: where to save checkpoints\n"," :return: \n"," '''\n"," logger.info(\"[INFO] Starting run_training\")\n","\n"," optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=hps.lr)\n","\n","\n"," criterion = torch.nn.CrossEntropyLoss(reduction='none')\n","\n"," best_train_loss = None\n"," best_loss = None\n"," best_F = None\n"," non_descent_cnt = 0\n"," saveNo = 0\n","\n"," for epoch in range(1, hps.n_epochs + 1):\n"," epoch_loss = 0.0\n"," train_loss = 0.0\n"," epoch_start_time = time.time()\n"," for i, (G, index) in enumerate(train_loader):\n"," iter_start_time = time.time()\n"," # if i > 10:\n"," # break\n"," model.train()\n","\n"," if hps.cuda:\n"," G.to(torch.device(\"cuda:0\"))\n"," G.to(torch.device(\"cuda:0\"))\n"," outputs = model.forward(G) # [n_snodes, 2]\n"," snode_id = G.filter_nodes(lambda nodes: nodes.data[\"dtype\"] == 1)\n"," label = G.ndata[\"label\"][snode_id].sum(-1) # [n_nodes]\n"," \n"," G.nodes[snode_id].data[\"loss\"] = criterion(outputs, label).unsqueeze(-1) # [n_nodes, 1]\n"," loss = dgl.sum_nodes(G, \"loss\") # [batch_size, 1]\n"," loss = loss.mean()\n","\n"," if not (np.isfinite(loss.data.cpu())).numpy():\n"," logger.error(\"train Loss is not finite. Stopping.\")\n"," logger.info(loss)\n"," for name, param in model.named_parameters():\n"," if param.requires_grad:\n"," logger.info(name)\n"," # logger.info(param.grad.data.sum())\n"," raise Exception(\"train Loss is not finite. Stopping.\")\n","\n"," optimizer.zero_grad()\n"," loss.backward()\n"," if hps.grad_clip:\n"," torch.nn.utils.clip_grad_norm_(model.parameters(), hps.max_grad_norm)\n","\n"," optimizer.step()\n","\n"," train_loss += float(loss.data)\n"," epoch_loss += float(loss.data)\n","\n"," if i % 100 == 0:\n"," if _DEBUG_FLAG_:\n"," for name, param in model.named_parameters():\n"," if param.requires_grad:\n"," logger.debug(name)\n"," logger.debug(param.grad.data.sum())\n"," logger.info(' | end of iter {:3d} | time: {:5.2f}s | train loss {:5.4f} | '\n"," .format(i, (time.time() - iter_start_time),float(train_loss / 100)))\n"," train_loss = 0.0\n","\n"," if hps.lr_descent:\n"," new_lr = max(5e-6, hps.lr / (epoch + 1))\n"," for param_group in list(optimizer.param_groups):\n"," param_group['lr'] = new_lr\n"," logger.info(\"[INFO] The learning rate now is %f\", new_lr)\n","\n"," epoch_avg_loss = epoch_loss / len(train_loader)\n"," logger.info(' | end of epoch {:3d} | time: {:5.2f}s | epoch train loss {:5.4f} | '\n"," .format(epoch, (time.time() - epoch_start_time), float(epoch_avg_loss)))\n","\n"," if not best_train_loss or epoch_avg_loss < best_train_loss:\n"," save_file = os.path.join(train_dir, \"bestmodel\")\n"," logger.info('[INFO] Found new best model with %.3f running_train_loss. Saving to %s', float(epoch_avg_loss),\n"," save_file)\n"," save_model(model, save_file)\n"," best_train_loss = epoch_avg_loss\n"," elif epoch_avg_loss >= best_train_loss:\n"," logger.error(\"[Error] training loss does not descent. Stopping supervisor...\")\n"," save_model(model, os.path.join(train_dir, \"earlystop\"))\n"," sys.exit(1)\n","\n"," best_loss, best_F, non_descent_cnt, saveNo = run_eval(model, valid_loader, valset, hps, best_loss, best_F, non_descent_cnt, saveNo)\n","\n"," if non_descent_cnt >= 3:\n"," logger.error(\"[Error] val loss does not descent for three times. Stopping supervisor...\")\n"," save_model(model, os.path.join(train_dir, \"earlystop\"))\n"," return\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MCfsK1liksLv"},"outputs":[],"source":["def run_eval(model, loader, valset, hps, best_loss, best_F, non_descent_cnt, saveNo):\n"," ''' \n"," Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.\n"," :param model: the model\n"," :param loader: valid dataset loader\n"," :param valset: valid dataset which includes text and summary\n"," :param hps: hps for model\n"," :param best_loss: best valid loss so far\n"," :param best_F: best valid F so far\n"," :param non_descent_cnt: the number of non descent epoch (for early stop)\n"," :param saveNo: the number of saved models (always keep best saveNo checkpoints)\n"," :return: \n"," '''\n"," logger.info(\"[INFO] Starting eval for this model ...\")\n"," eval_dir = os.path.join(hps.save_root, \"eval\") # make a subdir of the root dir for eval data\n"," if not os.path.exists(eval_dir): os.makedirs(eval_dir)\n","\n"," model.eval()\n","\n"," iter_start_time = time.time()\n","\n"," with torch.no_grad():\n"," tester = SLTester(model, hps.m)\n"," for i, (G, index) in enumerate(loader):\n"," if hps.cuda:\n"," G.to(torch.device(\"cuda:0\"))\n"," tester.evaluation(G.to(torch.device(\"cuda:0\")), index, valset)\n","\n"," running_avg_loss = tester.running_avg_loss\n","\n"," if len(tester.hyps) == 0 or len(tester.refer) == 0:\n"," logger.error(\"During testing, no hyps is selected!\")\n"," return\n"," rouge = Rouge()\n"," scores_all = rouge.get_scores(tester.hyps, tester.refer, avg=True)\n"," logger.info('[INFO] End of valid | time: {:5.2f}s | valid loss {:5.4f} | ' .format((time.time() - iter_start_time), float(running_avg_loss)))\n","\n"," res = \"Rouge1:\\n\\tp:%.6f, r:%.6f, f:%.6f\\n\" % (\n"," scores_all['rouge-1']['p'], scores_all['rouge-1']['r'], scores_all['rouge-1']['f']) \\\n"," + \"Rouge2:\\n\\tp:%.6f, r:%.6f, f:%.6f\\n\" % (\n"," scores_all['rouge-2']['p'], scores_all['rouge-2']['r'], scores_all['rouge-2']['f']) \\\n"," + \"Rougel:\\n\\tp:%.6f, r:%.6f, f:%.6f\\n\" % (\n"," scores_all['rouge-l']['p'], scores_all['rouge-l']['r'], scores_all['rouge-l']['f'])\n"," logger.info(res)\n","\n"," tester.getMetric()\n"," F = tester.labelMetric\n","\n"," if best_loss is None or running_avg_loss < best_loss:\n"," bestmodel_save_path = os.path.join(eval_dir, 'bestmodel_%d' % (saveNo % 3)) # this is where checkpoints of best models are saved\n"," if best_loss is not None:\n"," logger.info(\n"," '[INFO] Found new best model with %.6f running_avg_loss. The original loss is %.6f, Saving to %s',\n"," float(running_avg_loss), float(best_loss), bestmodel_save_path)\n"," else:\n"," logger.info(\n"," '[INFO] Found new best model with %.6f running_avg_loss. The original loss is None, Saving to %s',\n"," float(running_avg_loss), bestmodel_save_path)\n"," with open(bestmodel_save_path, 'wb') as f:\n"," torch.save(model.state_dict(), f)\n"," best_loss = running_avg_loss\n"," non_descent_cnt = 0\n"," saveNo += 1\n"," else:\n"," non_descent_cnt += 1\n"," print(\"到这里了\")\n"," if best_F is None or best_F < F:\n"," bestmodel_save_path = os.path.join(eval_dir, 'bestFmodel') # this is where checkpoints of best models are saved\n"," if best_F is not None:\n"," logger.info('[INFO] Found new best model with %.6f F. The original F is %.6f, Saving to %s', float(F),\n"," float(best_F), bestmodel_save_path)\n"," else:\n"," logger.info('[INFO] Found new best model with %.6f F. The original F is None, Saving to %s', float(F),\n"," bestmodel_save_path) \n"," print(\"然后是这里\") \n"," with open(bestmodel_save_path, 'wb') as f:\n"," torch.save(model.state_dict(), f)\n"," best_F = F\n","\n"," return best_loss, best_F, non_descent_cnt, "]},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":249},"id":"sWJxRu2xkvfv","outputId":"2a5c29be-1989-46c6-d01b-7a6c7f9989bd","executionInfo":{"status":"error","timestamp":1645452087999,"user_tz":-480,"elapsed":794,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"}}},"outputs":[{"output_type":"error","ename":"NameError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margparse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentParser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdescription\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'HeterSumGraph Model'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m# Where to find data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--data_dir'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'/content/gdrive/MyDrive/MultiNews'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mhelp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'The dataset directory.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mNameError\u001b[0m: name 'argparse' is not defined"]}],"source":["\n","parser = argparse.ArgumentParser(description='HeterSumGraph Model')\n","\n","# Where to find data\n","parser.add_argument('--data_dir', type=str, default='/content/gdrive/MyDrive/MultiNews',help='The dataset directory.')\n","parser.add_argument('--cache_dir', type=str, default='/content/gdrive/MyDrive/MultiNews',help='The processed dataset directory')\n","parser.add_argument('--embedding_path', type=str, default='/content/gdrive/MyDrive/glove/glove.42B.300d.txt', help='Path expression to external word embedding.')\n","\n","# Important settings\n","parser.add_argument('--model', type=str, default='HSG', help='model structure[HSG|HDSG]')\n","parser.add_argument('--restore_model', type=str, default='None', help='Restore model for further training. [bestmodel/bestFmodel/earlystop/None]')\n","\n","# Where to save output\n","parser.add_argument('--save_root', type=str, default='save/', help='Root directory for all model.')\n","parser.add_argument('--log_root', type=str, default='log/', help='Root directory for all logging.')\n","\n","# Hyperparameters\n","parser.add_argument('--seed', type=int, default=666, help='set the random seed [default: 666]')\n","parser.add_argument('--gpu', type=str, default='0', help='GPU ID to use. [default: 0]')\n","parser.add_argument('--cuda', action='store_true', default=True, help='GPU or CPU [default: False]')\n","parser.add_argument('--vocab_size', type=int, default=50000,help='Size of vocabulary. [default: 50000]')\n","parser.add_argument('--n_epochs', type=int, default=20, help='Number of epochs [default: 20]')\n","parser.add_argument('--batch_size', type=int, default=32, help='Mini batch size [default: 32]')\n","parser.add_argument('--n_iter', type=int, default=1, help='iteration hop [default: 1]')\n","\n","parser.add_argument('--word_embedding', action='store_true', default=True, help='whether to use Word embedding [default: True]')\n","parser.add_argument('--word_emb_dim', type=int, default=300, help='Word embedding size [default: 300]')\n","parser.add_argument('--embed_train', action='store_true', default=False,help='whether to train Word embedding [default: False]')\n","parser.add_argument('--feat_embed_size', type=int, default=50, help='feature embedding size [default: 50]')\n","parser.add_argument('--n_layers', type=int, default=1, help='Number of GAT layers [default: 1]')\n","parser.add_argument('--lstm_hidden_state', type=int, default=128, help='size of lstm hidden state [default: 128]')\n","parser.add_argument('--lstm_layers', type=int, default=2, help='Number of lstm layers [default: 2]')\n","parser.add_argument('--bidirectional', action='store_true', default=True, help='whether to use bidirectional LSTM [default: True]')\n","parser.add_argument('--n_feature_size', type=int, default=128, help='size of node feature [default: 128]')\n","parser.add_argument('--hidden_size', type=int, default=64, help='hidden size [default: 64]')\n","parser.add_argument('--ffn_inner_hidden_size', type=int, default=512,help='PositionwiseFeedForward inner hidden size [default: 512]')\n","parser.add_argument('--n_head', type=int, default=8, help='multihead attention number [default: 8]')\n","parser.add_argument('--recurrent_dropout_prob', type=float, default=0.1,help='recurrent dropout prob [default: 0.1]')\n","parser.add_argument('--atten_dropout_prob', type=float, default=0.1, help='attention dropout prob [default: 0.1]')\n","parser.add_argument('--ffn_dropout_prob', type=float, default=0.1,help='PositionwiseFeedForward dropout prob [default: 0.1]')\n","parser.add_argument('--use_orthnormal_init', action='store_true', default=True,help='use orthnormal init for lstm [default: True]')\n","parser.add_argument('--sent_max_len', type=int, default=100,help='max length of sentences (max source text sentence tokens)')\n","parser.add_argument('--doc_max_timesteps', type=int, default=50,help='max length of documents (max timesteps of documents)')\n","\n","# Training\n","parser.add_argument('--lr', type=float, default=0.0005, help='learning rate')\n","parser.add_argument('--lr_descent', action='store_true', default=False, help='learning rate descent')\n","parser.add_argument('--grad_clip', action='store_true', default=False, help='for gradient clipping')\n","parser.add_argument('--max_grad_norm', type=float, default=1.0, help='for gradient clipping max gradient normalization')\n","\n","parser.add_argument('-m', type=int, default=3, help='decode summary length')\n","\n","args = parser.parse_args(args = [])\n"," \n","# set the seed\n","random.seed(args.seed)\n","np.random.seed(args.seed)\n","torch.manual_seed(args.seed)\n"," \n","os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu\n","torch.set_printoptions(threshold=50000)\n","\n","# File paths\n","DATA_FILE = os.path.join(args.data_dir, \"train.label.jsonl\")\n","VALID_FILE = os.path.join(args.data_dir, \"val.label.jsonl\")\n","VOCAL_FILE = os.path.join(args.cache_dir, \"vocab\")\n","FILTER_WORD = os.path.join(args.cache_dir, \"filter_word.txt\")\n","LOG_PATH = args.log_root\n","\n","\n","# train_log setting\n","if not os.path.exists(LOG_PATH):\n"," os.makedirs(LOG_PATH)\n","nowTime = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')\n","log_path = os.path.join(LOG_PATH, \"train_\" + nowTime)\n","file_handler = logging.FileHandler(log_path)\n","file_handler.setFormatter(formatter)\n","logger.addHandler(file_handler)\n","\n","logger.info(\"Pytorch %s\", torch.__version__)\n","logger.info(\"[INFO] Create Vocab, vocab path is %s\", VOCAL_FILE)\n","vocab = Vocab(VOCAL_FILE, args.vocab_size)\n","embed = torch.nn.Embedding(vocab.size(), args.word_emb_dim, padding_idx=0)\n","if args.word_embedding:\n"," embed_loader = Word_Embedding(args.embedding_path, vocab)\n"," vectors = embed_loader.load_my_vecs(args.word_emb_dim)\n"," pretrained_weight = embed_loader.add_unknown_words_by_avg(vectors, args.word_emb_dim)\n"," embed.weight.data.copy_(torch.Tensor(pretrained_weight))\n"," embed.weight.requires_grad = args.embed_train\n","\n","hps = args\n","logger.info(hps)\n","\n","train_w2s_path = os.path.join(args.cache_dir, \"train.w2s.tfidf.jsonl\")\n","val_w2s_path = os.path.join(args.cache_dir, \"val.w2s.tfidf.jsonl\")\n","\n","if hps.model == \"HSG\":\n"," model = HSumGraph(hps, embed)\n"," logger.info(\"[MODEL] HeterSumGraph \")\n"," dataset = ExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, train_w2s_path)\n"," train_loader = torch.utils.data.DataLoader(dataset, batch_size=hps.batch_size, shuffle=True, num_workers=32,collate_fn=graph_collate_fn)\n"," del dataset\n"," valid_dataset = ExampleSet(VALID_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, val_w2s_path)\n"," valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=hps.batch_size, shuffle=False, collate_fn=graph_collate_fn, num_workers=32)\n","elif hps.model == \"HDSG\":\n"," model = HSumDocGraph(hps, embed)\n"," logger.info(\"[MODEL] HeterDocSumGraph \")\n"," train_w2d_path = os.path.join(args.cache_dir, \"train.w2d.tfidf.jsonl\")\n"," dataset = MultiExampleSet(DATA_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, train_w2s_path, train_w2d_path)\n"," train_loader = torch.utils.data.DataLoader(dataset, batch_size=hps.batch_size, shuffle=True, num_workers=32,collate_fn=graph_collate_fn)\n"," del dataset\n"," val_w2d_path = os.path.join(args.cache_dir, \"val.w2d.tfidf.jsonl\")\n"," valid_dataset = MultiExampleSet(VALID_FILE, vocab, hps.doc_max_timesteps, hps.sent_max_len, FILTER_WORD, val_w2s_path, val_w2d_path)\n"," valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=hps.batch_size, shuffle=False,collate_fn=graph_collate_fn, num_workers=32) # Shuffle Must be False for ROUGE evaluation\n","else:\n"," logger.error(\"[ERROR] Invalid Model Type!\")\n"," raise NotImplementedError(\"Model Type has not been implemented\")\n","\n","\n","if args.cuda:\n"," model.to(torch.device(\"cuda:0\"))\n"," print(\"使用显卡\")\n"," logger.info(\"[INFO] Use cuda\")\n","\n","setup_training(model, train_loader, valid_loader, valid_dataset, hps)"]},{"cell_type":"code","source":["import torch\n","print(torch.version.cuda)\n","import platform\n","print(platform.python_version())"],"metadata":{"id":"cJ8SXwCpCGRZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi\n","!pip install gputil\n","!pip install psutil\n","!pip install humanize"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"E8O35jKAl_IN","executionInfo":{"status":"ok","timestamp":1645452210522,"user_tz":-480,"elapsed":13480,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"}},"outputId":"d1d78340-e68b-41dd-ee3c-c794c10327f0"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting gputil\n"," Downloading GPUtil-1.4.0.tar.gz (5.5 kB)\n","Building wheels for collected packages: gputil\n"," Building wheel for gputil (setup.py) ... \u001b[?25l\u001b[?25hdone\n"," Created wheel for gputil: filename=GPUtil-1.4.0-py3-none-any.whl size=7411 sha256=15f15081af18a0235440162b157cf80c65a52e3fd9516baffbbdb4a0ed2a598c\n"," Stored in directory: /root/.cache/pip/wheels/6e/f8/83/534c52482d6da64622ddbf72cd93c35d2ef2881b78fd08ff0c\n","Successfully built gputil\n","Installing collected packages: gputil\n","Successfully installed gputil-1.4.0\n","Requirement already satisfied: psutil in /usr/local/lib/python3.7/dist-packages (5.4.8)\n","Requirement already satisfied: humanize in /usr/local/lib/python3.7/dist-packages (0.5.1)\n"]}]},{"cell_type":"code","source":["import psutil\n","import humanize\n","import os\n","import GPUtil as GPU\n","\n","GPUs = GPU.getGPUs()\n","# XXX: only one GPU on Colab and isn’t guaranteed\n","gpu = GPUs[0]\n","def printm():\n"," process = psutil.Process(os.getpid())\n"," print(\"Gen RAM Free: \" + humanize.naturalsize(psutil.virtual_memory().available), \" | Proc size: \" + humanize.naturalsize(process.memory_info().rss))\n"," print(\"GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB\".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))\n","printm()\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zb9-pMaBmCh8","executionInfo":{"status":"ok","timestamp":1645452226154,"user_tz":-480,"elapsed":324,"user":{"displayName":"su haha","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"03290838865279118816"}},"outputId":"11984d36-7987-466d-f160-6cc487b7ad3f"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["Gen RAM Free: 720.7 MB | Proc size: 141.5 MB\n","GPU RAM Free: 9368MB | Used: 2073MB | Util 18% | Total 11441MB\n"]}]}],"metadata":{"colab":{"collapsed_sections":[],"name":"PositionEmbedding.py","provenance":[],"mount_file_id":"1S5v5sC_x4J5YAO0Q8jABNUnEaOZjtjpe","authorship_tag":"ABX9TyNBCEHZ6/L1UJGvlccw6IuZ"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}