/*
 * Decompiled with CFR 0.152.
 */
package com.teamscale.core.ai.google;

import com.teamscale.commons.service.client.ServiceCallException;
import com.teamscale.core.ai.google.GoogleAiClientBase;
import com.teamscale.core.rest.client.IRetrofitApi;
import com.teamscale.core.rest.client.Retrofit;
import java.io.IOException;
import java.util.List;
import java.util.function.BiConsumer;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.POST;
import retrofit2.http.Path;

public class VertexAiClient
extends GoogleAiClientBase<IVertexAiService> {
    private final String model;

    public VertexAiClient(String model, String project, String location, String credentialsJson, double temperature, @Nullable BiConsumer<String, Long> usageReporter) throws IOException {
        super(project, location, credentialsJson, temperature, usageReporter, IVertexAiService.class);
        this.model = model;
    }

    @Override
    public String complete(String prompt) throws ServiceCallException {
        VertexAiRequestParameters parameters = new VertexAiRequestParameters(this.temperature, this.determineMaxOutputTokens());
        VertexAiPredictResponse response = (VertexAiPredictResponse)Retrofit.executeServiceCall(((IVertexAiService)this.ensureService()).performPrediction(this.model, new VertexAiPredictRequest(List.of(new VertexAiRequestInstance(prompt)), parameters))).orElseThrow(() -> new ServiceCallException("No response returned from LLM!"));
        if (this.usageReporter != null) {
            VertexAiClient.reportUsage(this.usageReporter, response.metadata.tokenMetadata);
        }
        return response.predictions.get((int)0).content;
    }

    private static void reportUsage(@NonNull BiConsumer<String, Long> usageReporter, VertexAiPredictResponseTokenMetadata tokenMetadata) {
        if (tokenMetadata.inputTokenCount != null) {
            usageReporter.accept("input-tokens", Long.valueOf(tokenMetadata.inputTokenCount.totalTokens));
            usageReporter.accept("input-billable-characters", Long.valueOf(tokenMetadata.inputTokenCount.totalBillableCharacters));
        }
        if (tokenMetadata.outputTokenCount != null) {
            usageReporter.accept("output-tokens", Long.valueOf(tokenMetadata.outputTokenCount.totalTokens));
            usageReporter.accept("output-billable-characters", Long.valueOf(tokenMetadata.outputTokenCount.totalBillableCharacters));
        }
    }

    private int determineMaxOutputTokens() {
        if (this.model.contains("unicorn")) {
            return 1024;
        }
        if (this.model.contains("32k")) {
            return 8192;
        }
        return 2048;
    }

    static interface IVertexAiService
    extends IRetrofitApi {
        @POST(value="publishers/google/models/{model}:predict")
        public Call<VertexAiPredictResponse> performPrediction(@Path(value="model") String var1, @Body VertexAiPredictRequest var2);
    }

    private record VertexAiRequestParameters(double temperature, int maxOutputTokens) {
    }

    record VertexAiPredictRequest(List<VertexAiRequestInstance> instances, VertexAiRequestParameters parameters) {
    }

    private record VertexAiRequestInstance(String prompt) {
    }

    record VertexAiPredictResponse(List<VertexAiPrediction> predictions, VertexAiPredictResponseMetadata metadata) {
    }

    private record VertexAiPredictResponseMetadata(VertexAiPredictResponseTokenMetadata tokenMetadata) {
    }

    private record VertexAiPredictResponseTokenMetadata(VertexAiTokenCount inputTokenCount, VertexAiTokenCount outputTokenCount) {
    }

    private record VertexAiPrediction(String content) {
    }

    private record VertexAiTokenCount(int totalTokens, int totalBillableCharacters) {
    }
}

