############################################################################################## TRAIN T5-CQR ############################################################################################## export PROJECT=${PROJECT} export ZONE=${ZONE} export BUCKET="gs://FOLDER/" export TPU_NAME=${TPU_NAME} export TPU_SIZE=v3-8 export DATA_DIR="gs://FOLDER/data" export MODEL_DIR="gs://FOLDER/models" t5_mesh_transformer \ --tpu="${TPU_NAME}" \ --gcp_project="${PROJECT}" \ --tpu_zone="${ZONE}" \ --model_dir="gs://FOLDER/model_large/" \ --gin_param="init_checkpoint = 'gs://t5-data/pretrained_models/large/model.ckpt-1000700'" \ --gin_file="dataset.gin" \ --gin_file="models/bi_v1.gin" \ --gin_file="gs://t5-data/pretrained_models/large/operative_config.gin" \ --gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \ --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \ --gin_param="utils.run.train_dataset_fn = @t5.models.mesh_transformer.tsv_dataset_fn" \ --gin_param="tsv_dataset_fn.filename = 'gs://FOLDER/data/history_query_pairs.train.tsv'" \ --gin_file="learning_rate_schedules/constant_0_001.gin" \ --gin_param="run.train_steps = 1001000" \ ############################################################################################## RUN T5-CQR ############################################################################################## t5_mesh_transformer \ --tpu="${TPU_NAME}" \ --gcp_project="${PROJECT}" \ --tpu_zone="${ZONE}" \ --model_dir="gs://FOLDER/model_large/" \ --gin_file="gs://t5-data/pretrained_models/large/operative_config.gin" \ --gin_file="infer.gin" \ --gin_file="beam_search.gin" \ --gin_file="sample_decode.gin" \ --gin_param="utils.tpu_mesh_shape.tpu_topology = '${TPU_SIZE}'" \ --gin_param="infer_checkpoint_step = 1001000" \ --gin_param="utils.run.sequence_length = {'inputs': 512, 'targets': 64}" \ --gin_param="Bitransformer.decode.max_decode_length = 64" \ --gin_param="input_filename = 'gs://FOLDER/data/CAST_2021_raw_again_T5_in.txt'" \ --gin_param="output_filename = 'gs://FOLDER/data/pred_CAST_2021_raw_again_T5_out.txt'" \ --gin_param="Bitransformer.decode.temperature = 0.0" \ --gin_param="tokens_per_batch = 131072" \ --gin_param="Unitransformer.sample_autoregressive.sampling_keep_top_k = -1" ##############################################################################################