/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.ratelimit;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.get.MultiGetItemResponse;
import org.opensearch.action.get.MultiGetRequest;
import org.opensearch.action.get.MultiGetResponse;
import org.opensearch.ad.NodeStateManager;
import org.opensearch.ad.breaker.ADCircuitBreakerService;
import org.opensearch.ad.caching.CacheProvider;
import org.opensearch.ad.common.exception.EndRunException;
import org.opensearch.ad.indices.ADIndex;
import org.opensearch.ad.indices.AnomalyDetectionIndices;
import org.opensearch.ad.ml.CheckpointDao;
import org.opensearch.ad.ml.EntityModel;
import org.opensearch.ad.ml.ModelManager;
import org.opensearch.ad.ml.ModelState;
import org.opensearch.ad.ml.ThresholdingResult;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyResult;
import org.opensearch.ad.model.Entity;
import org.opensearch.ad.ratelimit.BatchWorker;
import org.opensearch.ad.ratelimit.CheckpointWriteWorker;
import org.opensearch.ad.ratelimit.EntityColdStartWorker;
import org.opensearch.ad.ratelimit.EntityFeatureRequest;
import org.opensearch.ad.ratelimit.EntityRequest;
import org.opensearch.ad.ratelimit.RequestPriority;
import org.opensearch.ad.ratelimit.ResultWriteRequest;
import org.opensearch.ad.ratelimit.ResultWriteWorker;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.stats.ADStats;
import org.opensearch.ad.stats.StatNames;
import org.opensearch.ad.util.ExceptionUtil;
import org.opensearch.ad.util.ParseUtils;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.threadpool.ThreadPool;

public class CheckpointReadWorker
extends BatchWorker<EntityFeatureRequest, MultiGetRequest, MultiGetResponse> {
    private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class);
    public static final String WORKER_NAME = "checkpoint-read";
    private final ModelManager modelManager;
    private final CheckpointDao checkpointDao;
    private final EntityColdStartWorker entityColdStartQueue;
    private final ResultWriteWorker resultWriteQueue;
    private final AnomalyDetectionIndices indexUtil;
    private final CacheProvider cacheProvider;
    private final CheckpointWriteWorker checkpointWriteQueue;
    private final ADStats adStats;

    public CheckpointReadWorker(long heapSizeInBytes, int singleRequestSizeInBytes, Setting<Float> maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, ADCircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, Duration executionTtl, ModelManager modelManager, CheckpointDao checkpointDao, EntityColdStartWorker entityColdStartQueue, ResultWriteWorker resultWriteQueue, NodeStateManager stateManager, AnomalyDetectionIndices indexUtil, CacheProvider cacheProvider, Duration stateTtl, CheckpointWriteWorker checkpointWriteQueue, ADStats adStats) {
        super(WORKER_NAME, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, clusterService, random, adCircuitBreakerService, threadPool, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_CONCURRENCY, executionTtl, AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_BATCH_SIZE, stateTtl, stateManager);
        this.modelManager = modelManager;
        this.checkpointDao = checkpointDao;
        this.entityColdStartQueue = entityColdStartQueue;
        this.resultWriteQueue = resultWriteQueue;
        this.indexUtil = indexUtil;
        this.cacheProvider = cacheProvider;
        this.checkpointWriteQueue = checkpointWriteQueue;
        this.adStats = adStats;
    }

    @Override
    protected void executeBatchRequest(MultiGetRequest request, ActionListener<MultiGetResponse> listener) {
        this.checkpointDao.batchRead(request, listener);
    }

    @Override
    protected MultiGetRequest toBatchRequest(List<EntityFeatureRequest> toProcess) {
        MultiGetRequest multiGetRequest = new MultiGetRequest();
        for (EntityRequest entityRequest : toProcess) {
            Optional<String> modelId = entityRequest.getModelId();
            if (!modelId.isPresent()) continue;
            multiGetRequest.add(new MultiGetRequest.Item(".opendistro-anomaly-checkpoints", modelId.get()));
        }
        return multiGetRequest;
    }

    @Override
    protected ActionListener<MultiGetResponse> getResponseListener(List<EntityFeatureRequest> toProcess, MultiGetRequest batchRequest) {
        return ActionListener.wrap(response -> {
            MultiGetItemResponse[] itemResponses = response.getResponses();
            HashMap<String, MultiGetItemResponse> successfulRequests = new HashMap<String, MultiGetItemResponse>();
            HashSet<String> retryableRequests = null;
            HashSet<String> notFoundModels = null;
            boolean printedUnexpectedFailure = false;
            HashMap<String, Exception> stopDetectorRequests = null;
            for (MultiGetItemResponse itemResponse : itemResponses) {
                String modelId = itemResponse.getId();
                if (itemResponse.isFailed()) {
                    Exception failure = itemResponse.getFailure().getFailure();
                    if (failure instanceof IndexNotFoundException) {
                        for (EntityRequest origRequest : toProcess) {
                            this.entityColdStartQueue.put(origRequest);
                        }
                        return;
                    }
                    if (ExceptionUtil.isRetryAble(failure)) {
                        if (retryableRequests == null) {
                            retryableRequests = new HashSet<String>();
                        }
                        retryableRequests.add(modelId);
                        continue;
                    }
                    if (ExceptionUtil.isOverloaded(failure)) {
                        LOG.error("too many get AD model checkpoint requests or shard not available");
                        this.setCoolDownStart();
                        continue;
                    }
                    if (!printedUnexpectedFailure) {
                        LOG.error("Unexpected failure", (Throwable)failure);
                        printedUnexpectedFailure = true;
                    }
                    if (stopDetectorRequests == null) {
                        stopDetectorRequests = new HashMap<String, Exception>();
                    }
                    stopDetectorRequests.put(modelId, failure);
                    continue;
                }
                if (!itemResponse.getResponse().isExists()) {
                    if (notFoundModels == null) {
                        notFoundModels = new HashSet<String>();
                    }
                    notFoundModels.add(modelId);
                    continue;
                }
                successfulRequests.put(modelId, itemResponse);
            }
            if (notFoundModels != null) {
                for (EntityRequest origRequest : toProcess) {
                    Optional<String> modelId = origRequest.getModelId();
                    if (!modelId.isPresent() || !notFoundModels.contains(modelId.get())) continue;
                    this.entityColdStartQueue.put(origRequest);
                }
            }
            if (stopDetectorRequests != null) {
                for (EntityRequest origRequest : toProcess) {
                    Optional<String> modelId = origRequest.getModelId();
                    if (!modelId.isPresent() || !stopDetectorRequests.containsKey(modelId.get())) continue;
                    String adID = origRequest.detectorId;
                    this.nodeStateManager.setException(adID, new EndRunException(adID, "We might have bugs.", (Throwable)stopDetectorRequests.get(modelId.get()), false));
                }
            }
            if (successfulRequests.isEmpty() && (retryableRequests == null || retryableRequests.isEmpty())) {
                return;
            }
            this.processCheckpointIteration(0, toProcess, successfulRequests, (Set<String>)retryableRequests);
        }, exception -> {
            if (ExceptionUtil.isOverloaded(exception)) {
                LOG.error("too many get AD model checkpoint requests or shard not available");
                this.setCoolDownStart();
            } else if (ExceptionUtil.isRetryAble(exception)) {
                this.putAll(toProcess);
            } else {
                LOG.error("Fail to restore models", (Throwable)exception);
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void processCheckpointIteration(int i, List<EntityFeatureRequest> toProcess, Map<String, MultiGetItemResponse> successfulRequests, Set<String> retryableRequests) {
        if (i >= toProcess.size()) {
            return;
        }
        boolean processNextInCallBack = false;
        try {
            EntityFeatureRequest origRequest = toProcess.get(i);
            Optional<String> modelIdOptional = origRequest.getModelId();
            if (!modelIdOptional.isPresent()) {
                return;
            }
            String detectorId = origRequest.getDetectorId();
            Entity entity = origRequest.getEntity();
            String modelId = modelIdOptional.get();
            MultiGetItemResponse checkpointResponse = successfulRequests.get(modelId);
            if (checkpointResponse != null) {
                Optional<Map.Entry<EntityModel, Instant>> checkpoint = this.checkpointDao.processGetResponse(checkpointResponse.getResponse(), modelId);
                if (!checkpoint.isPresent()) {
                    return;
                }
                this.nodeStateManager.getAnomalyDetector(detectorId, this.onGetDetector(origRequest, i, detectorId, toProcess, successfulRequests, retryableRequests, checkpoint, entity, modelId));
                processNextInCallBack = true;
            } else if (retryableRequests != null && retryableRequests.contains(modelId)) {
                super.put(origRequest);
            }
        }
        finally {
            if (!processNextInCallBack) {
                this.processCheckpointIteration(i + 1, toProcess, successfulRequests, retryableRequests);
            }
        }
    }

    private ActionListener<Optional<AnomalyDetector>> onGetDetector(EntityFeatureRequest origRequest, int index, String detectorId, List<EntityFeatureRequest> toProcess, Map<String, MultiGetItemResponse> successfulRequests, Set<String> retryableRequests, Optional<Map.Entry<EntityModel, Instant>> checkpoint, Entity entity, String modelId) {
        return ActionListener.wrap(detectorOptional -> {
            boolean loaded;
            if (!detectorOptional.isPresent()) {
                LOG.warn((Message)new ParameterizedMessage("AnomalyDetector [{}] is not available.", (Object)detectorId));
                this.processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests);
                return;
            }
            AnomalyDetector detector = (AnomalyDetector)detectorOptional.get();
            ModelState<EntityModel> modelState = this.modelManager.processEntityCheckpoint(checkpoint, entity, modelId, detectorId, detector.getShingleSize());
            ThresholdingResult result = null;
            try {
                result = this.modelManager.getAnomalyResultForEntity(origRequest.getCurrentFeature(), modelState, modelId, entity, detector.getShingleSize());
            }
            catch (IllegalArgumentException e) {
                LOG.error((Message)new ParameterizedMessage("Likely model corruption for [{}]", origRequest.getModelId()), (Throwable)e);
                this.adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment();
                if (origRequest.getModelId().isPresent()) {
                    String entityModelId = origRequest.getModelId().get();
                    this.checkpointDao.deleteModelCheckpoint(entityModelId, (ActionListener<Void>)ActionListener.wrap(r -> LOG.debug((Message)new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", (Object)entityModelId)), ex -> LOG.error((Message)new ParameterizedMessage("Failed to delete checkpoint [{}].", (Object)entityModelId), (Throwable)ex)));
                }
                this.entityColdStartQueue.put(origRequest);
                this.processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests);
                return;
            }
            if (result != null && result.getRcfScore() > 0.0) {
                AnomalyResult resultToSave = result.toAnomalyResult(detector, Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + detector.getDetectorIntervalInMilliseconds()), Instant.now(), Instant.now(), ParseUtils.getFeatureData(origRequest.getCurrentFeature(), detector), entity, this.indexUtil.getSchemaVersion(ADIndex.RESULT), modelId, null, null);
                this.resultWriteQueue.put(new ResultWriteRequest(origRequest.getExpirationEpochMs(), detectorId, result.getGrade() > 0.0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, resultToSave, detector.getResultIndex()));
            }
            if (!(loaded = this.cacheProvider.get().hostIfPossible(detector, modelState))) {
                this.checkpointWriteQueue.write(modelState, true, RequestPriority.LOW);
            }
            this.processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests);
        }, exception -> {
            LOG.error((Message)new ParameterizedMessage("fail to get checkpoint [{}]", (Object)modelId, exception));
            this.nodeStateManager.setException(detectorId, (Exception)exception);
            this.processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests);
        });
    }
}

