/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.parser.StatementBlock;

public class RewriteAddChkpointLop
extends LopRewriteRule {
    @Override
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
        if (!ConfigurationManager.isCheckpointEnabled()) {
            return List.of(sb);
        }
        ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
        if (lops == null) {
            return List.of(sb);
        }
        ArrayList<Lop> sparkRoots = new ArrayList<Lop>();
        HashMap sparkOpCount = new HashMap();
        List<Lop> roots = lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
        roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
        if (sparkRoots.isEmpty()) {
            return List.of(sb);
        }
        HashMap<Long, Integer> operatorJobCount = new HashMap<Long, Integer>();
        RewriteAddChkpointLop.markPersistableSparkOps(sparkRoots, operatorJobCount);
        List<Lop> nodesWithChkpt = RewriteAddChkpointLop.addChkpointLop(lops, operatorJobCount);
        return List.of(sb);
    }

    @Override
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
        return sbs;
    }

    private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long, Integer> operatorJobCount) {
        ArrayList<Lop> nodesWithChkpt = new ArrayList<Lop>();
        for (Lop l : nodes) {
            nodesWithChkpt.add(l);
            if (!operatorJobCount.containsKey(l.getID()) || operatorJobCount.get(l.getID()) <= 1) continue;
            ArrayList<Lop> oldOuts = new ArrayList<Lop>(l.getOutputs());
            Checkpoint chkpoint = new Checkpoint(l, l.getDataType(), l.getValueType(), Checkpoint.getDefaultStorageLevelString(), false);
            for (Lop out : oldOuts) {
                chkpoint.addOutput(out);
                out.replaceInput(l, chkpoint);
                l.removeOutput(out);
            }
            nodesWithChkpt.add(chkpoint);
        }
        return nodesWithChkpt;
    }

    private static void markPersistableSparkOps(List<Lop> sparkRoots, Map<Long, Integer> operatorJobCount) {
        for (Lop root : sparkRoots) {
            RewriteAddChkpointLop.collectPersistableSparkOps(root, operatorJobCount);
            root.resetVisitStatus();
        }
    }

    private static void collectPersistableSparkOps(Lop root, Map<Long, Integer> operatorJobCount) {
        if (root.isVisited()) {
            return;
        }
        for (Lop input : root.getInputs()) {
            if (root.getBroadcastInput() == input) continue;
            RewriteAddChkpointLop.collectPersistableSparkOps(input, operatorJobCount);
        }
        if (OperatorOrderingUtils.isPersistableSparkOp(root)) {
            operatorJobCount.merge(root.getID(), 1, Integer::sum);
        }
        root.setVisited();
    }
}

