/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.io.IOException;
import java.util.List;
import java.util.Optional;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

public class TextSimilarityProcessor
extends NlpTask.Processor {
    TextSimilarityProcessor(NlpTokenizer tokenizer) {
        super(tokenizer);
    }

    @Override
    public void validateInputs(List<String> inputs) {
    }

    @Override
    public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
        if (nlpConfig instanceof TextSimilarityConfig) {
            TextSimilarityConfig textSimilarityConfig = (TextSimilarityConfig)nlpConfig;
            return new RequestBuilder(this.tokenizer, textSimilarityConfig.getText());
        }
        throw ExceptionsHelper.badRequestException((String)"please provide configuration update for text_similarity task including the desired [text]", (Object[])new Object[0]);
    }

    @Override
    public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
        if (nlpConfig instanceof TextSimilarityConfig) {
            TextSimilarityConfig textSimilarityConfig = (TextSimilarityConfig)nlpConfig;
            return new ResultProcessor(textSimilarityConfig.getText(), textSimilarityConfig.getResultsField(), textSimilarityConfig.getSpanScoreFunction());
        }
        throw ExceptionsHelper.badRequestException((String)"please provide configuration update for text_similarity task including the desired [text]", (Object[])new Object[0]);
    }

    static SpanScoreFunction fromConfig(TextSimilarityConfig.SpanScoreFunction spanScoreFunction) {
        return switch (spanScoreFunction) {
            default -> throw new IncompatibleClassChangeError();
            case TextSimilarityConfig.SpanScoreFunction.MAX -> new Max();
            case TextSimilarityConfig.SpanScoreFunction.MEAN -> new Mean();
        };
    }

    record RequestBuilder(NlpTokenizer tokenizer, String sequence) implements NlpTask.RequestBuilder
    {
        @Override
        public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate, int span) throws IOException {
            if (inputs.size() > 1) {
                throw ExceptionsHelper.badRequestException((String)"Unable to do text_similarity on more than one text input at a time", (Object[])new Object[0]);
            }
            String context = inputs.get(0);
            List<TokenizationResult.Tokens> tokenizations = this.tokenizer.tokenize(this.sequence, context, truncate, span, 0);
            TokenizationResult result = this.tokenizer.buildTokenizationResult(tokenizations);
            return result.buildRequest(requestId, truncate);
        }
    }

    record ResultProcessor(String question, String resultsField, TextSimilarityConfig.SpanScoreFunction function) implements NlpTask.ResultProcessor
    {
        @Override
        public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
            if (pyTorchResult.getInferenceResult().length < 1) {
                throw new ElasticsearchStatusException("text_similarity result has no data", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            SpanScoreFunction spanScoreFunction = TextSimilarityProcessor.fromConfig(this.function);
            for (int i = 0; i < pyTorchResult.getInferenceResult()[0].length; ++i) {
                double[] result = pyTorchResult.getInferenceResult()[0][i];
                if (result.length != 1) {
                    throw new ElasticsearchStatusException("Expected exactly [1] value in text_similarity result; got [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{result.length});
                }
                spanScoreFunction.accept(result[0]);
            }
            return new TextSimilarityInferenceResults(Optional.ofNullable(this.resultsField).orElse("predicted_value"), spanScoreFunction.score(), tokenization.anyTruncated());
        }
    }

    private static class Max
    implements SpanScoreFunction {
        private double score = Double.NEGATIVE_INFINITY;

        private Max() {
        }

        @Override
        public void accept(double v) {
            this.score = Math.max(this.score, v);
        }

        @Override
        public double score() {
            return this.score;
        }
    }

    private static class Mean
    implements SpanScoreFunction {
        private double score = 0.0;
        private int count = 0;

        private Mean() {
        }

        @Override
        public void accept(double v) {
            this.score += v;
            ++this.count;
        }

        @Override
        public double score() {
            return this.score / (double)this.count;
        }
    }

    private static interface SpanScoreFunction {
        public void accept(double var1);

        public double score();
    }
}

