Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
CreatureChat
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Jobs
Commits
Open sidebar
Public
CreatureChat
Commits
237fd0be
Commit
237fd0be
authored
Jan 03, 2025
by
Jonathan Thomas
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improving LLM unit tests, adding Rate Limiter, and checking for + and - friendship.
parent
a119149f
Pipeline
#13250
passed with stages
in 2 minutes 8 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
52 deletions
+111
-52
CHANGELOG.md
CHANGELOG.md
+7
-0
BehaviorTests.java
src/test/java/com/owlmaddie/tests/BehaviorTests.java
+80
-52
RateLimiter.java
src/test/java/com/owlmaddie/utils/RateLimiter.java
+24
-0
No files found.
CHANGELOG.md
View file @
237fd0be
...
...
@@ -6,6 +6,13 @@ All notable changes to **CreatureChat** are documented in this file. The format
## Unreleased
### Added
-
Rate limiter for LLM unit tests (to prevent rate limit issues from certain providers when running all tests)
-
Check friendship direction (+ or -) in LLM unit tests (to verify friendship is output correctly)
### Changed
-
Improved LLM unit tests to check for both a positive and negative behaviors (i.e. FOLLOW and not LEAD, ATTACK and not FLEE, etc...)
### Fixed
-
Changing death message timestamp output to use DEBUG log level
...
...
src/test/java/com/owlmaddie/tests/BehaviorTests.java
View file @
237fd0be
...
...
@@ -8,6 +8,7 @@ import com.owlmaddie.commands.ConfigurationHandler;
import
com.owlmaddie.message.MessageParser
;
import
com.owlmaddie.message.ParsedMessage
;
import
com.owlmaddie.utils.EntityTestData
;
import
com.owlmaddie.utils.RateLimiter
;
import
org.junit.jupiter.api.BeforeEach
;
import
org.junit.jupiter.api.Test
;
import
org.slf4j.Logger
;
...
...
@@ -38,13 +39,16 @@ public class BehaviorTests {
static
String
API_MODEL
=
""
;
String
NO_API_KEY
=
"No API_KEY environment variable has been set."
;
// Requests per second limit
private
static
final
RateLimiter
rateLimiter
=
new
RateLimiter
(
1
);
ConfigurationHandler
.
Config
config
=
null
;
String
systemChatContents
=
null
;
List
<
String
>
followMessages
=
Arrays
.
asList
(
"Please follow me"
,
"Come with me please"
,
"Quickly, please
come this way
"
);
"Quickly, please
join me on an adventure
"
);
List
<
String
>
leadMessages
=
Arrays
.
asList
(
"Take me to a secret forrest"
,
"Where is the strong hold?"
,
...
...
@@ -52,7 +56,7 @@ public class BehaviorTests {
List
<
String
>
attackMessages
=
Arrays
.
asList
(
"<attacked you directly with Stone Axe>"
,
"<attacked you indirectly with Arrow>"
,
"Fight me now!"
);
"Fight me now
or your city burns
!"
);
List
<
String
>
protectMessages
=
Arrays
.
asList
(
"Please protect me"
,
"Please keep me safe friend"
,
...
...
@@ -109,131 +113,155 @@ public class BehaviorTests {
@Test
public
void
followBrave
()
{
for
(
String
message
:
followMessages
)
{
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"FOLLOW"
);
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"FOLLOW"
,
"LEAD"
);
}
}
@Test
public
void
followNervous
()
{
for
(
String
message
:
followMessages
)
{
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"FOLLOW"
);
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"FOLLOW"
,
"LEAD"
);
}
}
@Test
public
void
leadBrave
()
{
for
(
String
message
:
leadMessages
)
{
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"LEAD"
);
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"LEAD"
,
"FOLLOW"
);
}
}
@Test
public
void
leadNervous
()
{
for
(
String
message
:
leadMessages
)
{
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"LEAD"
);
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"LEAD"
,
"FOLLOW"
);
}
}
@Test
public
void
unFleeBrave
()
{
for
(
String
message
:
unFleeMessages
)
{
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"UNFLEE"
);
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"UNFLEE"
,
"FOLLOW"
);
}
}
@Test
public
void
protectBrave
()
{
for
(
String
message
:
protectMessages
)
{
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"PROTECT"
);
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"PROTECT"
,
"ATTACK"
);
}
}
@Test
public
void
protectNervous
()
{
for
(
String
message
:
protectMessages
)
{
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"PROTECT"
);
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"PROTECT"
,
"ATTACK"
);
}
}
@Test
public
void
attackBrave
()
{
for
(
String
message
:
attackMessages
)
{
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"ATTACK"
);
testPromptForBehavior
(
bravePath
,
List
.
of
(
message
),
"ATTACK"
,
"FLEE"
);
}
}
@Test
public
void
attackNervous
()
{
for
(
String
message
:
attackMessages
)
{
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"FLEE"
);
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"FLEE"
,
"ATTACK"
);
}
}
@Test
public
void
friendshipUpNervous
()
{
ParsedMessage
result
=
testPromptForBehavior
(
nervousPath
,
friendshipUpMessages
,
"FRIENDSHIP
"
);
ParsedMessage
result
=
testPromptForBehavior
(
nervousPath
,
friendshipUpMessages
,
"FRIENDSHIP
+"
,
null
);
assertTrue
(
result
.
getBehaviors
().
stream
().
anyMatch
(
b
->
"FRIENDSHIP"
.
equals
(
b
.
getName
())
&&
b
.
getArgument
()
>
0
));
}
@Test
public
void
friendshipUpBrave
()
{
ParsedMessage
result
=
testPromptForBehavior
(
bravePath
,
friendshipUpMessages
,
"FRIENDSHIP
"
);
ParsedMessage
result
=
testPromptForBehavior
(
bravePath
,
friendshipUpMessages
,
"FRIENDSHIP
+"
,
null
);
assertTrue
(
result
.
getBehaviors
().
stream
().
anyMatch
(
b
->
"FRIENDSHIP"
.
equals
(
b
.
getName
())
&&
b
.
getArgument
()
>
0
));
}
@Test
public
void
friendshipDownNervous
()
{
for
(
String
message
:
friendshipDownMessages
)
{
ParsedMessage
result
=
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"FRIENDSHIP
"
);
ParsedMessage
result
=
testPromptForBehavior
(
nervousPath
,
List
.
of
(
message
),
"FRIENDSHIP
-"
,
null
);
assertTrue
(
result
.
getBehaviors
().
stream
().
anyMatch
(
b
->
"FRIENDSHIP"
.
equals
(
b
.
getName
())
&&
b
.
getArgument
()
<
0
));
}
}
public
ParsedMessage
testPromptForBehavior
(
Path
chatDataPath
,
List
<
String
>
messages
,
String
behavior
)
{
LOGGER
.
info
(
"Testing '"
+
chatDataPath
.
getFileName
()
+
"' with '"
+
messages
.
toString
()
+
"' and expecting behavior: "
+
behavior
);
public
ParsedMessage
testPromptForBehavior
(
Path
chatDataPath
,
List
<
String
>
messages
,
String
goodBehavior
,
String
badBehavior
)
{
LOGGER
.
info
(
"Testing '"
+
chatDataPath
.
getFileName
()
+
"' with '"
+
messages
.
toString
()
+
"' expecting behavior: "
+
goodBehavior
+
" and avoid: "
+
badBehavior
);
try
{
// Load entity chat data
String
chatDataPathContents
=
readFileContents
(
chatDataPath
);
EntityTestData
entityTestData
=
gson
.
fromJson
(
chatDataPathContents
,
EntityTestData
.
class
);
// Load context
Map
<
String
,
String
>
contextData
=
entityTestData
.
getPlayerContext
(
worldPath
,
playerPath
,
entityPigPath
);
assertNotNull
(
contextData
);
// Add test message
for
(
String
message
:
messages
)
{
entityTestData
.
addMessage
(
message
,
ChatDataManager
.
ChatSender
.
USER
,
"TestPlayer1"
);
}
// Get prompt
Path
promptPath
=
Paths
.
get
(
PROMPT_PATH
,
"system-chat"
);
String
promptText
=
Files
.
readString
(
promptPath
);
assertNotNull
(
promptText
);
// fetch HTTP response from ChatGPT
CompletableFuture
<
String
>
future
=
ChatGPTRequest
.
fetchMessageFromChatGPT
(
config
,
promptText
,
contextData
,
entityTestData
.
previousMessages
,
false
);
// Enforce rate limit
rateLimiter
.
acquire
();
try
{
String
outputMessage
=
future
.
get
(
60
*
60
,
TimeUnit
.
SECONDS
);
assertNotNull
(
outputMessage
);
// Chat Message: Check for behavior
ParsedMessage
result
=
MessageParser
.
parseMessage
(
outputMessage
.
replace
(
"\n"
,
" "
));
assertTrue
(
result
.
getBehaviors
().
stream
().
anyMatch
(
b
->
behavior
.
equals
(
b
.
getName
())));
return
result
;
}
catch
(
TimeoutException
e
)
{
fail
(
"The asynchronous operation timed out."
);
}
catch
(
Exception
e
)
{
fail
(
"The asynchronous operation failed: "
+
e
.
getMessage
());
// Load entity chat data
String
chatDataPathContents
=
readFileContents
(
chatDataPath
);
EntityTestData
entityTestData
=
gson
.
fromJson
(
chatDataPathContents
,
EntityTestData
.
class
);
// Load context
Map
<
String
,
String
>
contextData
=
entityTestData
.
getPlayerContext
(
worldPath
,
playerPath
,
entityPigPath
);
assertNotNull
(
contextData
);
// Add test message
for
(
String
message
:
messages
)
{
entityTestData
.
addMessage
(
message
,
ChatDataManager
.
ChatSender
.
USER
,
"TestPlayer1"
);
}
// Get prompt
Path
promptPath
=
Paths
.
get
(
PROMPT_PATH
,
"system-chat"
);
String
promptText
=
Files
.
readString
(
promptPath
);
assertNotNull
(
promptText
);
// Fetch HTTP response from ChatGPT
CompletableFuture
<
String
>
future
=
ChatGPTRequest
.
fetchMessageFromChatGPT
(
config
,
promptText
,
contextData
,
entityTestData
.
previousMessages
,
false
);
try
{
String
outputMessage
=
future
.
get
(
60
*
60
,
TimeUnit
.
SECONDS
);
assertNotNull
(
outputMessage
);
// Chat Message: Check for behaviors
ParsedMessage
result
=
MessageParser
.
parseMessage
(
outputMessage
.
replace
(
"\n"
,
" "
));
// Check for the presence of good behavior
if
(
goodBehavior
!=
null
&&
goodBehavior
.
contains
(
"FRIENDSHIP"
))
{
boolean
isPositive
=
goodBehavior
.
equals
(
"FRIENDSHIP+"
);
assertTrue
(
result
.
getBehaviors
().
stream
().
anyMatch
(
b
->
"FRIENDSHIP"
.
equals
(
b
.
getName
())
&&
((
isPositive
&&
b
.
getArgument
()
>
0
)
||
(!
isPositive
&&
b
.
getArgument
()
<
0
))));
}
else
{
assertTrue
(
result
.
getBehaviors
().
stream
().
anyMatch
(
b
->
goodBehavior
.
equals
(
b
.
getName
())));
}
// Check for the absence of bad behavior if badBehavior is not empty
if
(
badBehavior
!=
null
&&
!
badBehavior
.
isEmpty
())
{
assertTrue
(
result
.
getBehaviors
().
stream
().
noneMatch
(
b
->
badBehavior
.
equals
(
b
.
getName
())));
}
return
result
;
}
catch
(
TimeoutException
e
)
{
fail
(
"The asynchronous operation timed out."
);
}
catch
(
Exception
e
)
{
fail
(
"The asynchronous operation failed: "
+
e
.
getMessage
());
}
}
catch
(
IOException
e
)
{
e
.
printStackTrace
();
fail
(
"Failed to read the file: "
+
e
.
getMessage
());
}
LOGGER
.
info
(
""
);
}
catch
(
IOException
e
)
{
e
.
printStackTrace
();
fail
(
"Failed to read the file: "
+
e
.
getMessage
());
}
catch
(
InterruptedException
e
)
{
LOGGER
.
warn
(
"Rate limit enforcement interrupted: "
+
e
.
getMessage
());
}
LOGGER
.
info
(
""
);
return
null
;
}
...
...
src/test/java/com/owlmaddie/utils/RateLimiter.java
0 → 100644
View file @
237fd0be
package
com
.
owlmaddie
.
utils
;
import
java.util.concurrent.Semaphore
;
import
java.util.concurrent.Executors
;
import
java.util.concurrent.TimeUnit
;
/**
* The {@code RateLimiter} class is used to slow down LLM unit tests so we don't hit any rate limits accidentally.
*/
public
class
RateLimiter
{
private
final
Semaphore
semaphore
;
public
RateLimiter
(
int
requestsPerSecond
)
{
semaphore
=
new
Semaphore
(
requestsPerSecond
);
Executors
.
newScheduledThreadPool
(
1
).
scheduleAtFixedRate
(()
->
{
semaphore
.
release
(
requestsPerSecond
-
semaphore
.
availablePermits
());
},
0
,
1
,
TimeUnit
.
SECONDS
);
}
public
void
acquire
()
throws
InterruptedException
{
semaphore
.
acquire
();
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment