/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.rescorer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.elasticsearch.index.query.ParsedQuery;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.rescore.Rescorer;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.rescorer.FeatureExtractor;
import org.elasticsearch.xpack.ml.inference.rescorer.FieldValueFeatureExtractor;
import org.elasticsearch.xpack.ml.inference.rescorer.QueryFeatureExtractor;

public class InferenceRescorerContext
extends RescoreContext {
    final SearchExecutionContext executionContext;
    final LocalModel inferenceDefinition;
    final LearnToRankConfig inferenceConfig;

    public InferenceRescorerContext(int windowSize, Rescorer rescorer, LearnToRankConfig inferenceConfig, LocalModel inferenceDefinition, SearchExecutionContext executionContext) {
        super(windowSize, rescorer);
        this.executionContext = executionContext;
        this.inferenceDefinition = inferenceDefinition;
        this.inferenceConfig = inferenceConfig;
    }

    List<FeatureExtractor> buildFeatureExtractors(IndexSearcher searcher) throws IOException {
        assert (this.inferenceDefinition != null && this.inferenceConfig != null);
        ArrayList<FeatureExtractor> featureExtractors = new ArrayList<FeatureExtractor>();
        if (!this.inferenceDefinition.inputFields().isEmpty()) {
            featureExtractors.add(new FieldValueFeatureExtractor(new ArrayList<String>(this.inferenceDefinition.inputFields()), this.executionContext));
        }
        ArrayList<Weight> weights = new ArrayList<Weight>();
        ArrayList<String> queryFeatureNames = new ArrayList<String>();
        for (LearnToRankFeatureExtractorBuilder featureExtractorBuilder : this.inferenceConfig.getFeatureExtractorBuilders()) {
            if (!(featureExtractorBuilder instanceof QueryExtractorBuilder)) continue;
            QueryExtractorBuilder queryExtractorBuilder = (QueryExtractorBuilder)featureExtractorBuilder;
            Query query = this.executionContext.toQuery(queryExtractorBuilder.query().getParsedQuery()).query();
            Weight weight = searcher.rewrite(query).createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
            weights.add(weight);
            queryFeatureNames.add(queryExtractorBuilder.featureName());
        }
        if (!weights.isEmpty()) {
            featureExtractors.add(new QueryFeatureExtractor(queryFeatureNames, weights));
        }
        return featureExtractors;
    }

    public List<ParsedQuery> getParsedQueries() {
        if (this.inferenceConfig == null) {
            return List.of();
        }
        ArrayList<ParsedQuery> parsedQueries = new ArrayList<ParsedQuery>();
        for (LearnToRankFeatureExtractorBuilder featureExtractorBuilder : this.inferenceConfig.getFeatureExtractorBuilders()) {
            if (!(featureExtractorBuilder instanceof QueryExtractorBuilder)) continue;
            QueryExtractorBuilder queryExtractorBuilder = (QueryExtractorBuilder)featureExtractorBuilder;
            parsedQueries.add(this.executionContext.toQuery(queryExtractorBuilder.query().getParsedQuery()));
        }
        return parsedQueries;
    }
}

