/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.jet.impl.execution.init;

import com.hazelcast.cluster.Address;
import com.hazelcast.internal.cluster.MemberInfo;
import com.hazelcast.internal.partition.IPartitionService;
import com.hazelcast.internal.util.ConcurrencyUtil;
import com.hazelcast.internal.util.Preconditions;
import com.hazelcast.internal.util.collection.IntHashSet;
import com.hazelcast.internal.util.executor.ManagedExecutorService;
import com.hazelcast.jet.JetException;
import com.hazelcast.jet.config.EdgeConfig;
import com.hazelcast.jet.config.JobConfig;
import com.hazelcast.jet.core.DAG;
import com.hazelcast.jet.core.Edge;
import com.hazelcast.jet.core.ProcessorMetaSupplier;
import com.hazelcast.jet.core.ProcessorSupplier;
import com.hazelcast.jet.core.Vertex;
import com.hazelcast.jet.function.RunnableEx;
import com.hazelcast.jet.impl.JetServiceBackend;
import com.hazelcast.jet.impl.JobClassLoaderService;
import com.hazelcast.jet.impl.deployment.JetDelegatingClassLoader;
import com.hazelcast.jet.impl.execution.init.Contexts;
import com.hazelcast.jet.impl.execution.init.EdgeDef;
import com.hazelcast.jet.impl.execution.init.ExecutionPlan;
import com.hazelcast.jet.impl.execution.init.VertexDef;
import com.hazelcast.jet.impl.util.ExceptionUtil;
import com.hazelcast.jet.impl.util.FixedCapacityIntArrayList;
import com.hazelcast.jet.impl.util.PrefixedLogger;
import com.hazelcast.jet.impl.util.Util;
import com.hazelcast.logging.ILogger;
import com.hazelcast.spi.impl.NodeEngine;
import com.hazelcast.spi.impl.NodeEngineImpl;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.Function;
import java.util.function.ToIntFunction;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.security.auth.Subject;

public final class ExecutionPlanBuilder {
    private ExecutionPlanBuilder() {
    }

    public static CompletableFuture<Map<MemberInfo, ExecutionPlan>> createExecutionPlans(NodeEngineImpl nodeEngine, List<MemberInfo> memberInfos, DAG dag, long jobId, long executionId, JobConfig jobConfig, long lastSnapshotId, boolean isLightJob, Subject subject) {
        Map<MemberInfo, int[]> partitionsByMember;
        boolean isIsolatedJob;
        Set requiredPartitions = (Set)jobConfig.getArgument("__sql.requiredPartitions");
        boolean bl = isIsolatedJob = dag.memberSelector() != null;
        if (requiredPartitions != null) {
            PartitionPruningAnalysisResult analysisResult = ExecutionPlanBuilder.analyzeDagForPartitionPruning(nodeEngine, dag);
            partitionsByMember = ExecutionPlanBuilder.getPartitionAssignment(nodeEngine, memberInfos, analysisResult.allPartitionsRequired, requiredPartitions, analysisResult.constantPartitionIds, analysisResult.requiredAddresses);
        } else {
            partitionsByMember = isIsolatedJob ? ExecutionPlanBuilder.getFairPartitionAssignment(nodeEngine, memberInfos) : ExecutionPlanBuilder.getPartitionAssignment(nodeEngine, memberInfos, false, null, null, null);
        }
        Map<Address, int[]> partitionsByAddress = partitionsByMember.entrySet().stream().collect(Collectors.toMap(en -> ((MemberInfo)en.getKey()).getAddress(), Map.Entry::getValue));
        int memberCount = partitionsByAddress.size();
        boolean isJobDistributed = memberCount > 1;
        VerticesIdAndOrder verticesIdAndOrder = VerticesIdAndOrder.assignVertexIds(dag);
        int defaultParallelism = nodeEngine.getConfig().getJetConfig().getCooperativeThreadCount();
        EdgeConfig defaultEdgeConfig = nodeEngine.getConfig().getJetConfig().getDefaultEdgeConfig();
        HashMap<MemberInfo, ExecutionPlan> plans = new HashMap<MemberInfo, ExecutionPlan>();
        int memberIndex = 0;
        for (MemberInfo member : partitionsByMember.keySet()) {
            plans.put(member, new ExecutionPlan(partitionsByAddress, jobConfig, lastSnapshotId, memberIndex++, memberCount, isLightJob, subject, verticesIdAndOrder.count()));
        }
        List<Address> addresses = Util.toList(partitionsByMember.keySet(), MemberInfo::getAddress);
        ManagedExecutorService initOffloadExecutor = nodeEngine.getExecutionService().getExecutor("hz:jet-job-offloadable");
        CompletableFuture[] futures = new CompletableFuture[verticesIdAndOrder.count()];
        for (VertexIdPos entry : verticesIdAndOrder) {
            Vertex vertex = dag.getVertex(entry.vertexName);
            assert (vertex != null);
            ProcessorMetaSupplier metaSupplier = vertex.getMetaSupplier();
            int vertexId = entry.vertexId;
            int localParallelism = vertex.determineLocalParallelism(defaultParallelism);
            int totalParallelism = localParallelism * memberCount;
            List<EdgeDef> inbound = ExecutionPlanBuilder.toEdgeDefs(dag.getInboundEdges(vertex.getName()), defaultEdgeConfig, e -> verticesIdAndOrder.idByName(e.getSourceName()), isJobDistributed);
            List<EdgeDef> outbound = ExecutionPlanBuilder.toEdgeDefs(dag.getOutboundEdges(vertex.getName()), defaultEdgeConfig, e -> verticesIdAndOrder.idByName(e.getDestName()), isJobDistributed);
            String prefix = PrefixedLogger.prefix(jobConfig.getName(), jobId, vertex.getName(), "#PMS");
            ILogger logger = PrefixedLogger.prefixedLogger(nodeEngine.getLogger(metaSupplier.getClass()), prefix);
            RunnableEx action = () -> {
                JetServiceBackend jetBackend = (JetServiceBackend)nodeEngine.getService("hz:impl:jetService");
                JobClassLoaderService jobClassLoaderService = jetBackend.getJobClassLoaderService();
                JetDelegatingClassLoader processorClassLoader = jobClassLoaderService.getClassLoader(jobId);
                try {
                    Util.doWithClassLoader((ClassLoader)processorClassLoader, () -> metaSupplier.init(new Contexts.MetaSupplierCtx(nodeEngine, jobId, executionId, jobConfig, logger, vertex.getName(), localParallelism, totalParallelism, memberCount, isLightJob, partitionsByAddress, subject, processorClassLoader)));
                }
                catch (Exception e) {
                    throw com.hazelcast.internal.util.ExceptionUtil.sneakyThrow(ExceptionUtil.peel(e));
                }
                Function procSupplierFn = Util.doWithClassLoader((ClassLoader)processorClassLoader, () -> metaSupplier.get(addresses));
                for (Map.Entry e : plans.entrySet()) {
                    ProcessorSupplier processorSupplier = Util.doWithClassLoader((ClassLoader)processorClassLoader, () -> (ProcessorSupplier)procSupplierFn.apply(((MemberInfo)e.getKey()).getAddress()));
                    if (!isLightJob) {
                        Util.checkSerializable(processorSupplier, "ProcessorSupplier in vertex '" + vertex.getName() + "'");
                    }
                    VertexDef vertexDef = new VertexDef(vertexId, vertex.getName(), processorSupplier, localParallelism);
                    vertexDef.addInboundEdges(inbound);
                    vertexDef.addOutboundEdges(outbound);
                    ((ExecutionPlan)e.getValue()).setVertex(entry.requiredPosition, vertexDef);
                }
            };
            Executor executor = metaSupplier.initIsCooperative() ? ConcurrencyUtil.CALLER_RUNS : initOffloadExecutor;
            futures[entry.requiredPosition] = CompletableFuture.runAsync(action, executor);
        }
        return CompletableFuture.allOf(futures).thenCompose(r -> CompletableFuture.completedFuture(plans));
    }

    @Nonnull
    static PartitionPruningAnalysisResult analyzeDagForPartitionPruning(NodeEngine nodeEngine, DAG dag) {
        IPartitionService partitionService = nodeEngine.getPartitionService();
        int partitionCount = partitionService.getPartitionCount();
        HashSet<Address> requiredAddresses = new HashSet<Address>(1);
        IntHashSet constantPartitionIds = new IntHashSet(partitionCount, -1);
        boolean allPartitionsRequired = false;
        Iterator<Edge> it = dag.edgeIterator();
        while (it.hasNext()) {
            Edge edge = it.next();
            if (edge.getDistributedTo() != null && !edge.isDistributed()) {
                requiredAddresses.add(edge.getDistributedTo());
            }
            if (edge.getRoutingPolicy() != Edge.RoutingPolicy.PARTITIONED) continue;
            assert (edge.getPartitioner() != null) : "PARTITIONED policy was used without partitioner";
            Object maybeConstantPartition = edge.getPartitioner().getConstantPartitioningKey();
            if (maybeConstantPartition != null) {
                constantPartitionIds.add(partitionService.getPartitionId(maybeConstantPartition));
                continue;
            }
            allPartitionsRequired = true;
        }
        return new PartitionPruningAnalysisResult(requiredAddresses, constantPartitionIds, allPartitionsRequired);
    }

    private static List<EdgeDef> toEdgeDefs(List<Edge> edges, EdgeConfig defaultEdgeConfig, ToIntFunction<Edge> oppositeVtxId, boolean isJobDistributed) {
        ArrayList<EdgeDef> list = new ArrayList<EdgeDef>(edges.size());
        for (Edge edge : edges) {
            list.add(new EdgeDef(edge, edge.getConfig() == null ? defaultEdgeConfig : edge.getConfig(), oppositeVtxId.applyAsInt(edge), isJobDistributed));
        }
        return list;
    }

    public static Map<MemberInfo, int[]> getPartitionAssignment(NodeEngine nodeEngine, List<MemberInfo> memberList, boolean allPartitionsRequired, @Nullable Set<Integer> dataPartitions, @Nullable Set<Integer> routingPartitions, @Nullable Set<Address> extraRequiredMemberAddresses) {
        Object address;
        if (allPartitionsRequired) {
            Preconditions.checkNotNull(dataPartitions);
        }
        IPartitionService partitionService = nodeEngine.getPartitionService();
        HashMap<Address, MemberInfo> membersByAddress = new HashMap<Address, MemberInfo>();
        for (MemberInfo memberInfo : memberList) {
            membersByAddress.put(memberInfo.getAddress(), memberInfo);
        }
        HashMap<MemberInfo, FixedCapacityIntArrayList> partitionsForMember = new HashMap<MemberInfo, FixedCapacityIntArrayList>();
        int partitionCount = partitionService.getPartitionCount();
        int memberIndex = 0;
        if (dataPartitions == null) {
            for (int partitionId = 0; partitionId < partitionCount; ++partitionId) {
                Address address2 = partitionService.getPartitionOwnerOrWait(partitionId);
                MemberInfo member = (MemberInfo)membersByAddress.get(address2);
                if (member == null) {
                    member = memberList.get(memberIndex++ % memberList.size());
                }
                partitionsForMember.computeIfAbsent(member, ignored -> new FixedCapacityIntArrayList(partitionCount)).add(partitionId);
            }
        } else {
            for (int partitionId : dataPartitions) {
                address = partitionService.getPartitionOwnerOrWait(partitionId);
                MemberInfo member = (MemberInfo)membersByAddress.get(address);
                if (member == null) {
                    member = memberList.get(memberIndex++ % memberList.size());
                }
                partitionsForMember.computeIfAbsent(member, ignored -> new FixedCapacityIntArrayList(partitionCount)).add(partitionId);
            }
        }
        if (dataPartitions != null) {
            extraRequiredMemberAddresses = Preconditions.checkNotNull(extraRequiredMemberAddresses);
            routingPartitions = Preconditions.checkNotNull(routingPartitions);
            extraRequiredMemberAddresses.forEach(requiredMemberAddr -> {
                MemberInfo requiredMemberInfo = (MemberInfo)membersByAddress.get(requiredMemberAddr);
                if (requiredMemberInfo == null) {
                    throw new JetException("Member with address " + requiredMemberAddr + " not present in the cluster");
                }
                partitionsForMember.computeIfAbsent(requiredMemberInfo, i -> {
                    nodeEngine.getLogger(ExecutionPlanBuilder.class).fine("Adding required member " + requiredMemberAddr + " to partition-pruned job members");
                    return new FixedCapacityIntArrayList(partitionCount);
                });
            });
            if (allPartitionsRequired || !routingPartitions.isEmpty()) {
                HashSet<Integer> partitionsToAssign = allPartitionsRequired ? new HashSet<Integer>(Util.range(0, partitionCount)) : new HashSet<Integer>(routingPartitions);
                partitionsToAssign.removeAll(dataPartitions);
                ArrayList requiredMembers = new ArrayList(partitionsForMember.keySet());
                address = partitionsToAssign.iterator();
                while (address.hasNext()) {
                    int partitionId = (Integer)address.next();
                    MemberInfo member = (MemberInfo)requiredMembers.get(memberIndex++ % requiredMembers.size());
                    ((FixedCapacityIntArrayList)partitionsForMember.get(member)).add(partitionId);
                }
            }
        }
        HashMap<MemberInfo, int[]> partitionAssignment = new HashMap<MemberInfo, int[]>();
        for (Map.Entry memberWithPartitions : partitionsForMember.entrySet()) {
            int[] p = ((FixedCapacityIntArrayList)memberWithPartitions.getValue()).asArray();
            if (dataPartitions != null) {
                Arrays.sort(p);
            }
            partitionAssignment.put((MemberInfo)memberWithPartitions.getKey(), p);
        }
        return partitionAssignment;
    }

    public static Map<MemberInfo, int[]> getFairPartitionAssignment(NodeEngine nodeEngine, List<MemberInfo> memberList) {
        List<MemberInfo> liteMembers = memberList.stream().filter(MemberInfo::isLiteMember).toList();
        if (liteMembers.isEmpty()) {
            return ExecutionPlanBuilder.getPartitionAssignment(nodeEngine, memberList, false, null, null, null);
        }
        IPartitionService partitionService = nodeEngine.getPartitionService();
        int partitionCount = partitionService.getPartitionCount();
        HashMap<Address, MemberInfo> membersByAddress = new HashMap<Address, MemberInfo>();
        HashMap<MemberInfo, FixedCapacityIntArrayList> partitionsForMember = new HashMap<MemberInfo, FixedCapacityIntArrayList>();
        HashSet<MemberInfo> membersAbleToAcceptPartitions = new HashSet<MemberInfo>(memberList);
        for (MemberInfo memberInfo : memberList) {
            membersByAddress.put(memberInfo.getAddress(), memberInfo);
        }
        int fairPartitionSliceSize = (partitionCount + memberList.size() - 1) / memberList.size();
        int memberIndex = 0;
        for (int partitionId = 0; partitionId < partitionCount; ++partitionId) {
            Address address = partitionService.getPartitionOwnerOrWait(partitionId);
            MemberInfo member = (MemberInfo)membersByAddress.get(address);
            while (!membersAbleToAcceptPartitions.contains(member) || member == null) {
                member = memberList.get(memberIndex++ % memberList.size());
            }
            FixedCapacityIntArrayList partitions = partitionsForMember.computeIfAbsent(member, ignored -> new FixedCapacityIntArrayList(partitionCount));
            partitions.add(partitionId);
            if (partitions.size() < fairPartitionSliceSize) continue;
            membersAbleToAcceptPartitions.remove(member);
        }
        HashMap<MemberInfo, int[]> partitionAssignment = new HashMap<MemberInfo, int[]>();
        for (Map.Entry memberWithPartitions : partitionsForMember.entrySet()) {
            int[] p = ((FixedCapacityIntArrayList)memberWithPartitions.getValue()).asArray();
            partitionAssignment.put((MemberInfo)memberWithPartitions.getKey(), p);
        }
        return partitionAssignment;
    }

    static class PartitionPruningAnalysisResult {
        final Set<Address> requiredAddresses;
        final Set<Integer> constantPartitionIds;
        final boolean allPartitionsRequired;

        PartitionPruningAnalysisResult(Set<Address> requiredAddresses, Set<Integer> constantPartitionIds, boolean allPartitionsRequired) {
            this.requiredAddresses = requiredAddresses;
            this.constantPartitionIds = constantPartitionIds;
            this.allPartitionsRequired = allPartitionsRequired;
        }
    }

    private static final class VerticesIdAndOrder
    implements Iterable<VertexIdPos> {
        private final LinkedHashMap<String, Integer> vertexIdMap;
        private final HashMap<Integer, Integer> vertexPosById;

        private VerticesIdAndOrder(LinkedHashMap<String, Integer> vertexIdMap) {
            this.vertexIdMap = vertexIdMap;
            int index = 0;
            this.vertexPosById = new LinkedHashMap<Integer, Integer>(vertexIdMap.size());
            for (Integer vertexId : vertexIdMap.values()) {
                this.vertexPosById.put(vertexId, index++);
            }
        }

        private Integer idByName(String vertexName) {
            return this.vertexIdMap.get(vertexName);
        }

        private static VerticesIdAndOrder assignVertexIds(DAG dag) {
            LinkedHashMap<String, Integer> vertexIdMap = new LinkedHashMap<String, Integer>();
            int[] vertexId = new int[]{0};
            dag.forEach(v -> {
                int n = vertexId[0];
                vertexId[0] = n + 1;
                vertexIdMap.put(v.getName(), n);
            });
            return new VerticesIdAndOrder(vertexIdMap);
        }

        private int count() {
            return this.vertexIdMap.size();
        }

        @Override
        @Nonnull
        public Iterator<VertexIdPos> iterator() {
            return this.vertexIdMap.entrySet().stream().map(e -> new VertexIdPos((Integer)e.getValue(), (String)e.getKey(), this.vertexPosById.get(e.getValue()))).iterator();
        }
    }

    private static final class VertexIdPos {
        private final int vertexId;
        private final String vertexName;
        private final int requiredPosition;

        private VertexIdPos(int vertexId, String vertexName, int position) {
            this.vertexId = vertexId;
            this.vertexName = vertexName;
            this.requiredPosition = position;
        }
    }
}

