/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.deployment;

import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentNodeService;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;

public class TrainedModelDeploymentTask
extends CancellableTask
implements StartTrainedModelDeploymentAction.TaskMatcher {
    private static final Logger logger = LogManager.getLogger(TrainedModelDeploymentTask.class);
    private volatile StartTrainedModelDeploymentAction.TaskParams params;
    private final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService;
    private volatile boolean stopped;
    private volatile boolean failed;
    private final SetOnce<String> stoppedReasonHolder = new SetOnce();
    private final SetOnce<InferenceConfig> inferenceConfigHolder = new SetOnce();
    private final XPackLicenseState licenseState;
    private final LicensedFeature.Persistent licensedFeature;

    public TrainedModelDeploymentTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers, StartTrainedModelDeploymentAction.TaskParams taskParams, TrainedModelAssignmentNodeService trainedModelAssignmentNodeService, XPackLicenseState licenseState, LicensedFeature.Persistent licensedFeature) {
        super(id, type, action, MlTasks.trainedModelAssignmentTaskDescription((String)taskParams.getDeploymentId()), parentTask, headers);
        this.params = Objects.requireNonNull(taskParams);
        this.trainedModelAssignmentNodeService = (TrainedModelAssignmentNodeService)ExceptionsHelper.requireNonNull((Object)trainedModelAssignmentNodeService, (String)"trainedModelAssignmentNodeService");
        this.licenseState = licenseState;
        this.licensedFeature = licensedFeature;
    }

    void init(InferenceConfig inferenceConfig) {
        if (this.inferenceConfigHolder.trySet((Object)inferenceConfig)) {
            this.licensedFeature.startTracking(this.licenseState, "model-" + this.params.getModelId());
        }
    }

    public void updateNumberOfAllocations(int numberOfAllocations) {
        this.params = new StartTrainedModelDeploymentAction.TaskParams(this.params.getModelId(), this.params.getDeploymentId(), this.params.getModelBytes(), numberOfAllocations, this.params.getThreadsPerAllocation(), this.params.getQueueCapacity(), (ByteSizeValue)this.params.getCacheSize().orElse(null), this.params.getPriority());
    }

    public String getModelId() {
        return this.params.getModelId();
    }

    public String getDeploymentId() {
        return this.params.getDeploymentId();
    }

    public long estimateMemoryUsageBytes() {
        return this.params.estimateMemoryUsageBytes();
    }

    public StartTrainedModelDeploymentAction.TaskParams getParams() {
        return this.params;
    }

    public void stop(String reason, ActionListener<AcknowledgedResponse> listener) {
        this.trainedModelAssignmentNodeService.stopDeploymentAndNotify(this, reason, listener);
    }

    public void markAsStopped(String reason) {
        this.licensedFeature.stopTracking(this.licenseState, "model-" + this.params.getModelId());
        logger.debug("[{}] Stopping due to reason [{}]", (Object)this.getDeploymentId(), (Object)reason);
        this.stoppedReasonHolder.trySet((Object)reason);
        this.stopped = true;
    }

    public boolean isStopped() {
        return this.stopped;
    }

    public Optional<String> stoppedReason() {
        return Optional.ofNullable((String)this.stoppedReasonHolder.get());
    }

    protected void onCancelled() {
        String reason = this.getReasonCancelled();
        logger.info("[{}] task cancelled due to reason [{}]", (Object)this.getDeploymentId(), (Object)reason);
        this.stop(reason, (ActionListener<AcknowledgedResponse>)ActionListener.wrap(acknowledgedResponse -> {}, e -> logger.error(() -> "[" + this.getDeploymentId() + "] error stopping the deployment after task cancellation", (Throwable)e)));
    }

    public void infer(NlpInferenceInput input, InferenceConfigUpdate update, boolean skipQueue, TimeValue timeout, CancellableTask parentActionTask, ActionListener<InferenceResults> listener) {
        if (this.inferenceConfigHolder.get() == null) {
            listener.onFailure((Exception)ExceptionsHelper.conflictStatusException((String)"Trained model deployment [{}] is not initialized", (Object[])new Object[]{this.params.getDeploymentId()}));
            return;
        }
        if (!update.isSupported((InferenceConfig)this.inferenceConfigHolder.get())) {
            listener.onFailure((Exception)new ElasticsearchStatusException("Trained model [{}] is configured for task [{}] but called with task [{}]", RestStatus.FORBIDDEN, new Object[]{this.params.getModelId(), ((InferenceConfig)this.inferenceConfigHolder.get()).getName(), update.getName()}));
            return;
        }
        this.trainedModelAssignmentNodeService.infer(this, update.apply((InferenceConfig)this.inferenceConfigHolder.get()), input, skipQueue, timeout, parentActionTask, listener);
    }

    public Optional<ModelStats> modelStats() {
        return this.trainedModelAssignmentNodeService.modelStats(this);
    }

    public void clearCache(ActionListener<AcknowledgedResponse> listener) {
        this.trainedModelAssignmentNodeService.clearCache(this, listener);
    }

    public void setFailed(String reason) {
        this.failed = true;
        this.trainedModelAssignmentNodeService.failAssignment(this, reason);
    }

    public boolean isFailed() {
        return this.failed;
    }
}

