# Biaffine Model Main Class
# Yu Zhang (adapted by Claudia Greenberg)
# UCT Honours Project: 
#   Meaning Representation Parsing: 
#   The Edge Prediction Component of a Semantic Graph Parser
#   by Claudia Greenberg, GRNCLA009
# September 2022
# This is the class called in the command line.
# -*- coding: utf-8 -*-

import argparse

from supar import BiaffineSemanticDependencyParser
from supar.cmds.cmd import init


def main():
    # the class is left as "Biaffine Semantic Dependency Parser" to reduce futher complications in the code.
    parser = argparse.ArgumentParser(description='Create Biaffine Semantic Dependency Parser.')
    parser.set_defaults(Parser=BiaffineSemanticDependencyParser)
    subparsers = parser.add_subparsers(title='Commands', dest='mode')

    # train
    subparser = subparsers.add_parser('train', help='Train a parser.')
    subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'lemma', 'elmo', 'bert'], default = [], nargs='*', help='features to use')
    subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
    subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
    subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='lstm', help='encoder to use')
    subparser.add_argument('--max-len', type=int, help='max length of the sentences')
    subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
    subparser.add_argument('--train', default='data/train.conllu', help='path to train file')
    subparser.add_argument('--dev', default='data/dev.conllu', help='path to dev file')
    subparser.add_argument('--test', default='data/test.conllu', help='path to test file')
    subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
    subparser.add_argument('--n-embed-proj', default=125, type=int, help='dimension of projected embeddings')
    subparser.add_argument('--bert', default='SpanBERT/spanbert-base-cased', help='which BERT model to use')
    
    # evaluate
    subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
    subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
    
    # predict
    subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
    subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
    subparser.add_argument('--data', default='data/sdp/DM/test.conllu', help='path to dataset')
    subparser.add_argument('--pred', default='pred.conllu', help='path to predicted result')
    subparser.add_argument('--prob', action='store_true', help='whether to output probs')

    init(parser)


if __name__ == "__main__":
    main()
