Commit 9b5f8870 by Jonathan Thomas

Small refactor to ChatGPT requests, separating all the key variables to make it…

Small refactor to ChatGPT requests, separating all the key variables to make it easier to modify. Also, adding in token estimation and limits: 200 output tokens, 75% of the 16k context window, so super long conversations exceeding 12k tokens will be trimmed to the most recent messages.
parent 2a36b9ad
...@@ -25,6 +25,14 @@ import java.util.regex.Pattern; ...@@ -25,6 +25,14 @@ import java.util.regex.Pattern;
public class ChatGPTRequest { public class ChatGPTRequest {
public static final Logger LOGGER = LoggerFactory.getLogger("mobgpt"); public static final Logger LOGGER = LoggerFactory.getLogger("mobgpt");
// Init API & LLM details
private static final String apiUrl = "https://api.openai.com/v1/chat/completions";
private static final String apiKey = "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";
private static final String modelName = "gpt-3.5-turbo";
private static final int maxContextTokens = 16385;
private static final int maxOutputTokens = 200;
private static final double percentOfContext = 0.75;
static class ChatGPTRequestMessage { static class ChatGPTRequestMessage {
String role; String role;
String content; String content;
...@@ -40,11 +48,13 @@ public class ChatGPTRequest { ...@@ -40,11 +48,13 @@ public class ChatGPTRequest {
List<ChatGPTRequestMessage> messages; List<ChatGPTRequestMessage> messages;
ResponseFormat response_format; ResponseFormat response_format;
float temperature; float temperature;
int max_tokens;
public ChatGPTRequestPayload(String model, List<ChatGPTRequestMessage> messages, Boolean jsonMode, float temperature) { public ChatGPTRequestPayload(String model, List<ChatGPTRequestMessage> messages, Boolean jsonMode, float temperature, int maxTokens) {
this.model = model; this.model = model;
this.messages = messages; this.messages = messages;
this.temperature = temperature; this.temperature = temperature;
this.max_tokens = maxTokens;
if (jsonMode) { if (jsonMode) {
this.response_format = new ResponseFormat("json_object"); this.response_format = new ResponseFormat("json_object");
} else { } else {
...@@ -88,6 +98,11 @@ public class ChatGPTRequest { ...@@ -88,6 +98,11 @@ public class ChatGPTRequest {
return result.replace("\"", "") ; return result.replace("\"", "") ;
} }
// Function to roughly estimate # of OpenAI tokens in String
private static int estimateTokenSize(String text) {
return (int) Math.round(text.length() / 3.5);
}
public static CompletableFuture<String> fetchMessageFromChatGPT(String systemPrompt, Map<String, String> context, List<ChatDataManager.ChatMessage> messageHistory, Boolean jsonMode) { public static CompletableFuture<String> fetchMessageFromChatGPT(String systemPrompt, Map<String, String> context, List<ChatDataManager.ChatMessage> messageHistory, Boolean jsonMode) {
return CompletableFuture.supplyAsync(() -> { return CompletableFuture.supplyAsync(() -> {
try { try {
...@@ -97,25 +112,37 @@ public class ChatGPTRequest { ...@@ -97,25 +112,37 @@ public class ChatGPTRequest {
systemMessage = replacePlaceholders(systemMessage, context); systemMessage = replacePlaceholders(systemMessage, context);
} }
URL url = new URL("https://api.openai.com/v1/chat/completions"); URL url = new URL(apiUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection(); HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST"); connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/json"); connection.setRequestProperty("Content-Type", "application/json");
connection.setRequestProperty("Authorization", "Bearer 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"); connection.setRequestProperty("Authorization", "Bearer " + apiKey);
connection.setDoOutput(true); connection.setDoOutput(true);
connection.setConnectTimeout(10000); // 10 seconds connection timeout connection.setConnectTimeout(10000); // 10 seconds connection timeout
connection.setReadTimeout(10000); // 10 seconds read timeout connection.setReadTimeout(10000); // 10 seconds read timeout
// Add system message
List<ChatGPTRequestMessage> messages = new ArrayList<>(); List<ChatGPTRequestMessage> messages = new ArrayList<>();
messages.add(new ChatGPTRequestMessage("system", systemMessage)); messages.add(new ChatGPTRequestMessage("system", systemMessage));
// Don't exceed a specific % of total context window (to limit message history in request)
int remainingContextTokens = (int) ((maxContextTokens - maxOutputTokens) * percentOfContext);
int usedTokens = estimateTokenSize("system: " + systemMessage);
// Loop through entity chat history (add messages to request, if we have enough context windows)
for (ChatDataManager.ChatMessage chatMessage : messageHistory) { for (ChatDataManager.ChatMessage chatMessage : messageHistory) {
String senderName = chatMessage.sender.toString().toLowerCase(); String senderName = chatMessage.sender.toString().toLowerCase();
String messageText = replacePlaceholders(chatMessage.message, context); String messageText = replacePlaceholders(chatMessage.message, context);
int messageTokens = estimateTokenSize(senderName + ": " + messageText);
if (usedTokens + messageTokens > remainingContextTokens) {
break;
}
messages.add(new ChatGPTRequestMessage(senderName, messageText)); messages.add(new ChatGPTRequestMessage(senderName, messageText));
usedTokens += messageTokens;
} }
// Convert JSON to String // Convert JSON to String
ChatGPTRequestPayload payload = new ChatGPTRequestPayload("gpt-3.5-turbo", messages, jsonMode, 1.0f); ChatGPTRequestPayload payload = new ChatGPTRequestPayload(modelName, messages, jsonMode, 1.0f, maxOutputTokens);
Gson gsonInput = new Gson(); Gson gsonInput = new Gson();
String jsonInputString = gsonInput.toJson(payload); String jsonInputString = gsonInput.toJson(payload);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment