Commit 91b1313f by Jonathan Thomas

Adding new testing framework:

- new junit test module
- 4 initial tests for attack, follow, and flee
- improved regex to support <behavior> or *behavior*
- improved message cleaning
- refactor of HTTP requests to remove some Minecraft and Fabric specific imports
parent 881c4305
Pipeline #12443 failed with stage
in 5 seconds
......@@ -6,8 +6,14 @@ All notable changes to **CreatureChat** are documented in this file. The format
## [Unreleased]
### Added
- New **Prompt Testing** module, for faster validation of LLMs and prompt changes
- New `stream = false` parameter to HTTP API requests (since some APIs default to `true`)
### Changed
- Many improvements to chat prompt for more balanced dialog and behaviors
- **Huge improvements** to **chat prompt** for more *balanced* dialog and *predictable* behaviors
- Improved **Behavior regex** to include both `<BEHAVIOR arg>` and `*BEHAVIOR arg*` syntax
- Improved **message cleaning** to remove any remaining `**` and `<>` after parsing behaviors
- Privacy Policy updated
## [1.0.5] - 2024-05-27
......
......@@ -42,8 +42,20 @@ dependencies {
// Uncomment the following line to enable the deprecated Fabric API modules.
// These are included in the Fabric API production distribution and allow you to update your mod to the latest modules at a later more convenient time.
// modImplementation "net.fabricmc.fabric-api:fabric-api-deprecated:${project.fabric_version}"
// Test module dependencies
testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.8.1'
testImplementation 'org.apache.commons:commons-lang3:3.12.0'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.1'
testImplementation "net.fabricmc.fabric-api:fabric-api:${project.fabric_version}"
testImplementation 'com.google.code.gson:gson:2.8.8'
testImplementation 'org.slf4j:slf4j-api:1.7.32'
testImplementation 'ch.qos.logback:logback-classic:1.2.6'
}
test {
useJUnitPlatform()
}
processResources {
......
......@@ -2,6 +2,7 @@ package com.owlmaddie.chat;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.owlmaddie.commands.ConfigurationHandler;
import com.owlmaddie.controls.SpeedControls;
import com.owlmaddie.goals.*;
import com.owlmaddie.items.RarityItemCollector;
......@@ -231,8 +232,12 @@ public class ChatDataManager {
// Add PLAYER context information
Map<String, String> contextData = getPlayerContext(player, userLanguage);
// Get config (api key, url, settings)
ConfigurationHandler.Config config = new ConfigurationHandler(ServerPackets.serverInstance).loadConfig();
String promptText = ChatPrompt.loadPromptFromResource(ServerPackets.serverInstance.getResourceManager(), systemPrompt);
// fetch HTTP response from ChatGPT
ChatGPTRequest.fetchMessageFromChatGPT(systemPrompt, contextData, previousMessages, false).thenAccept(output_message -> {
ChatGPTRequest.fetchMessageFromChatGPT(config, promptText, contextData, previousMessages, false).thenAccept(output_message -> {
if (output_message != null && systemPrompt == "system-character") {
// Character Sheet: Remove system-character message from previous messages
previousMessages.clear();
......@@ -469,8 +474,12 @@ public class ChatDataManager {
List<ChatMessage> messages = new ArrayList<>();
messages.add(new ChatMessage("Generate me a new fantasy story with ONLY the 1st character in the story", ChatSender.USER));
// Get config (api key, url, settings)
ConfigurationHandler.Config config = new ConfigurationHandler(ServerPackets.serverInstance).loadConfig();
String questPrompt = ChatPrompt.loadPromptFromResource(ServerPackets.serverInstance.getResourceManager(), "system-quest");
// Generate Quest: fetch HTTP response from ChatGPT
ChatGPTRequest.fetchMessageFromChatGPT("system-quest", contextData, messages, true).thenAccept(output_message -> {
ChatGPTRequest.fetchMessageFromChatGPT(config, questPrompt, contextData, messages, true).thenAccept(output_message -> {
// New Quest
Gson gson = new Gson();
quest = gson.fromJson(output_message, QuestJson.class);
......
......@@ -4,9 +4,6 @@ import com.google.gson.Gson;
import com.google.gson.JsonSyntaxException;
import com.owlmaddie.commands.ConfigurationHandler;
import com.owlmaddie.json.ChatGPTResponse;
import com.owlmaddie.network.ServerPackets;
import net.minecraft.resource.ResourceManager;
import net.minecraft.util.Identifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -42,12 +39,14 @@ public class ChatGPTRequest {
ResponseFormat response_format;
float temperature;
int max_tokens;
boolean stream;
public ChatGPTRequestPayload(String model, List<ChatGPTRequestMessage> messages, Boolean jsonMode, float temperature, int maxTokens) {
this.model = model;
this.messages = messages;
this.temperature = temperature;
this.max_tokens = maxTokens;
this.stream = false;
if (jsonMode) {
this.response_format = new ResponseFormat("json_object");
} else {
......@@ -94,7 +93,7 @@ public class ChatGPTRequest {
return response.error.message;
} else {
LOGGER.error("Unknown error response: " + errorResponse);
return "Unknown";
return "Unknown: " + errorResponse;
}
} catch (JsonSyntaxException e) {
LOGGER.warn("Failed to parse error response as JSON, falling back to plain text");
......@@ -105,24 +104,6 @@ public class ChatGPTRequest {
return removeQuotes(errorResponse);
}
// This method should be called in an appropriate context where ResourceManager is available
public static String loadPromptFromResource(ResourceManager resourceManager, String filePath) {
Identifier fileIdentifier = new Identifier("creaturechat", filePath);
try (InputStream inputStream = resourceManager.getResource(fileIdentifier).get().getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
StringBuilder contentBuilder = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
contentBuilder.append(line).append("\n");
}
return contentBuilder.toString();
} catch (Exception e) {
LOGGER.error("Failed to read prompt file", e);
}
return null;
}
// Function to replace placeholders in the template
public static String replacePlaceholders(String template, Map<String, String> replacements) {
String result = template;
......@@ -137,10 +118,7 @@ public class ChatGPTRequest {
return (int) Math.round(text.length() / 3.5);
}
public static CompletableFuture<String> fetchMessageFromChatGPT(String systemPrompt, Map<String, String> context, List<ChatDataManager.ChatMessage> messageHistory, Boolean jsonMode) {
// Get config (api key, url, settings)
ConfigurationHandler.Config config = new ConfigurationHandler(ServerPackets.serverInstance).loadConfig();
public static CompletableFuture<String> fetchMessageFromChatGPT(ConfigurationHandler.Config config, String systemPrompt, Map<String, String> contextData, List<ChatDataManager.ChatMessage> messageHistory, Boolean jsonMode) {
// Init API & LLM details
String apiUrl = config.getUrl();
String apiKey = config.getApiKey();
......@@ -152,11 +130,8 @@ public class ChatGPTRequest {
return CompletableFuture.supplyAsync(() -> {
try {
String systemMessage = "";
if (systemPrompt != null && !systemPrompt.isEmpty()) {
systemMessage = loadPromptFromResource(ServerPackets.serverInstance.getResourceManager(), "prompts/" + systemPrompt);
systemMessage = replacePlaceholders(systemMessage, context);
}
// Replace placeholders
String systemMessage = replacePlaceholders(systemPrompt, contextData);
URL url = new URL(apiUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
......@@ -178,7 +153,7 @@ public class ChatGPTRequest {
for (int i = messageHistory.size() - 1; i >= 0; i--) {
ChatDataManager.ChatMessage chatMessage = messageHistory.get(i);
String senderName = chatMessage.sender.toString().toLowerCase(Locale.ENGLISH);
String messageText = replacePlaceholders(chatMessage.message, context);
String messageText = replacePlaceholders(chatMessage.message, contextData);
int messageTokens = estimateTokenSize(senderName + ": " + messageText);
if (usedTokens + messageTokens > remainingContextTokens) {
......
package com.owlmaddie.chat;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import net.minecraft.resource.ResourceManager;
import net.minecraft.util.Identifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The {@code ChatPrompt} class is used to load a prompt from the Minecraft resource manager
*/
public class ChatPrompt {
public static final Logger LOGGER = LoggerFactory.getLogger("creaturechat");
// This method should be called in an appropriate context where ResourceManager is available
public static String loadPromptFromResource(ResourceManager resourceManager, String promptName) {
Identifier fileIdentifier = new Identifier("creaturechat", "prompts/" + promptName);
try (InputStream inputStream = resourceManager.getResource(fileIdentifier).get().getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
StringBuilder contentBuilder = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
contentBuilder.append(line).append("\n");
}
return contentBuilder.toString();
} catch (Exception e) {
LOGGER.error("Failed to read prompt file", e);
}
return null;
}
}
......@@ -19,7 +19,7 @@ public class MessageParser {
LOGGER.info("Parsing message: {}", input);
StringBuilder cleanedMessage = new StringBuilder();
List<Behavior> behaviors = new ArrayList<>();
Pattern pattern = Pattern.compile("<(\\w+)(?:\\s+(-?\\d+))?>");
Pattern pattern = Pattern.compile("[<*](\\w+)(?:\\s+(-?\\d+))?[>*]");
Matcher matcher = pattern.matcher(input);
while (matcher.find()) {
......@@ -34,8 +34,14 @@ public class MessageParser {
matcher.appendReplacement(cleanedMessage, "");
}
matcher.appendTail(cleanedMessage);
LOGGER.info("Cleaned message: {}", cleanedMessage.toString());
return new ParsedMessage(cleanedMessage.toString().trim(), input.trim(), behaviors);
// Get final cleaned string
String displayMessage = cleanedMessage.toString().trim();
// Remove all occurrences of "<>" and "**" (if any)
displayMessage = displayMessage.replaceAll("<>", "").replaceAll("\\*\\*", "").trim();
LOGGER.info("Cleaned message: {}", displayMessage);
return new ParsedMessage(displayMessage, input.trim(), behaviors);
}
}
package com.owlmaddie;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.owlmaddie.chat.ChatDataManager;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* The {@code EntityTestData} class is a test representation of our regular EntityChatData class. This allows us
* to simulate loading entity JSON data, adding new messages, and sending HTTP requests for the testing module. It
* is not possible to use the original class, due to Minecraft and Fabric imports and dependencies.
*/
public class EntityTestData {
public String entityId;
public String playerId;
public String currentMessage;
public int currentLineNumber;
public ChatDataManager.ChatStatus status;
public List<ChatDataManager.ChatMessage> previousMessages;
public String characterSheet;
public ChatDataManager.ChatSender sender;
public int friendship; // -3 to 3 (0 = neutral)
public int auto_generated;
public EntityTestData(String entityId, String playerId) {
this.entityId = entityId;
this.playerId = playerId;
this.currentMessage = "";
this.currentLineNumber = 0;
this.previousMessages = new ArrayList<>();
this.characterSheet = "";
this.status = ChatDataManager.ChatStatus.NONE;
this.sender = ChatDataManager.ChatSender.USER;
this.friendship = 0;
this.auto_generated = 0;
}
public String getCharacterProp(String propertyName) {
// Create a case-insensitive regex pattern to match the property name and capture its value
Pattern pattern = Pattern.compile("-?\\s*" + Pattern.quote(propertyName) + ":\\s*(.+)", Pattern.CASE_INSENSITIVE);
Matcher matcher = pattern.matcher(characterSheet);
if (matcher.find()) {
// Return the captured value, trimmed of any excess whitespace
return matcher.group(1).trim().replace("\"", "");
}
return "N/A";
}
// Add a message to the history and update the current message
public void addMessage(String message, ChatDataManager.ChatSender messageSender) {
// Truncate message (prevent crazy long messages... just in case)
String truncatedMessage = message.substring(0, Math.min(message.length(), ChatDataManager.MAX_CHAR_IN_USER_MESSAGE));
// Add message to history
previousMessages.add(new ChatDataManager.ChatMessage(truncatedMessage, messageSender));
// Set new message and reset line number of displayed text
currentMessage = truncatedMessage;
currentLineNumber = 0;
if (messageSender == ChatDataManager.ChatSender.ASSISTANT) {
// Show new generated message
status = ChatDataManager.ChatStatus.DISPLAY;
} else if (messageSender == ChatDataManager.ChatSender.USER) {
// Show pending icon
status = ChatDataManager.ChatStatus.PENDING;
}
sender = messageSender;
}
public Map<String, String> getPlayerContext(Path worldPath, Path playerPath, Path entityPath) {
Gson gson = new Gson();
Type mapType = new TypeToken<Map<String, String>>() {}.getType();
Map<String, String> contextData = new HashMap<>();
try {
// Load world context
String worldContent = Files.readString(worldPath);
Map<String, String> worldContext = gson.fromJson(worldContent, mapType);
contextData.putAll(worldContext);
// Load player context
String playerContent = Files.readString(playerPath);
Map<String, String> playerContext = gson.fromJson(playerContent, mapType);
contextData.putAll(playerContext);
// Load entity context
String entityContent = Files.readString(entityPath);
Map<String, String> entityContext = gson.fromJson(entityContent, mapType);
contextData.putAll(entityContext);
// Read character sheet info
contextData.put("entity_name", getCharacterProp("Name"));
contextData.put("entity_friendship", String.valueOf(this.friendship));
contextData.put("entity_personality", getCharacterProp("Personality"));
contextData.put("entity_speaking_style", getCharacterProp("Speaking Style / Tone"));
contextData.put("entity_likes", getCharacterProp("Likes"));
contextData.put("entity_dislikes", getCharacterProp("Dislikes"));
contextData.put("entity_age", getCharacterProp("Age"));
contextData.put("entity_alignment", getCharacterProp("Alignment"));
contextData.put("entity_class", getCharacterProp("Class"));
contextData.put("entity_skills", getCharacterProp("Skills"));
contextData.put("entity_background", getCharacterProp("Background"));
} catch (IOException e) {
e.printStackTrace();
}
return contextData;
}
}
\ No newline at end of file
package com.owlmaddie;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.owlmaddie.chat.ChatDataManager;
import com.owlmaddie.chat.ChatGPTRequest;
import com.owlmaddie.commands.ConfigurationHandler;
import com.owlmaddie.message.MessageParser;
import com.owlmaddie.message.ParsedMessage;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;
/**
* The {@code PromptTests} class tests a variety of LLM prompts and expected outputs from specific characters
* and personality types. For example, an aggressive character will attack, a nervous character will flee, etc...
*/
public class PromptTests {
static String PROMPT_PATH = "src/main/resources/data/creaturechat/prompts/";
static String RESOURCE_PATH = "src/test/resources/data/creaturechat/";
static String API_KEY = "";
static String API_URL = "";
ConfigurationHandler.Config config = null;
String systemChatContents = null;
List<String> followMessages = Arrays.asList(
"Please follow me",
"Come with me please",
"Quickly, please come this way");
List<String> attackMessages = Arrays.asList(
"<attacked you directly with Stone Axe>",
"<attacked you indirectly with Arrow>",
"DIEEE!");
static Path systemChatPath = Paths.get(PROMPT_PATH, "system-chat");
static Path bravePath = Paths.get(RESOURCE_PATH, "chatdata", "brave-archer.json");
static Path nervousPath = Paths.get(RESOURCE_PATH, "chatdata", "nervous-rogue.json");
static Path entityPigPath = Paths.get(RESOURCE_PATH, "entities", "pig.json");
static Path playerPath = Paths.get(RESOURCE_PATH, "players", "player.json");
static Path worldPath = Paths.get(RESOURCE_PATH, "worlds", "world.json");
Logger LOGGER = LoggerFactory.getLogger("creaturechat");
Gson gson = new GsonBuilder().create();
@BeforeEach
public void setup() {
// Get API key from env var
API_KEY = System.getenv("API_KEY");
API_URL = System.getenv("API_URL");
// Config
config = new ConfigurationHandler.Config();
if (API_KEY != null && !API_KEY.isEmpty()) {
config.setApiKey(API_KEY);
}
if (API_URL != null && !API_URL.isEmpty()) {
config.setUrl(API_URL);
}
// Load system chat prompt
systemChatContents = readFileContents(systemChatPath);
}
@Test
public void followBrave() {
for (String message : followMessages) {
testPromptForBehavior(bravePath, message, "FOLLOW");
}
}
@Test
public void followNervous() {
for (String message : followMessages) {
testPromptForBehavior(nervousPath, message, "FOLLOW");
}
}
@Test
public void attackBrave() {
for (String message : attackMessages) {
testPromptForBehavior(bravePath, message, "ATTACK");
}
}
@Test
public void attackNervous() {
for (String message : attackMessages) {
testPromptForBehavior(nervousPath, message, "FLEE");
}
}
public void testPromptForBehavior(Path chatDataPath, String message, String behavior) {
LOGGER.info("Testing '" + chatDataPath.getFileName() + "' with '" + message + "' and expecting behavior: " + behavior);
try {
// Load entity chat data
String chatDataPathContents = readFileContents(chatDataPath);
EntityTestData entityTestData = gson.fromJson(chatDataPathContents, EntityTestData.class);
// Load context
Map<String, String> contextData = entityTestData.getPlayerContext(worldPath, playerPath, entityPigPath);
assertNotNull(contextData);
// Add test message
entityTestData.addMessage(message, ChatDataManager.ChatSender.USER);
// Get prompt
Path promptPath = Paths.get(PROMPT_PATH, "system-chat");
String promptText = Files.readString(promptPath);
assertNotNull(promptText);
// fetch HTTP response from ChatGPT
CompletableFuture<String> future = ChatGPTRequest.fetchMessageFromChatGPT(config, promptText, contextData, entityTestData.previousMessages, false);
try {
String outputMessage = future.get(60 * 60, TimeUnit.SECONDS);
assertNotNull(outputMessage);
// Chat Message: Check for behavior
ParsedMessage result = MessageParser.parseMessage(outputMessage.replace("\n", " "));
//assertTrue(result.getBehaviors().stream().anyMatch(b -> expectedBehavior.equals(b.getName())));
} catch (TimeoutException e) {
fail("The asynchronous operation timed out.");
} catch (Exception e) {
fail("The asynchronous operation failed: " + e.getMessage());
}
} catch (IOException e) {
e.printStackTrace();
fail("Failed to read the file: " + e.getMessage());
}
LOGGER.info("");
}
public String readFileContents(Path filePath) {
try {
return Files.readString(filePath);
} catch (IOException e) {
e.printStackTrace();
return "";
}
}
}
{
"entityId": "4c3dd3a1-ae77-4243-a6d8-d847a4349367",
"playerId": "187217a2-f3ae-3640-aba9-fce2f2248422",
"currentMessage": "",
"currentLineNumber": 0,
"status": "DISPLAY",
"previousMessages": [
{
"message": "Greetings friend! You've stumbled upon my path. What say you?",
"sender": "ASSISTANT"
}
],
"characterSheet": "- Name: Ivy\n- Personality: Brave, adventurous, and noble\n- Speaking Style / Tone: Confident and determined with a touch of curiosity,n- Class: Archer\n- Skills: Archery, leadership\n- Likes: Nature, challenging shooting competitions, pursuit of justice\n- Dislikes: Injustice, losing a target, being confined\n- Alignment: Good\n- Background: Grew up protecting his birth town from enemies, and now protects all towns.\n- Short Greeting: \"Greetings friend! You've stumbled upon my path. What say you?\"",
"sender": "ASSISTANT",
"friendship": 0,
"auto_generated": 0
}
\ No newline at end of file
{
"entityId": "6fcfb8cb-29a7-4a92-853d-ef63066c5b35",
"playerId": "343a2278-4579-3a59-8a3a-aa2690a75202",
"currentMessage": "",
"currentLineNumber": 0,
"status": "DISPLAY",
"previousMessages": [
{
"message": "H-hello there... I-I hope you're not h-here to cause trouble...",
"sender": "ASSISTANT"
}
],
"characterSheet": "- Name: Jasper\n- Personality: Nervous, anxious, and easily startled\n- Speaking Style / Tone: Stuttering and shaky, always on edge\n- Class: Rogue\n- Skills: Stealth, lock picking\n- Likes: Hiding in shadows, avoiding confrontation, collecting rare items\n- Dislikes: Loud noises, unexpected surprises, being the center of attention\n- Alignment: Lawful Neutral\n- Background: Former thief, escaped a life of crime\n- Short Greeting: \"H-hello there... I-I hope you're not h-here to cause trouble...\"",
"sender": "ASSISTANT",
"friendship": 0,
"auto_generated": 0
}
\ No newline at end of file
{
"entity_health": "10.0/10.0",
"entity_type": "Pig"
}
\ No newline at end of file
{
"player_active_effects": "",
"player_armor_chest": "air",
"player_armor_feet": "air",
"player_armor_head": "air",
"player_armor_legs": "air",
"player_biome": "plains",
"player_health": "8.5/20.0",
"player_held_item": "porkchop",
"player_hunger": "17",
"player_is_creative": "yes",
"player_is_on_ground": "yes",
"player_is_swimming": "no",
"player_language": "English (US)",
"player_name": "Steve"
}
\ No newline at end of file
{
"world_difficulty": "normal",
"world_is_hardcore": "no",
"world_is_raining": "no",
"world_is_thundering": "no",
"world_moon_phase": "Waning Gibbous",
"world_time": "19:01"
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment