Commit 237fd0be by Jonathan Thomas

Improving LLM unit tests, adding Rate Limiter, and checking for + and - friendship.

parent a119149f
Pipeline #13250 passed with stages
in 2 minutes 8 seconds
......@@ -6,6 +6,13 @@ All notable changes to **CreatureChat** are documented in this file. The format
## Unreleased
### Added
- Rate limiter for LLM unit tests (to prevent rate limit issues from certain providers when running all tests)
- Check friendship direction (+ or -) in LLM unit tests (to verify friendship is output correctly)
### Changed
- Improved LLM unit tests to check for both a positive and negative behaviors (i.e. FOLLOW and not LEAD, ATTACK and not FLEE, etc...)
### Fixed
- Changing death message timestamp output to use DEBUG log level
......
......@@ -8,6 +8,7 @@ import com.owlmaddie.commands.ConfigurationHandler;
import com.owlmaddie.message.MessageParser;
import com.owlmaddie.message.ParsedMessage;
import com.owlmaddie.utils.EntityTestData;
import com.owlmaddie.utils.RateLimiter;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
......@@ -38,13 +39,16 @@ public class BehaviorTests {
static String API_MODEL = "";
String NO_API_KEY = "No API_KEY environment variable has been set.";
// Requests per second limit
private static final RateLimiter rateLimiter = new RateLimiter(1);
ConfigurationHandler.Config config = null;
String systemChatContents = null;
List<String> followMessages = Arrays.asList(
"Please follow me",
"Come with me please",
"Quickly, please come this way");
"Quickly, please join me on an adventure");
List<String> leadMessages = Arrays.asList(
"Take me to a secret forrest",
"Where is the strong hold?",
......@@ -52,7 +56,7 @@ public class BehaviorTests {
List<String> attackMessages = Arrays.asList(
"<attacked you directly with Stone Axe>",
"<attacked you indirectly with Arrow>",
"Fight me now!");
"Fight me now or your city burns!");
List<String> protectMessages = Arrays.asList(
"Please protect me",
"Please keep me safe friend",
......@@ -109,131 +113,155 @@ public class BehaviorTests {
@Test
public void followBrave() {
for (String message : followMessages) {
testPromptForBehavior(bravePath, List.of(message), "FOLLOW");
testPromptForBehavior(bravePath, List.of(message), "FOLLOW", "LEAD");
}
}
@Test
public void followNervous() {
for (String message : followMessages) {
testPromptForBehavior(nervousPath, List.of(message), "FOLLOW");
testPromptForBehavior(nervousPath, List.of(message), "FOLLOW", "LEAD");
}
}
@Test
public void leadBrave() {
for (String message : leadMessages) {
testPromptForBehavior(bravePath, List.of(message), "LEAD");
testPromptForBehavior(bravePath, List.of(message), "LEAD", "FOLLOW");
}
}
@Test
public void leadNervous() {
for (String message : leadMessages) {
testPromptForBehavior(nervousPath, List.of(message), "LEAD");
testPromptForBehavior(nervousPath, List.of(message), "LEAD", "FOLLOW");
}
}
@Test
public void unFleeBrave() {
for (String message : unFleeMessages) {
testPromptForBehavior(bravePath, List.of(message), "UNFLEE");
testPromptForBehavior(bravePath, List.of(message), "UNFLEE", "FOLLOW");
}
}
@Test
public void protectBrave() {
for (String message : protectMessages) {
testPromptForBehavior(bravePath, List.of(message), "PROTECT");
testPromptForBehavior(bravePath, List.of(message), "PROTECT", "ATTACK");
}
}
@Test
public void protectNervous() {
for (String message : protectMessages) {
testPromptForBehavior(nervousPath, List.of(message), "PROTECT");
testPromptForBehavior(nervousPath, List.of(message), "PROTECT", "ATTACK");
}
}
@Test
public void attackBrave() {
for (String message : attackMessages) {
testPromptForBehavior(bravePath, List.of(message), "ATTACK");
testPromptForBehavior(bravePath, List.of(message), "ATTACK", "FLEE");
}
}
@Test
public void attackNervous() {
for (String message : attackMessages) {
testPromptForBehavior(nervousPath, List.of(message), "FLEE");
testPromptForBehavior(nervousPath, List.of(message), "FLEE", "ATTACK");
}
}
@Test
public void friendshipUpNervous() {
ParsedMessage result = testPromptForBehavior(nervousPath, friendshipUpMessages, "FRIENDSHIP");
ParsedMessage result = testPromptForBehavior(nervousPath, friendshipUpMessages, "FRIENDSHIP+", null);
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() > 0));
}
@Test
public void friendshipUpBrave() {
ParsedMessage result = testPromptForBehavior(bravePath, friendshipUpMessages, "FRIENDSHIP");
ParsedMessage result = testPromptForBehavior(bravePath, friendshipUpMessages, "FRIENDSHIP+", null);
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() > 0));
}
@Test
public void friendshipDownNervous() {
for (String message : friendshipDownMessages) {
ParsedMessage result = testPromptForBehavior(nervousPath, List.of(message), "FRIENDSHIP");
ParsedMessage result = testPromptForBehavior(nervousPath, List.of(message), "FRIENDSHIP-", null);
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() < 0));
}
}
public ParsedMessage testPromptForBehavior(Path chatDataPath, List<String> messages, String behavior) {
LOGGER.info("Testing '" + chatDataPath.getFileName() + "' with '" + messages.toString() + "' and expecting behavior: " + behavior);
public ParsedMessage testPromptForBehavior(Path chatDataPath, List<String> messages, String goodBehavior, String badBehavior) {
LOGGER.info("Testing '" + chatDataPath.getFileName() + "' with '" + messages.toString() +
"' expecting behavior: " + goodBehavior + " and avoid: " + badBehavior);
try {
// Load entity chat data
String chatDataPathContents = readFileContents(chatDataPath);
EntityTestData entityTestData = gson.fromJson(chatDataPathContents, EntityTestData.class);
// Load context
Map<String, String> contextData = entityTestData.getPlayerContext(worldPath, playerPath, entityPigPath);
assertNotNull(contextData);
// Add test message
for (String message : messages) {
entityTestData.addMessage(message, ChatDataManager.ChatSender.USER, "TestPlayer1");
}
// Get prompt
Path promptPath = Paths.get(PROMPT_PATH, "system-chat");
String promptText = Files.readString(promptPath);
assertNotNull(promptText);
// fetch HTTP response from ChatGPT
CompletableFuture<String> future = ChatGPTRequest.fetchMessageFromChatGPT(config, promptText, contextData, entityTestData.previousMessages, false);
// Enforce rate limit
rateLimiter.acquire();
try {
String outputMessage = future.get(60 * 60, TimeUnit.SECONDS);
assertNotNull(outputMessage);
// Chat Message: Check for behavior
ParsedMessage result = MessageParser.parseMessage(outputMessage.replace("\n", " "));
assertTrue(result.getBehaviors().stream().anyMatch(b -> behavior.equals(b.getName())));
return result;
} catch (TimeoutException e) {
fail("The asynchronous operation timed out.");
} catch (Exception e) {
fail("The asynchronous operation failed: " + e.getMessage());
// Load entity chat data
String chatDataPathContents = readFileContents(chatDataPath);
EntityTestData entityTestData = gson.fromJson(chatDataPathContents, EntityTestData.class);
// Load context
Map<String, String> contextData = entityTestData.getPlayerContext(worldPath, playerPath, entityPigPath);
assertNotNull(contextData);
// Add test message
for (String message : messages) {
entityTestData.addMessage(message, ChatDataManager.ChatSender.USER, "TestPlayer1");
}
// Get prompt
Path promptPath = Paths.get(PROMPT_PATH, "system-chat");
String promptText = Files.readString(promptPath);
assertNotNull(promptText);
// Fetch HTTP response from ChatGPT
CompletableFuture<String> future = ChatGPTRequest.fetchMessageFromChatGPT(
config, promptText, contextData, entityTestData.previousMessages, false);
try {
String outputMessage = future.get(60 * 60, TimeUnit.SECONDS);
assertNotNull(outputMessage);
// Chat Message: Check for behaviors
ParsedMessage result = MessageParser.parseMessage(outputMessage.replace("\n", " "));
// Check for the presence of good behavior
if (goodBehavior != null && goodBehavior.contains("FRIENDSHIP")) {
boolean isPositive = goodBehavior.equals("FRIENDSHIP+");
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) &&
((isPositive && b.getArgument() > 0) || (!isPositive && b.getArgument() < 0))));
} else {
assertTrue(result.getBehaviors().stream().anyMatch(b -> goodBehavior.equals(b.getName())));
}
// Check for the absence of bad behavior if badBehavior is not empty
if (badBehavior != null && !badBehavior.isEmpty()) {
assertTrue(result.getBehaviors().stream().noneMatch(b -> badBehavior.equals(b.getName())));
}
return result;
} catch (TimeoutException e) {
fail("The asynchronous operation timed out.");
} catch (Exception e) {
fail("The asynchronous operation failed: " + e.getMessage());
}
} catch (IOException e) {
e.printStackTrace();
fail("Failed to read the file: " + e.getMessage());
}
LOGGER.info("");
} catch (IOException e) {
e.printStackTrace();
fail("Failed to read the file: " + e.getMessage());
} catch (InterruptedException e) {
LOGGER.warn("Rate limit enforcement interrupted: " + e.getMessage());
}
LOGGER.info("");
return null;
}
......
package com.owlmaddie.utils;
import java.util.concurrent.Semaphore;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
/**
* The {@code RateLimiter} class is used to slow down LLM unit tests so we don't hit any rate limits accidentally.
*/
public class RateLimiter {
private final Semaphore semaphore;
public RateLimiter(int requestsPerSecond) {
semaphore = new Semaphore(requestsPerSecond);
Executors.newScheduledThreadPool(1).scheduleAtFixedRate(() -> {
semaphore.release(requestsPerSecond - semaphore.availablePermits());
}, 0, 1, TimeUnit.SECONDS);
}
public void acquire() throws InterruptedException {
semaphore.acquire();
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment