/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.combination;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;

public class ScoreCombiner {
    @Generated
    private static final Logger log = LogManager.getLogger(ScoreCombiner.class);
    public static final Float MAX_SCORE_WHEN_NO_HITS_FOUND = Float.valueOf(0.0f);
    private static final Comparator<ScoreDoc> SORTING_TIE_BREAKER = (o1, o2) -> {
        int scoreComparison = Double.compare(o1.score, o2.score);
        if (scoreComparison != 0) {
            return scoreComparison;
        }
        int docIdComparison = Integer.compare(o1.doc, o2.doc);
        if (docIdComparison != 0) {
            return docIdComparison;
        }
        return 1;
    };

    public void combineScores(CombineScoresDto combineScoresDTO) {
        ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique();
        Sort sort = combineScoresDTO.getSort();
        combineScoresDTO.getQueryTopDocs().forEach(compoundQueryTopDocs -> this.combineShardScores(scoreCombinationTechnique, (CompoundTopDocs)compoundQueryTopDocs, sort));
    }

    private void combineShardScores(ScoreCombinationTechnique scoreCombinationTechnique, CompoundTopDocs compoundQueryTopDocs, Sort sort) {
        if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0L) {
            return;
        }
        List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
        Map<Integer, float[]> normalizedScoresPerDoc = this.getNormalizedScoresPerDocument(topDocsPerSubQuery);
        Map<Integer, Float> combinedNormalizedScoresByDocId = this.combineScoresAndGetCombinedNormalizedScoresPerDocument(normalizedScoresPerDoc, scoreCombinationTechnique);
        Collection<Integer> sortedDocsIds = this.getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId);
        this.updateQueryTopDocsWithCombinedScores(compoundQueryTopDocs, topDocsPerSubQuery, combinedNormalizedScoresByDocId, sortedDocsIds, this.getDocIdSortFieldsMap(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sort), sort != null);
    }

    private boolean isSortOrderByScore(Sort sort) {
        if (sort == null) {
            return false;
        }
        for (SortField sortField : sort.getSort()) {
            if (!SortField.Type.SCORE.equals((Object)sortField.getType())) continue;
            return true;
        }
        return false;
    }

    private List<TopFieldDocs> getTopFieldDocs(Sort sort, List<TopDocs> topDocsPerSubQuery) {
        if (sort == null) {
            return null;
        }
        ArrayList<TopFieldDocs> topFieldDocs = new ArrayList<TopFieldDocs>();
        for (TopDocs topDocs : topDocsPerSubQuery) {
            if (topDocs.scoreDocs.length == 0) continue;
            topFieldDocs.add((TopFieldDocs)topDocs);
        }
        return topFieldDocs;
    }

    private Map<Integer, Object[]> getDocIdSortFieldsMap(CompoundTopDocs compoundTopDocs, Map<Integer, Float> combinedNormalizedScoresByDocId, Sort sort) {
        if (sort == null) {
            return null;
        }
        HashMap<Integer, Object[]> docIdSortFieldMap = new HashMap<Integer, Object[]>();
        List<TopDocs> topFieldDocs = compoundTopDocs.getTopDocs();
        boolean isSortByScore = this.isSortOrderByScore(sort);
        for (TopDocs topDocs : topFieldDocs) {
            for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
                FieldDoc fieldDoc = (FieldDoc)scoreDoc;
                if (docIdSortFieldMap.get(fieldDoc.doc) != null) continue;
                if (isSortByScore) {
                    docIdSortFieldMap.put(fieldDoc.doc, new Object[]{combinedNormalizedScoresByDocId.get(fieldDoc.doc)});
                    continue;
                }
                docIdSortFieldMap.put(fieldDoc.doc, fieldDoc.fields);
            }
        }
        return docIdSortFieldMap;
    }

    private List<Integer> getSortedDocIds(Map<Integer, Float> combinedNormalizedScoresByDocId) {
        ArrayList<Integer> sortedDocsIds = new ArrayList<Integer>(combinedNormalizedScoresByDocId.keySet());
        sortedDocsIds.sort((a, b) -> Float.compare(((Float)combinedNormalizedScoresByDocId.get(b)).floatValue(), ((Float)combinedNormalizedScoresByDocId.get(a)).floatValue()));
        return sortedDocsIds;
    }

    private Set<Integer> getSortedDocIdsBySortCriteria(List<TopFieldDocs> topFieldDocs, Sort sort) {
        if (Objects.isNull(topFieldDocs)) {
            throw new IllegalArgumentException("topFieldDocs cannot be null when sorting is enabled.");
        }
        int size = 0;
        for (TopFieldDocs topFieldDoc : topFieldDocs) {
            size += topFieldDoc.scoreDocs.length;
        }
        TopFieldDocs sortedTopDocs = TopDocs.merge((Sort)sort, (int)0, (int)size, (TopFieldDocs[])topFieldDocs.toArray(new TopFieldDocs[0]), SORTING_TIE_BREAKER);
        LinkedHashSet<Integer> uniqueDocIds = new LinkedHashSet<Integer>();
        for (ScoreDoc scoreDoc : sortedTopDocs.scoreDocs) {
            uniqueDocIds.add(scoreDoc.doc);
        }
        return uniqueDocIds;
    }

    private List<ScoreDoc> getCombinedScoreDocs(CompoundTopDocs compoundQueryTopDocs, Map<Integer, Float> combinedNormalizedScoresByDocId, Collection<Integer> sortedScores, long maxHits, Map<Integer, Object[]> docIdSortFieldMap, boolean isSortingEnabled) {
        int shardId = -1;
        if (!compoundQueryTopDocs.getScoreDocs().isEmpty()) {
            shardId = compoundQueryTopDocs.getScoreDocs().get((int)0).shardIndex;
        }
        ArrayList<ScoreDoc> scoreDocs = new ArrayList<ScoreDoc>();
        int hitCount = 0;
        for (Integer docId : sortedScores) {
            if ((long)hitCount == maxHits) break;
            scoreDocs.add(this.getScoreDoc(isSortingEnabled, docId, shardId, combinedNormalizedScoresByDocId, docIdSortFieldMap));
            ++hitCount;
        }
        return scoreDocs;
    }

    private ScoreDoc getScoreDoc(boolean isSortEnabled, int docId, int shardId, Map<Integer, Float> combinedNormalizedScoresByDocId, Map<Integer, Object[]> docIdSortFieldMap) {
        if (isSortEnabled && docIdSortFieldMap != null) {
            return new FieldDoc(docId, combinedNormalizedScoresByDocId.get(docId).floatValue(), docIdSortFieldMap.get(docId), shardId);
        }
        return new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId).floatValue(), shardId);
    }

    public Map<Integer, float[]> getNormalizedScoresPerDocument(List<TopDocs> topDocsPerSubQuery) {
        HashMap<Integer, float[]> normalizedScoresPerDoc = new HashMap<Integer, float[]>();
        for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
            TopDocs topDocs = topDocsPerSubQuery.get(j);
            for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
                normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> {
                    float[] scores = new float[topDocsPerSubQuery.size()];
                    return scores;
                });
                ((float[])normalizedScoresPerDoc.get((Object)Integer.valueOf((int)scoreDoc.doc)))[j] = scoreDoc.score;
            }
        }
        return normalizedScoresPerDoc;
    }

    private Map<Integer, Float> combineScoresAndGetCombinedNormalizedScoresPerDocument(Map<Integer, float[]> normalizedScoresPerDocument, ScoreCombinationTechnique scoreCombinationTechnique) {
        return normalizedScoresPerDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> Float.valueOf(scoreCombinationTechnique.combine((float[])entry.getValue()))));
    }

    private void updateQueryTopDocsWithCombinedScores(CompoundTopDocs compoundQueryTopDocs, List<TopDocs> topDocsPerSubQuery, Map<Integer, Float> combinedNormalizedScoresByDocId, Collection<Integer> sortedScores, Map<Integer, Object[]> docIdSortFieldMap, boolean isSortingEnabled) {
        long maxHits = compoundQueryTopDocs.getTotalHits().value;
        compoundQueryTopDocs.setScoreDocs(this.getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits, docIdSortFieldMap, isSortingEnabled));
        compoundQueryTopDocs.setTotalHits(this.getTotalHits(topDocsPerSubQuery, maxHits));
    }

    private TotalHits getTotalHits(List<TopDocs> topDocsPerSubQuery, long maxHits) {
        TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO;
        if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
            totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
        }
        return new TotalHits(maxHits, totalHits);
    }

    public Map<SearchShard, List<ExplanationDetails>> explain(List<CompoundTopDocs> queryTopDocs, ScoreCombinationTechnique combinationTechnique, Sort sort) {
        HashMap<SearchShard, List<ExplanationDetails>> explanations = new HashMap<SearchShard, List<ExplanationDetails>>();
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            explanations.putIfAbsent(compoundQueryTopDocs.getSearchShard(), this.explainByShard(combinationTechnique, compoundQueryTopDocs, sort));
        }
        return explanations;
    }

    private List<ExplanationDetails> explainByShard(ScoreCombinationTechnique scoreCombinationTechnique, CompoundTopDocs compoundQueryTopDocs, Sort sort) {
        if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0L) {
            return List.of();
        }
        Map<Integer, float[]> normalizedScoresPerDoc = this.getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs());
        Map<Integer, Float> combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> Float.valueOf(scoreCombinationTechnique.combine((float[])entry.getValue()))));
        Collection<Integer> sortedDocsIds = this.getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId);
        ArrayList<ExplanationDetails> listOfExplanations = new ArrayList<ExplanationDetails>();
        String combinationDescription = String.format(Locale.ROOT, "%s combination of:", ((ExplainableTechnique)((Object)scoreCombinationTechnique)).describe());
        for (int docId : sortedDocsIds) {
            ExplanationDetails explanation = new ExplanationDetails(docId, List.of(Pair.of((Object)combinedNormalizedScoresByDocId.get(docId), (Object)combinationDescription)));
            listOfExplanations.add(explanation);
        }
        return listOfExplanations;
    }

    private Collection<Integer> getSortedDocsIds(CompoundTopDocs compoundQueryTopDocs, Sort sort, Map<Integer, Float> combinedNormalizedScoresByDocId) {
        Collection<Integer> sortedDocsIds;
        if (sort != null) {
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            sortedDocsIds = this.getSortedDocIdsBySortCriteria(this.getTopFieldDocs(sort, topDocsPerSubQuery), sort);
        } else {
            sortedDocsIds = this.getSortedDocIds(combinedNormalizedScoresByDocId);
        }
        return sortedDocsIds;
    }
}

