/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.AbstractPyTorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;

class InferencePyTorchAction
extends AbstractPyTorchAction<InferenceResults> {
    private static final Logger logger = LogManager.getLogger(InferencePyTorchAction.class);
    private final InferenceConfig config;
    private final NlpInferenceInput input;
    @Nullable
    private final CancellableTask parentActionTask;

    InferencePyTorchAction(String deploymentId, long requestId, TimeValue timeout, DeploymentManager.ProcessContext processContext, InferenceConfig config, NlpInferenceInput input, ThreadPool threadPool, @Nullable CancellableTask parentActionTask, ActionListener<InferenceResults> listener) {
        super(deploymentId, requestId, timeout, processContext, threadPool, listener);
        this.config = config;
        this.input = input;
        this.parentActionTask = parentActionTask;
    }

    private boolean isCancelled() {
        if (this.parentActionTask != null) {
            try {
                this.parentActionTask.ensureNotCancelled();
            }
            catch (TaskCancelledException ex) {
                logger.warn(() -> Strings.format((String)"[%s] %s", (Object[])new Object[]{this.getDeploymentId(), ex.getMessage()}));
                return true;
            }
        }
        return false;
    }

    protected void doRun() throws Exception {
        if (this.isNotified()) {
            logger.debug(() -> Strings.format((String)"[%s] skipping inference on request [%s] as it has timed out", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
            return;
        }
        if (this.isCancelled()) {
            this.onFailure("inference task cancelled");
            return;
        }
        String requestIdStr = String.valueOf(this.getRequestId());
        try {
            List<String> text = Collections.singletonList(this.input.extractInput((TrainedModelInput)this.getProcessContext().getModelInput().get()));
            NlpTask.Processor processor = (NlpTask.Processor)this.getProcessContext().getNlpTaskProcessor().get();
            processor.validateInputs(text);
            assert (this.config instanceof NlpConfig);
            NlpConfig nlpConfig = (NlpConfig)this.config;
            NlpTask.Request request = processor.getRequestBuilder(nlpConfig).buildRequest(text, requestIdStr, nlpConfig.getTokenization().getTruncate(), nlpConfig.getTokenization().getSpan());
            logger.debug(() -> Strings.format((String)"handling request [%s]", (Object[])new Object[]{requestIdStr}));
            if (this.isCancelled()) {
                this.onFailure("inference task cancelled");
                return;
            }
            this.getProcessContext().getResultProcessor().registerRequest(requestIdStr, (ActionListener<PyTorchResult>)ActionListener.wrap(result -> this.processResult((PyTorchResult)result, request.tokenization(), processor.getResultProcessor(nlpConfig)), this::onFailure));
            ((PyTorchProcess)this.getProcessContext().getProcess().get()).writeInferenceRequest(request.processInput());
        }
        catch (IOException e) {
            logger.error(() -> "[" + this.getDeploymentId() + "] error writing to inference process", (Throwable)e);
            this.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"Error writing to inference process", (Throwable)e)));
        }
        catch (Exception e) {
            logger.error(() -> "[" + this.getDeploymentId() + "] error running inference", (Throwable)e);
            this.onFailure(e);
        }
    }

    private void processResult(PyTorchResult pyTorchResult, TokenizationResult tokenization, NlpTask.ResultProcessor inferenceResultsProcessor) {
        if (pyTorchResult.isError()) {
            this.onFailure(pyTorchResult.errorResult().error());
            return;
        }
        logger.debug(() -> Strings.format((String)"[%s] retrieved result for request [%s]", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
        if (this.isNotified()) {
            logger.debug(() -> Strings.format((String)"[%s] skipping result processing for request [%s] as the request has timed out", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
            return;
        }
        if (this.isCancelled()) {
            this.onFailure("inference task cancelled");
            return;
        }
        InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult.inferenceResult());
        logger.debug(() -> Strings.format((String)"[%s] processed result for request [%s]", (Object[])new Object[]{this.getDeploymentId(), this.getRequestId()}));
        this.onSuccess(results);
    }

    @Override
    protected Logger getLogger() {
        return logger;
    }
}

