/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.ml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.neuralsearch.util.RetryUtil;

public class MLCommonsClientAccessor {
    @Generated
    private static final Logger log = LogManager.getLogger(MLCommonsClientAccessor.class);
    private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
    private final MachineLearningNodeClient mlClient;
    private static final String EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED = "failed while calling model, check error log for details";
    private static final String EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED = "encountered following error while calling a model";

    public void inferenceSentence(@NonNull String modelId, @NonNull String inputText, @NonNull ActionListener<List<Float>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), (ActionListener<List<List<Float>>>)ActionListener.wrap(response -> {
            if (response.size() != 1) {
                listener.onFailure((Exception)new IllegalStateException("Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"));
                return;
            }
            listener.onResponse((Object)((List)response.get(0)));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public void inferenceSentences(@NonNull String modelId, @NonNull List<String> inputText, @NonNull ActionListener<List<List<Float>>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener);
    }

    public void inferenceSentences(@NonNull List<String> targetResponseFilters, @NonNull String modelId, @NonNull List<String> inputText, @NonNull ActionListener<List<List<Float>>> listener) {
        Objects.requireNonNull(targetResponseFilters, "targetResponseFilters is marked non-null but is null");
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
    }

    public void inferenceSentencesWithMapResult(@NonNull String modelId, @NonNull List<String> inputText, @NonNull ActionListener<List<Map<String, ?>>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
    }

    public void inferenceSentences(@NonNull String modelId, @NonNull Map<String, String> inputObjects, @NonNull ActionListener<List<Float>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputObjects, "inputObjects is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
    }

    public void inferenceSimilarity(@NonNull String modelId, @NonNull String queryText, @NonNull List<String> inputText, @NonNull ActionListener<List<Float>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(queryText, "queryText is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
    }

    private void retryableInferenceSentencesWithMapResult(String modelId, List<String> inputText, int retryTime, ActionListener<List<Map<String, ?>>> listener) {
        MLInput mlInput = this.createMLTextInput(null, inputText);
        this.mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
            List<Map<String, ?>> result = this.buildMapResultFromResponse((MLOutput)mlOutput);
            listener.onResponse(result);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSentencesWithMapResult(modelId, inputText, retryTime + 1, listener), listener)));
    }

    private void retryableInferenceSentencesWithVectorResult(List<String> targetResponseFilters, String modelId, List<String> inputText, int retryTime, ActionListener<List<List<Float>>> listener) {
        MLInput mlInput = this.createMLTextInput(targetResponseFilters, inputText);
        this.mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
            List<List<Float>> vector = this.buildVectorFromResponse((MLOutput)mlOutput);
            listener.onResponse(vector);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTime + 1, listener), listener)));
    }

    private void retryableInferenceSimilarityWithVectorResult(String modelId, String queryText, List<String> inputText, int retryTime, ActionListener<List<Float>> listener) {
        MLInput mlInput = this.createMLTextPairsInput(queryText, inputText);
        this.mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
            List scores = this.buildVectorFromResponse((MLOutput)mlOutput).stream().map(v -> (Float)v.get(0)).collect(Collectors.toList());
            listener.onResponse(scores);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener), listener)));
    }

    private MLInput createMLTextInput(List<String> targetResponseFilters, List<String> inputText) {
        ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
        TextDocsInputDataSet inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
        return new MLInput(FunctionName.TEXT_EMBEDDING, null, (MLInputDataset)inputDataset);
    }

    private MLInput createMLTextPairsInput(String query, List<String> inputText) {
        TextSimilarityInputDataSet inputDataset = new TextSimilarityInputDataSet(query, inputText);
        return new MLInput(FunctionName.TEXT_SIMILARITY, null, (MLInputDataset)inputDataset);
    }

    private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
        ArrayList<List<Float>> vector = new ArrayList<List<Float>>();
        ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlOutput;
        List tensorOutputList = modelTensorOutput.getMlModelOutputs();
        for (ModelTensors tensors : tensorOutputList) {
            List tensorsList = tensors.getMlModelTensors();
            for (ModelTensor tensor : tensorsList) {
                if (Objects.isNull(tensor.getData())) {
                    if (Objects.nonNull(tensor.getDataAsMap()) && Strings.hasText((String)((String)tensor.getDataAsMap().get("message")))) {
                        String errorFromModel = (String)tensor.getDataAsMap().get("message");
                        throw new IllegalStateException(String.format(Locale.ROOT, "%s: %s", EXCEPTION_MESSAGE_PREFIX_MODEL_PREDICT_FAILED, errorFromModel));
                    }
                    log.error("Received following output tensor from a model, there is no detailed error message: {}", (Object)tensor.toString());
                    throw new IllegalStateException(EXCEPTION_MESSAGE_MODEL_PREDICT_FAILED);
                }
                vector.add(Arrays.stream(tensor.getData()).map(value -> (Float)value).collect(Collectors.toList()));
            }
        }
        return vector;
    }

    private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mlOutput) {
        ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlOutput;
        List tensorOutputList = modelTensorOutput.getMlModelOutputs();
        if (CollectionUtils.isEmpty((Collection)tensorOutputList) || CollectionUtils.isEmpty((Collection)((ModelTensors)tensorOutputList.get(0)).getMlModelTensors())) {
            throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]");
        }
        ArrayList resultMaps = new ArrayList();
        for (ModelTensors tensors : tensorOutputList) {
            List tensorList = tensors.getMlModelTensors();
            for (ModelTensor tensor : tensorList) {
                resultMaps.add(tensor.getDataAsMap());
            }
        }
        return resultMaps;
    }

    private List<Float> buildSingleVectorFromResponse(MLOutput mlOutput) {
        List<List<Float>> vector = this.buildVectorFromResponse(mlOutput);
        return vector.isEmpty() ? new ArrayList() : vector.get(0);
    }

    private void retryableInferenceSentencesWithSingleVectorResult(List<String> targetResponseFilters, String modelId, Map<String, String> inputObjects, int retryTime, ActionListener<List<Float>> listener) {
        MLInput mlInput = this.createMLMultimodalInput(targetResponseFilters, inputObjects);
        this.mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
            List<Float> vector = this.buildSingleVectorFromResponse((MLOutput)mlOutput);
            log.debug("Inference Response for input sentence is : {} ", vector);
            listener.onResponse(vector);
        }, e -> RetryUtil.handleRetryOrFailure(e, retryTime, () -> this.retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTime + 1, listener), listener)));
    }

    private MLInput createMLMultimodalInput(List<String> targetResponseFilters, Map<String, String> input) {
        ArrayList<String> inputText = new ArrayList<String>();
        inputText.add(input.get("inputText"));
        if (input.containsKey("inputImage")) {
            inputText.add(input.get("inputImage"));
        }
        ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
        TextDocsInputDataSet inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
        return new MLInput(FunctionName.TEXT_EMBEDDING, null, (MLInputDataset)inputDataset);
    }

    @Generated
    public MLCommonsClientAccessor(MachineLearningNodeClient mlClient) {
        this.mlClient = mlClient;
    }
}

