ChatGPTRequest.java 9.53 KB
Newer Older
1
package com.owlmaddie.chat;
2 3

import com.google.gson.Gson;
4
import com.google.gson.JsonSyntaxException;
5
import com.owlmaddie.commands.ConfigurationHandler;
6
import com.owlmaddie.json.ChatGPTResponse;
7 8
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
9

10
import java.io.*;
11 12
import java.net.HttpURLConnection;
import java.net.URL;
13
import java.nio.charset.StandardCharsets;
14
import java.util.*;
15
import java.util.concurrent.CompletableFuture;
16
import java.util.regex.Pattern;
17

18 19 20 21
/**
 * The {@code ChatGPTRequest} class is used to send HTTP requests to our LLM to generate
 * messages.
 */
22
public class ChatGPTRequest {
23
    public static final Logger LOGGER = LoggerFactory.getLogger("creaturechat");
24
    public static String lastErrorMessage;
25

26 27 28 29 30 31 32 33 34 35 36 37 38
    static class ChatGPTRequestMessage {
        String role;
        String content;

        public ChatGPTRequestMessage(String role, String content) {
            this.role = role;
            this.content = content;
        }
    }

    static class ChatGPTRequestPayload {
        String model;
        List<ChatGPTRequestMessage> messages;
39
        ResponseFormat response_format;
40
        float temperature;
41
        int max_tokens;
42
        boolean stream;
43

44
        public ChatGPTRequestPayload(String model, List<ChatGPTRequestMessage> messages, Boolean jsonMode, float temperature, int maxTokens) {
45 46
            this.model = model;
            this.messages = messages;
47
            this.temperature = temperature;
48
            this.max_tokens = maxTokens;
49
            this.stream = false;
50 51 52 53 54
            if (jsonMode) {
                this.response_format = new ResponseFormat("json_object");
            } else {
                this.response_format = new ResponseFormat("text");
            }
55 56 57 58 59 60 61 62
        }
    }

    static class ResponseFormat {
        String type;

        public ResponseFormat(String type) {
            this.type = type;
63 64 65
        }
    }

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    public static String removeQuotes(String str) {
        if (str != null && str.length() > 1 && str.startsWith("\"") && str.endsWith("\"")) {
            return str.substring(1, str.length() - 1);
        }
        return str;
    }

    // Class to represent the error response structure
    public static class ErrorResponse {
        Error error;

        static class Error {
            String message;
            String type;
            String code;
        }
    }

    public static String parseAndLogErrorResponse(String errorResponse) {
        try {
            Gson gson = new Gson();
            ErrorResponse response = gson.fromJson(errorResponse, ErrorResponse.class);

            if (response.error != null) {
                LOGGER.error("Error Message: " + response.error.message);
                LOGGER.error("Error Type: " + response.error.type);
                LOGGER.error("Error Code: " + response.error.code);
                return response.error.message;
            } else {
                LOGGER.error("Unknown error response: " + errorResponse);
96
                return "Unknown: " + errorResponse;
97 98 99 100 101 102 103 104 105 106
            }
        } catch (JsonSyntaxException e) {
            LOGGER.warn("Failed to parse error response as JSON, falling back to plain text");
            LOGGER.error("Error response: " + errorResponse);
        } catch (Exception e) {
            LOGGER.error("Failed to parse error response", e);
        }
        return removeQuotes(errorResponse);
    }

107 108 109 110 111 112 113 114 115
    // Function to replace placeholders in the template
    public static String replacePlaceholders(String template, Map<String, String> replacements) {
        String result = template;
        for (Map.Entry<String, String> entry : replacements.entrySet()) {
            result = result.replaceAll(Pattern.quote("{{" + entry.getKey() + "}}"), entry.getValue());
        }
        return result.replace("\"", "") ;
    }

116 117 118 119 120
    // Function to roughly estimate # of OpenAI tokens in String
    private static int estimateTokenSize(String text) {
        return (int) Math.round(text.length() / 3.5);
    }

121
    public static CompletableFuture<String> fetchMessageFromChatGPT(ConfigurationHandler.Config config, String systemPrompt, Map<String, String> contextData, List<ChatMessage> messageHistory, Boolean jsonMode) {
122 123 124 125
        // Init API & LLM details
        String apiUrl = config.getUrl();
        String apiKey = config.getApiKey();
        String modelName = config.getModel();
126
        Integer timeout = config.getTimeout() * 1000;
127 128 129 130
        int maxContextTokens = config.getMaxContextTokens();
        int maxOutputTokens = config.getMaxOutputTokens();
        double percentOfContext = config.getPercentOfContext();

131 132
        return CompletableFuture.supplyAsync(() -> {
            try {
133 134
                // Replace placeholders
                String systemMessage = replacePlaceholders(systemPrompt, contextData);
135

136
                URL url = new URL(apiUrl);
137 138 139
                HttpURLConnection connection = (HttpURLConnection) url.openConnection();
                connection.setRequestMethod("POST");
                connection.setRequestProperty("Content-Type", "application/json");
140
                connection.setRequestProperty("Authorization", "Bearer " + apiKey);
141
                connection.setDoOutput(true);
142 143
                connection.setConnectTimeout(timeout); // 10 seconds connection timeout
                connection.setReadTimeout(timeout); // 10 seconds read timeout
144

145
                // Create messages list (for chat history)
146
                List<ChatGPTRequestMessage> messages = new ArrayList<>();
147 148 149 150 151

                // Don't exceed a specific % of total context window (to limit message history in request)
                int remainingContextTokens = (int) ((maxContextTokens - maxOutputTokens) * percentOfContext);
                int usedTokens = estimateTokenSize("system: " + systemMessage);

152 153
                // Iterate backwards through the message history
                for (int i = messageHistory.size() - 1; i >= 0; i--) {
154
                    ChatMessage chatMessage = messageHistory.get(i);
155
                    String senderName = chatMessage.sender.toString().toLowerCase(Locale.ENGLISH);
156
                    String messageText = replacePlaceholders(chatMessage.message, contextData);
157
                    int messageTokens = estimateTokenSize(senderName + ": " + messageText);
158

159
                    if (usedTokens + messageTokens > remainingContextTokens) {
160
                        break;  // If adding this message would exceed the token limit, stop adding more messages
161
                    }
162 163

                    // Add the message to the temporary list
164
                    messages.add(new ChatGPTRequestMessage(senderName, messageText));
165
                    usedTokens += messageTokens;
166
                }
167

168 169 170 171 172 173 174
                // Add system message
                messages.add(new ChatGPTRequestMessage("system", systemMessage));

                // Reverse the list to restore chronological order
                // This is needed since we build the list in reverse order for token restricting above
                Collections.reverse(messages);

175
                // Convert JSON to String
176
                ChatGPTRequestPayload payload = new ChatGPTRequestPayload(modelName, messages, jsonMode, 1.0f, maxOutputTokens);
177 178
                Gson gsonInput = new Gson();
                String jsonInputString = gsonInput.toJson(payload);
179 180

                try (OutputStream os = connection.getOutputStream()) {
181
                    byte[] input = jsonInputString.getBytes(StandardCharsets.UTF_8);
182 183 184
                    os.write(input, 0, input.length);
                }

185 186 187 188 189 190 191 192
                // Check for error message in response
                if (connection.getResponseCode() >= HttpURLConnection.HTTP_BAD_REQUEST) {
                    try (BufferedReader errorReader = new BufferedReader(new InputStreamReader(connection.getErrorStream(), StandardCharsets.UTF_8))) {
                        String errorLine;
                        StringBuilder errorResponse = new StringBuilder();
                        while ((errorLine = errorReader.readLine()) != null) {
                            errorResponse.append(errorLine.trim());
                        }
193 194 195 196 197 198

                        // Parse and log the error response using Gson
                        String cleanError = parseAndLogErrorResponse(errorResponse.toString());
                        lastErrorMessage = cleanError;
                    } catch (Exception e) {
                        LOGGER.error("Failed to read error response", e);
199 200
                    }
                    return null;
201 202
                } else {
                    lastErrorMessage = null;
203 204 205
                }

                try (BufferedReader br = new BufferedReader(new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
206 207 208 209 210 211
                    StringBuilder response = new StringBuilder();
                    String responseLine;
                    while ((responseLine = br.readLine()) != null) {
                        response.append(responseLine.trim());
                    }

212 213
                    Gson gsonOutput = new Gson();
                    ChatGPTResponse chatGPTResponse = gsonOutput.fromJson(response.toString(), ChatGPTResponse.class);
214
                    if (chatGPTResponse != null && chatGPTResponse.choices != null && !chatGPTResponse.choices.isEmpty()) {
215
                        String content = chatGPTResponse.choices.get(0).message.content;
216 217 218
                        return content;
                    }
                }
219 220
            } catch (IOException e) {
                LOGGER.error("Failed to fetch message from ChatGPT", e);
221 222 223 224 225 226
            }
            return null; // If there was an error or no response, return null
        });
    }
}