/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.component.srl;

import com.carrotsearch.hppc.IntOpenHashSet;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.dependency.DEPArc;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.dependency.srl.SRLLib;
import com.googlecode.clearnlp.feature.xml.FtrToken;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTOutput;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import com.googlecode.clearnlp.util.pair.StringIntPair;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

public class CSRLabeler
extends AbstractStatisticalComponent {
    private final String ENTRY_CONFIGURATION = "srl_CONFIGURATION";
    private final String ENTRY_FEATURE = "srl_FEATURE";
    private final String ENTRY_LEXICA = "srl_LEXICA";
    private final String ENTRY_MODEL = "srl_MODEL";
    protected final int LEXICA_PATH_UP = 0;
    protected final int LEXICA_PATH_DOWN = 1;
    protected final int PATH_ALL = 0;
    protected final int PATH_UP = 1;
    protected final int PATH_DOWN = 2;
    protected final int SUBCAT_ALL = 0;
    protected final int SUBCAT_LEFT = 1;
    protected final int SUBCAT_RIGHT = 2;
    protected final String LB_NO_ARG = "N";
    protected DEPNode d_lca;
    protected IntOpenHashSet s_skip;
    protected List<String> l_argns;
    protected StringIntPair[][] g_heads;
    protected DEPNode[] lm_deps;
    protected DEPNode[] rm_deps;
    protected DEPNode[] ln_sibs;
    protected DEPNode[] rn_sibs;
    protected int i_pred;
    protected int i_arg;
    protected Prob1DMap m_down;
    protected Prob1DMap m_up;
    protected Set<String> s_down;
    protected Set<String> s_up;

    public CSRLabeler(JointFtrXml[] xmls) {
        super(xmls);
        this.m_down = new Prob1DMap();
        this.m_up = new Prob1DMap();
    }

    public CSRLabeler(JointFtrXml[] xmls, StringTrainSpace[] spaces, Object[] lexica) {
        super(xmls, spaces, lexica);
    }

    public CSRLabeler(JointFtrXml[] xmls, StringModel[] models, Object[] lexica) {
        super(xmls, models, lexica);
    }

    public CSRLabeler(ZipInputStream in) {
        super(in);
    }

    public CSRLabeler(JointFtrXml[] xmls, StringTrainSpace[] spaces, StringModel[] models, Object[] lexica) {
        super(xmls, spaces, models, lexica);
    }

    @Override
    protected void initLexia(Object[] lexica) {
        this.s_down = (Set)lexica[0];
        this.s_up = (Set)lexica[1];
    }

    @Override
    public void loadModels(ZipInputStream zin) {
        int fLen = "srl_FEATURE".length();
        int mLen = "srl_MODEL".length();
        this.f_xmls = new JointFtrXml[1];
        this.s_models = null;
        try {
            ZipEntry zEntry;
            while ((zEntry = zin.getNextEntry()) != null) {
                String entry = zEntry.getName();
                if (entry.equals("srl_CONFIGURATION")) {
                    this.loadDefaultConfiguration(zin);
                    continue;
                }
                if (entry.startsWith("srl_FEATURE")) {
                    this.loadFeatureTemplates(zin, Integer.parseInt(entry.substring(fLen)));
                    continue;
                }
                if (entry.startsWith("srl_MODEL")) {
                    this.loadStatisticalModels(zin, Integer.parseInt(entry.substring(mLen)));
                    continue;
                }
                if (!entry.equals("srl_LEXICA")) continue;
                this.loadLexica(zin);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void loadLexica(ZipInputStream zin) throws Exception {
        BufferedReader fin = new BufferedReader(new InputStreamReader(zin));
        System.out.println("Loading lexica.");
        this.s_down = UTInput.getStringSet(fin);
        this.s_up = UTInput.getStringSet(fin);
    }

    @Override
    public void saveModels(ZipOutputStream zout) {
        try {
            this.saveDefaultConfiguration(zout, "srl_CONFIGURATION");
            this.saveFeatureTemplates(zout, "srl_FEATURE");
            this.saveLexica(zout);
            this.saveStatisticalModels(zout, "srl_MODEL");
            zout.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void saveLexica(ZipOutputStream zout) throws Exception {
        zout.putNextEntry(new ZipEntry("srl_LEXICA"));
        PrintStream fout = UTOutput.createPrintBufferedStream(zout);
        System.out.println("Saving lexica.");
        UTOutput.printSet(fout, this.s_down);
        fout.flush();
        UTOutput.printSet(fout, this.s_up);
        fout.flush();
        zout.closeEntry();
    }

    @Override
    public Object[] getLexica() {
        Object[] lexica = new Object[2];
        lexica[1] = this.i_flag == 0 ? this.m_down.toSet(this.f_xmls[0].getPathDownCutoff()) : this.s_down;
        lexica[0] = this.i_flag == 0 ? this.m_up.toSet(this.f_xmls[0].getPathUpCutoff()) : this.s_up;
        return lexica;
    }

    @Override
    public void countAccuracy(int[] counts) {
        int pTotal = 0;
        int rTotal = 0;
        int correct = 0;
        for (int i = 1; i < this.t_size; ++i) {
            List<DEPArc> sHeads = this.d_tree.get(i).getSHeads();
            StringIntPair[] gHeads = this.g_heads[i];
            pTotal += sHeads.size();
            rTotal += gHeads.length;
            for (StringIntPair p : gHeads) {
                for (DEPArc arc : sHeads) {
                    if (arc.getNode().id != p.i || !arc.isLabel(p.s)) continue;
                    ++correct;
                }
            }
        }
        counts[0] = counts[0] + correct;
        counts[1] = counts[1] + pTotal;
        counts[2] = counts[2] + rTotal;
    }

    protected void init(DEPTree tree) {
        this.d_tree = tree;
        this.t_size = tree.size();
        this.i_pred = this.getNextPredId(0);
        this.s_skip = new IntOpenHashSet();
        this.l_argns = new ArrayList<String>();
        if (this.i_flag != 2) {
            this.g_heads = tree.getSHeads();
            tree.clearSHeads();
        } else {
            tree.initSHeads();
        }
        this.initArcs();
    }

    private int getNextPredId(int prevId) {
        DEPNode pred = this.d_tree.getNextPredicate(prevId);
        return pred != null ? pred.id : this.d_tree.size();
    }

    private void initArcs() {
        this.lm_deps = new DEPNode[this.t_size];
        this.rm_deps = new DEPNode[this.t_size];
        this.ln_sibs = new DEPNode[this.t_size];
        this.rn_sibs = new DEPNode[this.t_size];
        this.d_tree.setDependents();
        for (int i = 1; i < this.t_size; ++i) {
            DEPNode curr;
            int j;
            List<DEPArc> deps = this.d_tree.get(i).getDependents();
            if (deps.isEmpty()) continue;
            int len = deps.size();
            DEPArc lmd = deps.get(0);
            DEPArc rmd = deps.get(len - 1);
            if (lmd.getNode().id < i) {
                this.lm_deps[i] = lmd.getNode();
            }
            if (rmd.getNode().id > i) {
                this.rm_deps[i] = rmd.getNode();
            }
            for (j = 1; j < len; ++j) {
                curr = deps.get(j).getNode();
                DEPNode prev = deps.get(j - 1).getNode();
                if (this.ln_sibs[curr.id] != null && this.ln_sibs[curr.id].id >= prev.id) continue;
                this.ln_sibs[curr.id] = prev;
            }
            for (j = 0; j < len - 1; ++j) {
                curr = deps.get(j).getNode();
                DEPNode next = deps.get(j + 1).getNode();
                if (this.rn_sibs[curr.id] != null && this.rn_sibs[curr.id].id <= next.id) continue;
                this.rn_sibs[curr.id] = next;
            }
        }
    }

    private void addLexica(DEPTree tree) {
        DEPNode pred = tree.getNextPredicate(0);
        tree.setDependents();
        while (pred != null) {
            for (DEPArc arc : pred.getGrandDependents()) {
                this.collectDown(pred, arc.getNode());
            }
            DEPNode head = pred.getHead();
            if (head != null) {
                this.collectUp(pred, head.getHead());
            }
            pred = tree.getNextPredicate(pred.id);
        }
    }

    private void collectDown(DEPNode pred, DEPNode arg) {
        if (arg.isArgumentOf(pred)) {
            for (String path : this.getDUPathList(pred, arg.getHead())) {
                this.m_down.add(path);
            }
        }
        for (DEPArc arc : arg.getDependents()) {
            this.collectDown(pred, arc.getNode());
        }
    }

    private void collectUp(DEPNode pred, DEPNode head) {
        if (head == null) {
            return;
        }
        for (DEPArc arc : head.getDependents()) {
            if (!arc.getNode().isArgumentOf(pred)) continue;
            for (String path : this.getDUPathList(head, pred)) {
                this.m_up.add(path);
            }
        }
        this.collectUp(pred, head.getHead());
    }

    private String getDUPath(DEPNode top, DEPNode bottom) {
        return this.getPathAux(top, bottom, "d", "|", true);
    }

    private List<String> getDUPathList(DEPNode top, DEPNode bottom) {
        ArrayList<String> paths = new ArrayList<String>();
        while (bottom != top) {
            paths.add(this.getDUPath(top, bottom));
            bottom = bottom.getHead();
        }
        return paths;
    }

    public Set<String> getDownSet(int cutoff) {
        return this.m_down.toSet(cutoff);
    }

    public Set<String> getUpSet(int cutoff) {
        return this.m_up.toSet(cutoff);
    }

    @Override
    public void process(DEPTree tree) {
        if (this.i_flag == 0) {
            this.addLexica(tree);
        } else {
            this.init(tree);
            this.label();
        }
    }

    private void label() {
        while (this.i_pred < this.t_size) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            this.s_skip.clear();
            this.s_skip.add(this.i_pred);
            this.s_skip.add(0);
            this.l_argns.clear();
            this.d_lca = pred;
            do {
                this.labelAux(pred, this.d_lca);
                this.d_lca = this.d_lca.getHead();
            } while (this.d_lca != null);
            this.i_pred = this.getNextPredId(this.i_pred);
        }
    }

    private void labelAux(DEPNode pred, DEPNode head) {
        if (!this.s_skip.contains(head.id)) {
            this.i_arg = head.id;
            this.addArgument(this.getLabel(this.getDirIndex()));
        }
        this.labelDown(pred, head.getDependents());
    }

    private void labelDown(DEPNode pred, List<DEPArc> arcs) {
        for (DEPArc arc : arcs) {
            DEPNode arg = arc.getNode();
            if (this.s_skip.contains(arg.id)) continue;
            this.i_arg = arg.id;
            this.addArgument(this.getLabel(this.getDirIndex()));
            if (this.i_pred != this.d_lca.id || !this.s_down.contains(this.getDUPath(pred, arg))) continue;
            this.labelDown(pred, arg.getDependents());
        }
    }

    private int getDirIndex() {
        return this.i_arg < this.i_pred ? 0 : 1;
    }

    private String getLabel(int idx) {
        StringFeatureVector vector = this.getFeatureVector(this.f_xmls[0]);
        String label = null;
        if (this.i_flag == 1) {
            label = this.getGoldLabel();
            this.s_spaces[idx].addInstance(label, vector);
        } else if (this.i_flag == 2 || this.i_flag == 4) {
            label = this.getAutoLabel(idx, vector);
        } else if (this.i_flag == 3) {
            label = this.getAutoLabel(idx, vector);
            this.s_spaces[idx].addInstance(this.getGoldLabel(), vector);
        }
        return label;
    }

    private String getGoldLabel() {
        for (StringIntPair head : this.g_heads[this.i_arg]) {
            if (head.i != this.i_pred) continue;
            return head.s;
        }
        return "N";
    }

    private String getAutoLabel(int idx, StringFeatureVector vector) {
        return this.s_models[idx].predictBest((StringFeatureVector)vector).label;
    }

    private void addArgument(String label) {
        this.s_skip.add(this.i_arg);
        if (!label.equals("N")) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            DEPNode arg = this.d_tree.get(this.i_arg);
            arg.addSHead(pred, label);
            if (SRLLib.isNumberedArgument(label)) {
                this.l_argns.add(label);
            }
        }
    }

    @Override
    protected String getField(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("f")) {
            return node.form;
        }
        if (token.isField("m")) {
            return node.lemma;
        }
        if (token.isField("p")) {
            return node.pos;
        }
        if (token.isField("d")) {
            return node.getLabel();
        }
        if (token.isField("n")) {
            return this.getDistance(node);
        }
        Matcher m = JointFtrXml.P_ARGN.matcher(token.field);
        if (m.find()) {
            int idx = this.l_argns.size() - Integer.parseInt(m.group(1)) - 1;
            return idx >= 0 ? this.l_argns.get(idx) : null;
        }
        m = JointFtrXml.P_PATH.matcher(token.field);
        if (m.find()) {
            String type = m.group(1);
            int dir = Integer.parseInt(m.group(2));
            return this.getPath(type, dir);
        }
        m = JointFtrXml.P_SUBCAT.matcher(token.field);
        if (m.find()) {
            String type = m.group(1);
            int dir = Integer.parseInt(m.group(2));
            return this.getSubcat(node, type, dir);
        }
        m = JointFtrXml.P_FEAT.matcher(token.field);
        if (m.find()) {
            return node.getFeat(m.group(1));
        }
        m = JointFtrXml.P_BOOLEAN.matcher(token.field);
        if (m.find()) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            int field = Integer.parseInt(m.group(1));
            switch (field) {
                case 0: {
                    return node.isDependentOf(pred) ? token.field : null;
                }
                case 1: {
                    return pred.isDependentOf(node) ? token.field : null;
                }
                case 2: {
                    return pred.isDependentOf(this.d_lca) ? token.field : null;
                }
                case 3: {
                    return pred == this.d_lca ? token.field : null;
                }
                case 4: {
                    return node == this.d_lca ? token.field : null;
                }
            }
        }
        return null;
    }

    @Override
    protected String[] getFields(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("ds")) {
            return this.getDeprelSet(node.getDependents());
        }
        if (token.isField("gds")) {
            return this.getDeprelSet(node.getGrandDependents());
        }
        return null;
    }

    private String[] getDeprelSet(List<DEPArc> deps) {
        if (deps.isEmpty()) {
            return null;
        }
        HashSet<String> set = new HashSet<String>();
        for (DEPArc arc : deps) {
            set.add(arc.getLabel());
        }
        String[] fields = new String[set.size()];
        set.toArray(fields);
        return fields;
    }

    private String getDistance(DEPNode node) {
        int dist = Math.abs(this.i_pred - node.id);
        if (dist <= 5) {
            return "0";
        }
        if (dist <= 10) {
            return "1";
        }
        if (dist <= 15) {
            return "2";
        }
        return "3";
    }

    private String getPath(String type, int dir) {
        DEPNode pred = this.d_tree.get(this.i_pred);
        DEPNode arg = this.d_tree.get(this.i_arg);
        if (dir == 1) {
            if (this.d_lca != pred) {
                return this.getPathAux(this.d_lca, pred, type, "^", true);
            }
        } else if (dir == 2) {
            if (this.d_lca != arg) {
                return this.getPathAux(this.d_lca, arg, type, "|", true);
            }
        } else {
            if (pred == this.d_lca) {
                return this.getPathAux(pred, arg, type, "|", true);
            }
            if (pred.isDescendentOf(arg)) {
                return this.getPathAux(arg, pred, type, "^", true);
            }
            String path = this.getPathAux(this.d_lca, pred, type, "^", true);
            path = path + this.getPathAux(this.d_lca, arg, type, "|", false);
            return path;
        }
        return null;
    }

    private String getPathAux(DEPNode top, DEPNode bottom, String type, String delim, boolean includeTop) {
        StringBuilder build = new StringBuilder();
        DEPNode head = bottom;
        int dist = 0;
        do {
            if (type.equals("p")) {
                build.append(delim);
                build.append(head.pos);
                continue;
            }
            if (type.equals("d")) {
                build.append(delim);
                build.append(head.getLabel());
                continue;
            }
            if (!type.equals("n")) continue;
            ++dist;
        } while ((head = head.getHead()) != top);
        if (type.equals("p")) {
            if (includeTop) {
                build.append(delim);
                build.append(top.pos);
            }
        } else if (type.equals("n")) {
            build.append(delim);
            build.append(dist);
        }
        return build.length() == 0 ? null : build.toString();
    }

    private String getSubcat(DEPNode node, String type, int dir) {
        List<DEPArc> deps = node.getDependents();
        StringBuilder build = new StringBuilder();
        int size = deps.size();
        if (dir == 1) {
            for (int i = 0; i < size; ++i) {
                DEPNode dep = deps.get(i).getNode();
                if (dep.id <= node.id) {
                    this.getSubcatAux(build, dep, type);
                    continue;
                }
                break;
            }
        } else if (dir == 2) {
            for (int i = size - 1; i >= 0; --i) {
                DEPNode dep = deps.get(i).getNode();
                if (dep.id >= node.id) {
                    this.getSubcatAux(build, dep, type);
                    continue;
                }
                break;
            }
        } else {
            for (int i = 0; i < size; ++i) {
                DEPNode dep = deps.get(i).getNode();
                this.getSubcatAux(build, dep, type);
            }
        }
        return build.length() == 0 ? null : build.substring("_".length());
    }

    private void getSubcatAux(StringBuilder build, DEPNode node, String type) {
        build.append("_");
        if (type.equals("p")) {
            build.append(node.pos);
        } else if (type.equals("d")) {
            build.append(node.getLabel());
        }
    }

    private DEPNode getNode(FtrToken token) {
        DEPNode node = null;
        switch (token.source) {
            case 'p': {
                node = this.d_tree.get(this.i_pred);
                break;
            }
            case 'a': {
                node = this.d_tree.get(this.i_arg);
            }
        }
        if (token.relation != null) {
            if (token.isRelation("h")) {
                node = node.getHead();
            } else if (token.isRelation("lmd")) {
                node = this.lm_deps[node.id];
            } else if (token.isRelation("rmd")) {
                node = this.rm_deps[node.id];
            } else if (token.isRelation("lns")) {
                node = this.ln_sibs[node.id];
            } else if (token.isRelation("rns")) {
                node = this.rn_sibs[node.id];
            }
        }
        return node;
    }
}

