Commit 23645ee3 by Jonathan Thomas

Merge branch 'develop' into hidden-icon

parents 28db12bf 237fd0be
...@@ -6,6 +6,13 @@ All notable changes to **CreatureChat** are documented in this file. The format ...@@ -6,6 +6,13 @@ All notable changes to **CreatureChat** are documented in this file. The format
## Unreleased ## Unreleased
### Added
- Rate limiter for LLM unit tests (to prevent rate limit issues from certain providers when running all tests)
- Check friendship direction (+ or -) in LLM unit tests (to verify friendship is output correctly)
### Changed
- Improved LLM unit tests to check for both a positive and negative behaviors (i.e. FOLLOW and not LEAD, ATTACK and not FLEE, etc...)
### Fixed ### Fixed
- Changing death message timestamp output to use DEBUG log level - Changing death message timestamp output to use DEBUG log level
......
...@@ -8,6 +8,7 @@ import com.owlmaddie.commands.ConfigurationHandler; ...@@ -8,6 +8,7 @@ import com.owlmaddie.commands.ConfigurationHandler;
import com.owlmaddie.message.MessageParser; import com.owlmaddie.message.MessageParser;
import com.owlmaddie.message.ParsedMessage; import com.owlmaddie.message.ParsedMessage;
import com.owlmaddie.utils.EntityTestData; import com.owlmaddie.utils.EntityTestData;
import com.owlmaddie.utils.RateLimiter;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.slf4j.Logger; import org.slf4j.Logger;
...@@ -38,13 +39,16 @@ public class BehaviorTests { ...@@ -38,13 +39,16 @@ public class BehaviorTests {
static String API_MODEL = ""; static String API_MODEL = "";
String NO_API_KEY = "No API_KEY environment variable has been set."; String NO_API_KEY = "No API_KEY environment variable has been set.";
// Requests per second limit
private static final RateLimiter rateLimiter = new RateLimiter(1);
ConfigurationHandler.Config config = null; ConfigurationHandler.Config config = null;
String systemChatContents = null; String systemChatContents = null;
List<String> followMessages = Arrays.asList( List<String> followMessages = Arrays.asList(
"Please follow me", "Please follow me",
"Come with me please", "Come with me please",
"Quickly, please come this way"); "Quickly, please join me on an adventure");
List<String> leadMessages = Arrays.asList( List<String> leadMessages = Arrays.asList(
"Take me to a secret forrest", "Take me to a secret forrest",
"Where is the strong hold?", "Where is the strong hold?",
...@@ -52,7 +56,7 @@ public class BehaviorTests { ...@@ -52,7 +56,7 @@ public class BehaviorTests {
List<String> attackMessages = Arrays.asList( List<String> attackMessages = Arrays.asList(
"<attacked you directly with Stone Axe>", "<attacked you directly with Stone Axe>",
"<attacked you indirectly with Arrow>", "<attacked you indirectly with Arrow>",
"Fight me now!"); "Fight me now or your city burns!");
List<String> protectMessages = Arrays.asList( List<String> protectMessages = Arrays.asList(
"Please protect me", "Please protect me",
"Please keep me safe friend", "Please keep me safe friend",
...@@ -109,88 +113,93 @@ public class BehaviorTests { ...@@ -109,88 +113,93 @@ public class BehaviorTests {
@Test @Test
public void followBrave() { public void followBrave() {
for (String message : followMessages) { for (String message : followMessages) {
testPromptForBehavior(bravePath, List.of(message), "FOLLOW"); testPromptForBehavior(bravePath, List.of(message), "FOLLOW", "LEAD");
} }
} }
@Test @Test
public void followNervous() { public void followNervous() {
for (String message : followMessages) { for (String message : followMessages) {
testPromptForBehavior(nervousPath, List.of(message), "FOLLOW"); testPromptForBehavior(nervousPath, List.of(message), "FOLLOW", "LEAD");
} }
} }
@Test @Test
public void leadBrave() { public void leadBrave() {
for (String message : leadMessages) { for (String message : leadMessages) {
testPromptForBehavior(bravePath, List.of(message), "LEAD"); testPromptForBehavior(bravePath, List.of(message), "LEAD", "FOLLOW");
} }
} }
@Test @Test
public void leadNervous() { public void leadNervous() {
for (String message : leadMessages) { for (String message : leadMessages) {
testPromptForBehavior(nervousPath, List.of(message), "LEAD"); testPromptForBehavior(nervousPath, List.of(message), "LEAD", "FOLLOW");
} }
} }
@Test @Test
public void unFleeBrave() { public void unFleeBrave() {
for (String message : unFleeMessages) { for (String message : unFleeMessages) {
testPromptForBehavior(bravePath, List.of(message), "UNFLEE"); testPromptForBehavior(bravePath, List.of(message), "UNFLEE", "FOLLOW");
} }
} }
@Test @Test
public void protectBrave() { public void protectBrave() {
for (String message : protectMessages) { for (String message : protectMessages) {
testPromptForBehavior(bravePath, List.of(message), "PROTECT"); testPromptForBehavior(bravePath, List.of(message), "PROTECT", "ATTACK");
} }
} }
@Test @Test
public void protectNervous() { public void protectNervous() {
for (String message : protectMessages) { for (String message : protectMessages) {
testPromptForBehavior(nervousPath, List.of(message), "PROTECT"); testPromptForBehavior(nervousPath, List.of(message), "PROTECT", "ATTACK");
} }
} }
@Test @Test
public void attackBrave() { public void attackBrave() {
for (String message : attackMessages) { for (String message : attackMessages) {
testPromptForBehavior(bravePath, List.of(message), "ATTACK"); testPromptForBehavior(bravePath, List.of(message), "ATTACK", "FLEE");
} }
} }
@Test @Test
public void attackNervous() { public void attackNervous() {
for (String message : attackMessages) { for (String message : attackMessages) {
testPromptForBehavior(nervousPath, List.of(message), "FLEE"); testPromptForBehavior(nervousPath, List.of(message), "FLEE", "ATTACK");
} }
} }
@Test @Test
public void friendshipUpNervous() { public void friendshipUpNervous() {
ParsedMessage result = testPromptForBehavior(nervousPath, friendshipUpMessages, "FRIENDSHIP"); ParsedMessage result = testPromptForBehavior(nervousPath, friendshipUpMessages, "FRIENDSHIP+", null);
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() > 0)); assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() > 0));
} }
@Test @Test
public void friendshipUpBrave() { public void friendshipUpBrave() {
ParsedMessage result = testPromptForBehavior(bravePath, friendshipUpMessages, "FRIENDSHIP"); ParsedMessage result = testPromptForBehavior(bravePath, friendshipUpMessages, "FRIENDSHIP+", null);
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() > 0)); assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() > 0));
} }
@Test @Test
public void friendshipDownNervous() { public void friendshipDownNervous() {
for (String message : friendshipDownMessages) { for (String message : friendshipDownMessages) {
ParsedMessage result = testPromptForBehavior(nervousPath, List.of(message), "FRIENDSHIP"); ParsedMessage result = testPromptForBehavior(nervousPath, List.of(message), "FRIENDSHIP-", null);
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() < 0)); assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) && b.getArgument() < 0));
} }
} }
public ParsedMessage testPromptForBehavior(Path chatDataPath, List<String> messages, String behavior) { public ParsedMessage testPromptForBehavior(Path chatDataPath, List<String> messages, String goodBehavior, String badBehavior) {
LOGGER.info("Testing '" + chatDataPath.getFileName() + "' with '" + messages.toString() + "' and expecting behavior: " + behavior); LOGGER.info("Testing '" + chatDataPath.getFileName() + "' with '" + messages.toString() +
"' expecting behavior: " + goodBehavior + " and avoid: " + badBehavior);
try {
// Enforce rate limit
rateLimiter.acquire();
try { try {
// Load entity chat data // Load entity chat data
...@@ -211,16 +220,31 @@ public class BehaviorTests { ...@@ -211,16 +220,31 @@ public class BehaviorTests {
String promptText = Files.readString(promptPath); String promptText = Files.readString(promptPath);
assertNotNull(promptText); assertNotNull(promptText);
// fetch HTTP response from ChatGPT // Fetch HTTP response from ChatGPT
CompletableFuture<String> future = ChatGPTRequest.fetchMessageFromChatGPT(config, promptText, contextData, entityTestData.previousMessages, false); CompletableFuture<String> future = ChatGPTRequest.fetchMessageFromChatGPT(
config, promptText, contextData, entityTestData.previousMessages, false);
try { try {
String outputMessage = future.get(60 * 60, TimeUnit.SECONDS); String outputMessage = future.get(60 * 60, TimeUnit.SECONDS);
assertNotNull(outputMessage); assertNotNull(outputMessage);
// Chat Message: Check for behavior // Chat Message: Check for behaviors
ParsedMessage result = MessageParser.parseMessage(outputMessage.replace("\n", " ")); ParsedMessage result = MessageParser.parseMessage(outputMessage.replace("\n", " "));
assertTrue(result.getBehaviors().stream().anyMatch(b -> behavior.equals(b.getName())));
// Check for the presence of good behavior
if (goodBehavior != null && goodBehavior.contains("FRIENDSHIP")) {
boolean isPositive = goodBehavior.equals("FRIENDSHIP+");
assertTrue(result.getBehaviors().stream().anyMatch(b -> "FRIENDSHIP".equals(b.getName()) &&
((isPositive && b.getArgument() > 0) || (!isPositive && b.getArgument() < 0))));
} else {
assertTrue(result.getBehaviors().stream().anyMatch(b -> goodBehavior.equals(b.getName())));
}
// Check for the absence of bad behavior if badBehavior is not empty
if (badBehavior != null && !badBehavior.isEmpty()) {
assertTrue(result.getBehaviors().stream().noneMatch(b -> badBehavior.equals(b.getName())));
}
return result; return result;
} catch (TimeoutException e) { } catch (TimeoutException e) {
...@@ -234,6 +258,10 @@ public class BehaviorTests { ...@@ -234,6 +258,10 @@ public class BehaviorTests {
fail("Failed to read the file: " + e.getMessage()); fail("Failed to read the file: " + e.getMessage());
} }
LOGGER.info(""); LOGGER.info("");
} catch (InterruptedException e) {
LOGGER.warn("Rate limit enforcement interrupted: " + e.getMessage());
}
return null; return null;
} }
......
package com.owlmaddie.utils;
import java.util.concurrent.Semaphore;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
/**
* The {@code RateLimiter} class is used to slow down LLM unit tests so we don't hit any rate limits accidentally.
*/
public class RateLimiter {
private final Semaphore semaphore;
public RateLimiter(int requestsPerSecond) {
semaphore = new Semaphore(requestsPerSecond);
Executors.newScheduledThreadPool(1).scheduleAtFixedRate(() -> {
semaphore.release(requestsPerSecond - semaphore.availablePermits());
}, 0, 1, TimeUnit.SECONDS);
}
public void acquire() throws InterruptedException {
semaphore.acquire();
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment