ChatGPTRequest.java 9.2 KB
Newer Older
1
package com.owlmaddie.chat;
2 3

import com.google.gson.Gson;
4
import com.owlmaddie.commands.ConfigurationHandler;
5
import com.owlmaddie.json.ChatGPTResponse;
6
import com.owlmaddie.network.ServerPackets;
7 8 9 10
import net.minecraft.resource.ResourceManager;
import net.minecraft.util.Identifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
11

12
import java.io.*;
13 14
import java.net.HttpURLConnection;
import java.net.URL;
15
import java.nio.charset.StandardCharsets;
16
import java.util.ArrayList;
17
import java.util.Collections;
18 19
import java.util.List;
import java.util.Map;
20
import java.util.concurrent.CompletableFuture;
21
import java.util.regex.Pattern;
22

23 24 25 26
/**
 * The {@code ChatGPTRequest} class is used to send HTTP requests to our LLM to generate
 * messages.
 */
27
public class ChatGPTRequest {
28
    public static final Logger LOGGER = LoggerFactory.getLogger("creaturechat");
29
    public static String lastErrorMessage;
30

31 32 33 34 35 36 37 38 39 40 41 42 43
    static class ChatGPTRequestMessage {
        String role;
        String content;

        public ChatGPTRequestMessage(String role, String content) {
            this.role = role;
            this.content = content;
        }
    }

    static class ChatGPTRequestPayload {
        String model;
        List<ChatGPTRequestMessage> messages;
44
        ResponseFormat response_format;
45
        float temperature;
46
        int max_tokens;
47

48
        public ChatGPTRequestPayload(String model, List<ChatGPTRequestMessage> messages, Boolean jsonMode, float temperature, int maxTokens) {
49 50
            this.model = model;
            this.messages = messages;
51
            this.temperature = temperature;
52
            this.max_tokens = maxTokens;
53 54 55 56 57
            if (jsonMode) {
                this.response_format = new ResponseFormat("json_object");
            } else {
                this.response_format = new ResponseFormat("text");
            }
58 59 60 61 62 63 64 65
        }
    }

    static class ResponseFormat {
        String type;

        public ResponseFormat(String type) {
            this.type = type;
66 67 68 69 70
        }
    }

    // This method should be called in an appropriate context where ResourceManager is available
    public static String loadPromptFromResource(ResourceManager resourceManager, String filePath) {
71
        Identifier fileIdentifier = new Identifier("creaturechat", filePath);
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        try (InputStream inputStream = resourceManager.getResource(fileIdentifier).get().getInputStream();
             BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {

            StringBuilder contentBuilder = new StringBuilder();
            String line;
            while ((line = reader.readLine()) != null) {
                contentBuilder.append(line).append("\n");
            }
            return contentBuilder.toString();
        } catch (Exception e) {
            LOGGER.error("Failed to read prompt file", e);
        }
        return null;
    }

    // Function to replace placeholders in the template
    public static String replacePlaceholders(String template, Map<String, String> replacements) {
        String result = template;
        for (Map.Entry<String, String> entry : replacements.entrySet()) {
            result = result.replaceAll(Pattern.quote("{{" + entry.getKey() + "}}"), entry.getValue());
        }
        return result.replace("\"", "") ;
    }

96 97 98 99 100
    // Function to roughly estimate # of OpenAI tokens in String
    private static int estimateTokenSize(String text) {
        return (int) Math.round(text.length() / 3.5);
    }

101
    public static CompletableFuture<String> fetchMessageFromChatGPT(String systemPrompt, Map<String, String> context, List<ChatDataManager.ChatMessage> messageHistory, Boolean jsonMode) {
102
        // Get config (api key, url, settings)
103
        ConfigurationHandler.Config config = new ConfigurationHandler(ServerPackets.serverInstance).loadConfig();
104 105 106 107 108 109 110 111 112

        // Init API & LLM details
        String apiUrl = config.getUrl();
        String apiKey = config.getApiKey();
        String modelName = config.getModel();
        int maxContextTokens = config.getMaxContextTokens();
        int maxOutputTokens = config.getMaxOutputTokens();
        double percentOfContext = config.getPercentOfContext();

113 114
        return CompletableFuture.supplyAsync(() -> {
            try {
115
                String systemMessage = "";
116
                if (systemPrompt != null && !systemPrompt.isEmpty()) {
117
                    systemMessage = loadPromptFromResource(ServerPackets.serverInstance.getResourceManager(), "prompts/" + systemPrompt);
118
                    systemMessage = replacePlaceholders(systemMessage, context);
119 120
                }

121
                URL url = new URL(apiUrl);
122 123 124
                HttpURLConnection connection = (HttpURLConnection) url.openConnection();
                connection.setRequestMethod("POST");
                connection.setRequestProperty("Content-Type", "application/json");
125
                connection.setRequestProperty("Authorization", "Bearer " + apiKey);
126
                connection.setDoOutput(true);
127 128
                connection.setConnectTimeout(10000); // 10 seconds connection timeout
                connection.setReadTimeout(10000); // 10 seconds read timeout
129

130
                // Create messages list (for chat history)
131
                List<ChatGPTRequestMessage> messages = new ArrayList<>();
132 133 134 135 136

                // Don't exceed a specific % of total context window (to limit message history in request)
                int remainingContextTokens = (int) ((maxContextTokens - maxOutputTokens) * percentOfContext);
                int usedTokens = estimateTokenSize("system: " + systemMessage);

137 138 139
                // Iterate backwards through the message history
                for (int i = messageHistory.size() - 1; i >= 0; i--) {
                    ChatDataManager.ChatMessage chatMessage = messageHistory.get(i);
140 141
                    String senderName = chatMessage.sender.toString().toLowerCase();
                    String messageText = replacePlaceholders(chatMessage.message, context);
142
                    int messageTokens = estimateTokenSize(senderName + ": " + messageText);
143

144
                    if (usedTokens + messageTokens > remainingContextTokens) {
145
                        break;  // If adding this message would exceed the token limit, stop adding more messages
146
                    }
147 148

                    // Add the message to the temporary list
149
                    messages.add(new ChatGPTRequestMessage(senderName, messageText));
150
                    usedTokens += messageTokens;
151
                }
152

153 154 155 156 157 158 159
                // Add system message
                messages.add(new ChatGPTRequestMessage("system", systemMessage));

                // Reverse the list to restore chronological order
                // This is needed since we build the list in reverse order for token restricting above
                Collections.reverse(messages);

160
                // Convert JSON to String
161
                ChatGPTRequestPayload payload = new ChatGPTRequestPayload(modelName, messages, jsonMode, 1.0f, maxOutputTokens);
162 163
                Gson gsonInput = new Gson();
                String jsonInputString = gsonInput.toJson(payload);
164 165

                try (OutputStream os = connection.getOutputStream()) {
166
                    byte[] input = jsonInputString.getBytes(StandardCharsets.UTF_8);
167 168 169
                    os.write(input, 0, input.length);
                }

170 171 172 173 174 175 176 177
                // Check for error message in response
                if (connection.getResponseCode() >= HttpURLConnection.HTTP_BAD_REQUEST) {
                    try (BufferedReader errorReader = new BufferedReader(new InputStreamReader(connection.getErrorStream(), StandardCharsets.UTF_8))) {
                        String errorLine;
                        StringBuilder errorResponse = new StringBuilder();
                        while ((errorLine = errorReader.readLine()) != null) {
                            errorResponse.append(errorLine.trim());
                        }
178
                        LOGGER.error("Error response from API: " + errorResponse);
179
                        lastErrorMessage = errorResponse.toString();
180 181
                    }
                    return null;
182 183
                } else {
                    lastErrorMessage = null;
184 185 186
                }

                try (BufferedReader br = new BufferedReader(new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8))) {
187 188 189 190 191 192
                    StringBuilder response = new StringBuilder();
                    String responseLine;
                    while ((responseLine = br.readLine()) != null) {
                        response.append(responseLine.trim());
                    }

193 194
                    Gson gsonOutput = new Gson();
                    ChatGPTResponse chatGPTResponse = gsonOutput.fromJson(response.toString(), ChatGPTResponse.class);
195
                    if (chatGPTResponse != null && chatGPTResponse.choices != null && !chatGPTResponse.choices.isEmpty()) {
196
                        String content = chatGPTResponse.choices.get(0).message.content;
197
                        LOGGER.info("Generated message: " + content);
198 199 200
                        return content;
                    }
                }
201 202
            } catch (IOException e) {
                LOGGER.error("Failed to fetch message from ChatGPT", e);
203 204 205 206 207 208
            }
            return null; // If there was an error or no response, return null
        });
    }
}