/*
 * Decompiled with CFR 0.152.
 */
package heros.solver;

import com.google.common.collect.Maps;
import heros.EdgeFunction;
import heros.EdgeFunctions;
import heros.FlowFunction;
import heros.FlowFunctions;
import heros.IDETabulationProblem;
import heros.InterproceduralCFG;
import heros.JoinLattice;
import heros.solver.CountingThreadPoolExecutor;
import heros.solver.IDESolver;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

public class BiDiIDESolver<N, D, M, V, I extends InterproceduralCFG<N, M>> {
    private final IDETabulationProblem<N, AbstractionWithSourceStmt, M, V, I> forwardProblem;
    private final IDETabulationProblem<N, AbstractionWithSourceStmt, M, V, I> backwardProblem;
    private final CountingThreadPoolExecutor sharedExecutor;
    protected SingleDirectionSolver fwSolver;
    protected SingleDirectionSolver bwSolver;

    public BiDiIDESolver(IDETabulationProblem<N, D, M, V, I> forwardProblem, IDETabulationProblem<N, D, M, V, I> backwardProblem) {
        if (!forwardProblem.followReturnsPastSeeds() || !backwardProblem.followReturnsPastSeeds()) {
            throw new IllegalArgumentException("This solver is only meant for bottom-up problems, so followReturnsPastSeeds() should return true.");
        }
        this.forwardProblem = new AugmentedTabulationProblem(forwardProblem);
        this.backwardProblem = new AugmentedTabulationProblem(backwardProblem);
        this.sharedExecutor = new CountingThreadPoolExecutor(1, Math.max(1, forwardProblem.numThreads()), 30L, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>());
    }

    public void solve() {
        this.fwSolver = this.createSingleDirectionSolver(this.forwardProblem, "FW");
        this.bwSolver = this.createSingleDirectionSolver(this.backwardProblem, "BW");
        this.fwSolver.otherSolver = this.bwSolver;
        this.bwSolver.otherSolver = this.fwSolver;
        this.bwSolver.submitInitialSeeds();
        this.fwSolver.solve();
    }

    protected SingleDirectionSolver createSingleDirectionSolver(IDETabulationProblem<N, AbstractionWithSourceStmt, M, V, I> problem, String debugName) {
        return new SingleDirectionSolver(problem, debugName);
    }

    private class AugmentedTabulationProblem
    implements IDETabulationProblem<N, AbstractionWithSourceStmt, M, V, I> {
        private final IDETabulationProblem<N, D, M, V, I> delegate;
        private final AbstractionWithSourceStmt ZERO;
        private final FlowFunctions<N, D, M> originalFunctions;

        public AugmentedTabulationProblem(IDETabulationProblem<N, D, M, V, I> delegate) {
            this.delegate = delegate;
            this.originalFunctions = this.delegate.flowFunctions();
            this.ZERO = new AbstractionWithSourceStmt(delegate.zeroValue(), null);
        }

        @Override
        public FlowFunctions<N, AbstractionWithSourceStmt, M> flowFunctions() {
            return new FlowFunctions<N, AbstractionWithSourceStmt, M>(){

                @Override
                public FlowFunction<AbstractionWithSourceStmt> getNormalFlowFunction(final N curr, final N succ) {
                    return new FlowFunction<AbstractionWithSourceStmt>(){

                        @Override
                        public Set<AbstractionWithSourceStmt> computeTargets(AbstractionWithSourceStmt source) {
                            return this.copyOverSourceStmts(source, AugmentedTabulationProblem.this.originalFunctions.getNormalFlowFunction(curr, succ));
                        }
                    };
                }

                @Override
                public FlowFunction<AbstractionWithSourceStmt> getCallFlowFunction(final N callStmt, final M destinationMethod) {
                    return new FlowFunction<AbstractionWithSourceStmt>(){

                        @Override
                        public Set<AbstractionWithSourceStmt> computeTargets(AbstractionWithSourceStmt source) {
                            Set origTargets = AugmentedTabulationProblem.this.originalFunctions.getCallFlowFunction(callStmt, destinationMethod).computeTargets(source.getAbstraction());
                            HashSet<AbstractionWithSourceStmt> res = new HashSet<AbstractionWithSourceStmt>();
                            for (Object d : origTargets) {
                                res.add(new AbstractionWithSourceStmt(d, null));
                            }
                            return res;
                        }
                    };
                }

                @Override
                public FlowFunction<AbstractionWithSourceStmt> getReturnFlowFunction(final N callSite, final M calleeMethod, final N exitStmt, final N returnSite) {
                    return new FlowFunction<AbstractionWithSourceStmt>(){

                        @Override
                        public Set<AbstractionWithSourceStmt> computeTargets(AbstractionWithSourceStmt source) {
                            return this.copyOverSourceStmts(source, AugmentedTabulationProblem.this.originalFunctions.getReturnFlowFunction(callSite, calleeMethod, exitStmt, returnSite));
                        }
                    };
                }

                @Override
                public FlowFunction<AbstractionWithSourceStmt> getCallToReturnFlowFunction(final N callSite, final N returnSite) {
                    return new FlowFunction<AbstractionWithSourceStmt>(){

                        @Override
                        public Set<AbstractionWithSourceStmt> computeTargets(AbstractionWithSourceStmt source) {
                            return this.copyOverSourceStmts(source, AugmentedTabulationProblem.this.originalFunctions.getCallToReturnFlowFunction(callSite, returnSite));
                        }
                    };
                }

                private Set<AbstractionWithSourceStmt> copyOverSourceStmts(AbstractionWithSourceStmt source, FlowFunction<D> originalFunction) {
                    Object originalAbstraction = source.getAbstraction();
                    Set origTargets = originalFunction.computeTargets(originalAbstraction);
                    HashSet<AbstractionWithSourceStmt> res = new HashSet<AbstractionWithSourceStmt>();
                    for (Object d : origTargets) {
                        res.add(new AbstractionWithSourceStmt(d, source.getSourceStmt()));
                    }
                    return res;
                }
            };
        }

        @Override
        public boolean followReturnsPastSeeds() {
            return this.delegate.followReturnsPastSeeds();
        }

        @Override
        public boolean autoAddZero() {
            return this.delegate.autoAddZero();
        }

        @Override
        public int numThreads() {
            return this.delegate.numThreads();
        }

        @Override
        public boolean computeValues() {
            return this.delegate.computeValues();
        }

        @Override
        public I interproceduralCFG() {
            return this.delegate.interproceduralCFG();
        }

        @Override
        public Map<N, Set<AbstractionWithSourceStmt>> initialSeeds() {
            Map originalSeeds = this.delegate.initialSeeds();
            HashMap res = new HashMap();
            for (Map.Entry entry : originalSeeds.entrySet()) {
                Object stmt = entry.getKey();
                Set seeds = entry.getValue();
                HashSet<AbstractionWithSourceStmt> resSet = new HashSet<AbstractionWithSourceStmt>();
                for (Object d : seeds) {
                    resSet.add(new AbstractionWithSourceStmt(d, stmt));
                }
                res.put(stmt, resSet);
            }
            return res;
        }

        @Override
        public AbstractionWithSourceStmt zeroValue() {
            return this.ZERO;
        }

        @Override
        public EdgeFunctions<N, AbstractionWithSourceStmt, M, V> edgeFunctions() {
            return new EdgeFunctions<N, AbstractionWithSourceStmt, M, V>(){

                @Override
                public EdgeFunction<V> getNormalEdgeFunction(N curr, AbstractionWithSourceStmt currNode, N succ, AbstractionWithSourceStmt succNode) {
                    return AugmentedTabulationProblem.this.delegate.edgeFunctions().getNormalEdgeFunction(curr, currNode.getAbstraction(), succ, succNode.getAbstraction());
                }

                @Override
                public EdgeFunction<V> getCallEdgeFunction(N callStmt, AbstractionWithSourceStmt srcNode, M destinationMethod, AbstractionWithSourceStmt destNode) {
                    return AugmentedTabulationProblem.this.delegate.edgeFunctions().getCallEdgeFunction(callStmt, srcNode.getAbstraction(), destinationMethod, destNode.getAbstraction());
                }

                @Override
                public EdgeFunction<V> getReturnEdgeFunction(N callSite, M calleeMethod, N exitStmt, AbstractionWithSourceStmt exitNode, N returnSite, AbstractionWithSourceStmt retNode) {
                    return AugmentedTabulationProblem.this.delegate.edgeFunctions().getReturnEdgeFunction(callSite, calleeMethod, exitStmt, exitNode.getAbstraction(), returnSite, retNode.getAbstraction());
                }

                @Override
                public EdgeFunction<V> getCallToReturnEdgeFunction(N callSite, AbstractionWithSourceStmt callNode, N returnSite, AbstractionWithSourceStmt returnSideNode) {
                    return AugmentedTabulationProblem.this.delegate.edgeFunctions().getCallToReturnEdgeFunction(callSite, callNode.getAbstraction(), returnSite, returnSideNode.getAbstraction());
                }
            };
        }

        @Override
        public JoinLattice<V> joinLattice() {
            return this.delegate.joinLattice();
        }

        @Override
        public EdgeFunction<V> allTopFunction() {
            return this.delegate.allTopFunction();
        }
    }

    public class AbstractionWithSourceStmt {
        protected final D abstraction;
        protected final N source;

        private AbstractionWithSourceStmt(D abstraction, N source) {
            this.abstraction = abstraction;
            this.source = source;
        }

        public D getAbstraction() {
            return this.abstraction;
        }

        public N getSourceStmt() {
            return this.source;
        }

        public String toString() {
            if (this.source != null) {
                return "" + this.abstraction + "-@-" + this.source + "";
            }
            return this.abstraction.toString();
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.abstraction == null ? 0 : this.abstraction.hashCode());
            result = 31 * result + (this.source == null ? 0 : this.source.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            AbstractionWithSourceStmt other = (AbstractionWithSourceStmt)obj;
            if (this.abstraction == null ? other.abstraction != null : !this.abstraction.equals(other.abstraction)) {
                return false;
            }
            return !(this.source == null ? other.source != null : !this.source.equals(other.source));
        }
    }

    protected class SingleDirectionSolver
    extends IDESolver<N, AbstractionWithSourceStmt, M, V, I> {
        private final String debugName;
        private SingleDirectionSolver otherSolver;
        private Set<LeakKey<N>> leakedSources;
        private ConcurrentMap<LeakKey<N>, Set<PausedEdge>> pausedPathEdges;

        public SingleDirectionSolver(IDETabulationProblem<N, AbstractionWithSourceStmt, M, V, I> ifdsProblem, String debugName) {
            super(ifdsProblem);
            this.leakedSources = Collections.newSetFromMap(Maps.newConcurrentMap());
            this.pausedPathEdges = Maps.newConcurrentMap();
            this.debugName = debugName;
        }

        @Override
        protected void propagateUnbalancedReturnFlow(N retSiteC, AbstractionWithSourceStmt targetVal, EdgeFunction<V> edgeFunction, N relatedCallSite) {
            Object sourceStmt = targetVal.getSourceStmt();
            LeakKey leakKey = new LeakKey(sourceStmt, relatedCallSite);
            this.leakedSources.add(leakKey);
            if (this.otherSolver.hasLeaked(leakKey)) {
                this.otherSolver.unpausePathEdgesForSource(leakKey);
                super.propagateUnbalancedReturnFlow(retSiteC, targetVal, edgeFunction, relatedCallSite);
            } else {
                Set newPausedEdges = Collections.newSetFromMap(Maps.newConcurrentMap());
                Set existingPausedEdges = this.pausedPathEdges.putIfAbsent(leakKey, newPausedEdges);
                if (existingPausedEdges == null) {
                    existingPausedEdges = newPausedEdges;
                }
                PausedEdge edge = new PausedEdge(retSiteC, targetVal, edgeFunction, relatedCallSite);
                existingPausedEdges.add(edge);
                if (this.otherSolver.hasLeaked(leakKey) && existingPausedEdges.remove(edge)) {
                    super.propagateUnbalancedReturnFlow(retSiteC, targetVal, edgeFunction, relatedCallSite);
                }
                logger.debug(" ++ PAUSE {}: {}", (Object)this.debugName, (Object)edge);
            }
        }

        @Override
        protected void propagate(AbstractionWithSourceStmt sourceVal, N target, AbstractionWithSourceStmt targetVal, EdgeFunction<V> f, N relatedCallSite, boolean isUnbalancedReturn) {
            if (isUnbalancedReturn) {
                assert (sourceVal.getSourceStmt() == null) : "source value should have no statement attached";
                targetVal = new AbstractionWithSourceStmt(targetVal.getAbstraction(), relatedCallSite);
                super.propagate(sourceVal, target, targetVal, f, relatedCallSite, isUnbalancedReturn);
            } else {
                super.propagate(sourceVal, target, targetVal, f, relatedCallSite, isUnbalancedReturn);
            }
        }

        @Override
        protected AbstractionWithSourceStmt restoreContextOnReturnedFact(AbstractionWithSourceStmt d4, AbstractionWithSourceStmt d5) {
            return new AbstractionWithSourceStmt(d5.getAbstraction(), d4.getSourceStmt());
        }

        private boolean hasLeaked(LeakKey<N> leakKey) {
            return this.leakedSources.contains(leakKey);
        }

        private void unpausePathEdgesForSource(LeakKey<N> leakKey) {
            Set pausedEdges = (Set)this.pausedPathEdges.get(leakKey);
            if (pausedEdges != null) {
                for (PausedEdge edge : pausedEdges) {
                    if (!pausedEdges.remove(edge)) continue;
                    if (DEBUG) {
                        logger.debug("-- UNPAUSE {}: {}", (Object)this.debugName, (Object)edge);
                    }
                    super.propagateUnbalancedReturnFlow(edge.retSiteC, edge.targetVal, edge.edgeFunction, edge.relatedCallSite);
                }
            }
        }

        @Override
        protected CountingThreadPoolExecutor getExecutor() {
            return BiDiIDESolver.this.sharedExecutor;
        }

        @Override
        protected String getDebugName() {
            return this.debugName;
        }
    }

    private static class LeakKey<N> {
        private N sourceStmt;
        private N relatedCallSite;

        public LeakKey(N sourceStmt, N relatedCallSite) {
            this.sourceStmt = sourceStmt;
            this.relatedCallSite = relatedCallSite;
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + (this.relatedCallSite == null ? 0 : this.relatedCallSite.hashCode());
            result = 31 * result + (this.sourceStmt == null ? 0 : this.sourceStmt.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (!(obj instanceof LeakKey)) {
                return false;
            }
            LeakKey other = (LeakKey)obj;
            if (this.relatedCallSite == null ? other.relatedCallSite != null : !this.relatedCallSite.equals(other.relatedCallSite)) {
                return false;
            }
            return !(this.sourceStmt == null ? other.sourceStmt != null : !this.sourceStmt.equals(other.sourceStmt));
        }
    }

    private class PausedEdge {
        private N retSiteC;
        private AbstractionWithSourceStmt targetVal;
        private EdgeFunction<V> edgeFunction;
        private N relatedCallSite;

        public PausedEdge(N retSiteC, AbstractionWithSourceStmt targetVal, EdgeFunction<V> edgeFunction, N relatedCallSite) {
            this.retSiteC = retSiteC;
            this.targetVal = targetVal;
            this.edgeFunction = edgeFunction;
            this.relatedCallSite = relatedCallSite;
        }
    }
}

