Investigation
-
Overview of slippi-ai
-
a ML system designed to trained AI agents to play SSB Melee competitively
-
the system implements a two-stage training pipeline
-
that begins with imitation learning from human gameplay data
-
and progresses to RL through self-play
for detailed information on subsystems theres System Architecture, Training Systems, Evaluation Systems, and Data Processing, we'll cover these later…
-
-
the system implements a two-stage training pipeline
-
a ML system designed to trained AI agents to play SSB Melee competitively
-
Project purpose and Scope
-
its predecessor relied on purely deep reinforcement learning
-
this system benefits from behavioral cloning from Slippi replay files to create agents that exhibit more human-like gameplay patterns before refining their strategies through self-play
raw replay data => AI agents includes data pre-processing, NN training, evaluation frameworks, and interactive applications such as netplay integration and twitch bot functionality
-
-
its predecessor relied on purely deep reinforcement learning
-
Training Pipeline overview
Stage 1: imitation learning
-
the first stage uses behavorial cloning to train agents on human gameplay data extracted from Slippi replay files
-
orchestrated in `scripts
\train
.py` and utilizes the `trainlib` module to implement supervised learning on state-action pairs derived from professional and high-level amateur gameplay
-
orchestrated in `scripts
Stage 2: reinforcement learning
-
the second stage takes the imitation-trained policy and refines it through self-play using proximal policy optimization PPO
-
handled by `slippi-ai
\rl
\run
/py` for single-agent training and `slippi-ai\rl
\train
two.py` for simultaneous two-agent training scenarios
-
handled by `slippi-ai
-
the first stage uses behavorial cloning to train agents on human gameplay data extracted from Slippi replay files
-
Tech Stack and Dependencies
DL: TF Probability for NN training and inference NN: DeepMind Sonnet for high-level network architecture DATA: Pandas + PyArrow+Parquet for replay parsing and dataset manipulation TELEMETRY: Wandb for training metrics and model versioning DISTRIBUTED: Ray for scalable evaluation and training EMULATOR: libmelee for dolphin emulator communication CONFIG: fancyflags for CLI-arg mgmt
Key Entry Points
Training
- scripts/train.py imitation learning from replay data
- slippiai/rl/run.py single-agent reinforcement learning
- slippiai/rl/train.py two-agent simultaneous training
Evaluation
- scripts/evaltwo.py local-agent evaluation and human play
- scripts/runevaluator.py batch evaluation with statistical analysis
- scripts/netplay.py online play
Data processing
- slippidb/parselocal.py
A Walk thru the code
Stepping through the processes
parsing local slippi replays
-
Download the compressed ranked replays
-
75 - 125 GBs == 120k - 170k replays
-
I sampled 3,300 replays for intial sanity test
-
-
Now we step through parselocal
-
expects to be supplied an organized "root" dir:
-
includes Root/
- /Raw, raw.json, Parsed, parsed.pkl, meta.json
- Raw contains .zip/.7z archives of .slp files
-
raw.json file contains info about each raw archive
- whether its been processed, if processed then removed to save space
-
Parsed dir populated by this script w/ a Parquet file for each .slp file
-
these files are named by the MD5 hash of .slp file
- and are used by imitation learning
-
these files are named by the MD5 hash of .slp file
- parsed.pkl pickle file contains metadata abt each processed .slp in Parsed
-
meta.json is created by scripts/makelocaldataset
- and used by imitation learning to know which files to train on
-
includes Root/
-
expects to be supplied an organized "root" dir:
Dependencies
concurrent.futures
json os pickle from absl import app, flags tqdm peppipy from slippidb import parsepeppi, preprocessing, utils, parsingutils
functions
-
parseslp(file, outputdir, tmpdir, compression, compressionlevel)
- result = dict(name=file.name)
-
utils.md5
- result.update( slpmd5 = md5, slpsize = len(slpbytes) )
-
game = peppipy.readslippi
- metadata = preproc.getmetadata(game)
-
istraining, reason = preproc.istrainingreplay(metadata)
- result.update(metadata)
- result.update(valid=true,istraining=istraining,nottrainingreason=reason)
-
if istraining
-
game = parsepeppi.frompeppi(Game)
- gamebytes =parsingutils.convertgame( game, compression=compression, compressionlevel=comrpessionlevel)
- result.update(pqsize=len(gamebytes)), compression=comrepssion.value)
-
with open(…'wb') as f
- f.write(gamebytes)
-
game = parsepeppi.frompeppi(Game)
- return result
- parsefiles
- parsechunk
- parse7zs
- runparsing
Steps
- standardized directory hierarchy under a root directory
-
multi-threaded in-memory extraction and parsing
-
derived qualities and filter candidate replays
- exclude bad AI
- damage threshold
- winner detection
- match deduplication
-
derived qualities and filter candidate replays
Training(s)
imitation learning
-
create experiment dir
- loads/restores checkpoints
- build train/test data sources from replay files
- create policy network and value function
- alternates between training steps and eval
- saves best models based on evaluation loss
-
walk thru code
- trainlib.train requires the Config struct
-
Configuration
struct Config runtime::RuntimeConfig dataset::DatasetConfig data::DataConfig observation::ObservationConfig learner::LearnerConfig network::NetworkConfig controllerhead::ControllerHeadConfig embed::EmbedConfig policy::PolicyConfig valuefunc::ValueFunctionConfig maxnames::Integer exptroot exptdir tag restorepickle tested version::Integer endRuntimeConfig
- max runtime in seconds
- interval for seconds between logging
- interval for seconds between saving to disk
- number for training steps between evaluations
- number for batches per evaluation
DataSetConfig
- data directory for parsed peppiDb
- metadata path for chunked data
- test ratio for splitting up training data
- allowed smash characters
- allowed smash opponents
- allowed player names
- banned player names
- yield swapped versions of each replay
- mirror left/right in each replay
- seed
DataConfig
- training batch size
- unroll length
- damage ratio
- compressed
- number of workers
- balance characters bool
ObservationConfig
-
animation::AnimationConfig AnimationConfig
- mask::Boolean
LearnerConfig
- learning rate::Float
- compile::Boolean
- jitcompile::Boolean
- decay rate::Float
- value cost::Float
- reward halflife::Float
NetworkConfig
- name='mlp'
- mlp=MLP.config
- lstm=LSTM.config
- gru=GRU.config
- reslstm=DeepResLSTM.config
- txlike=Transformerlike.config
ControllerHeadConfig
-
independent=Independent
- models each component of the controller independently
-
autoregressive=AutoRegressive
- samples components sequentially conditioned on past samples
EmbedConfig
- playerConfig
- controllerConfig
- randall::Bool
- fountainofdreams::Bool
-
itemsConfig PlayerConfig
- xy scale
- shield scale
- speed scale
- with speeds::Bool
- with controller::Bool
- with nana::Bool
- legacy jumps left::Bool
- axis spacing
- shoulder spacing
- type::ItemsType ItemsType SKIP or FLAT or MLP
- mlp sizes::Tuple{Int}
PolicyConfig
- train value head::Bool
- delay::Integer
ValueFunctionConfig
- train separate network::Bool
- separate network config::Bool
- network::NetworkConfig
-
Train
-
setup Wandb for logging
-
attempt to restore parameters using our pickle file
-
lots of config validation checks
-
create data sources for training and testing
- setup TrainManager and TestManager
TrainManager
-
Learner
-
DataSource
-
step kwargs
-
prefetch = 16 dataProfiler() stepProfiler() framesQueue queue.Queue(maxsize=prefetch) stopRequested threading.Event() dataThread threading.Thread(target=self.produceframes)
"used to produce tensors from frames" produceframes(self): while stop not requested
-
batch, epoch = next(self.dataSource)
- frames = batch.frames
-
frames = frames.replace( stateaction = self.learner.policy.embedstateaction.fromstate(frames.stateaction))
- frames = utils.mapnt(tf.converttotensor, frames)
-
data = (batch, epoch, frames)
- self.framesQueue.put(data)
"stop requested" stop(self): self.stoprequested.set() self.datathread.join()
-
batch, epoch = next(self.dataSource)
"step to get next frames in queue as input for batch training" step(self, compiled): batch, epoch, frames = self.framesqueue.get() stats, self.hiddenstate = self.learner.step( frames, self.hiddenstate, compiled, **kwargs) ) stats.update(epoch)
return stats, batch
inline funcs
-
gettfstate
- settfstate
- save
- maybelog (do a test step and log both train and test stats)
- maybeeval
-
reinforcement learning
- load imitation-trained policy as teacher
- environment setup for dolphin emulator instances
- actor-learner separates rollout collection from learning
-
performs policy gradient updates
- with KL constraints
- self-play: update opponent with current policy
Q-learning
- create sample policy and Q-policy
- initialize values and Q-value networks
- joint training: alternate between policy imitation and Q-learning
- action sampling: use sample policy to generate action candidates
- Q-policy updates: trains policy to select actions maximizing Q-values
Steps
- raw slippi replays (post meta-extraction)
- data source data library/module
- TrainManager for IL and LearnerManager for RL, and Learner for Q-Learning
-
the training system produces trained policy instances
-
it uses a hierarchial configuration approach with dataclasses that can be overriden via CLI flags
- each training system has its own top-level config class that composes various specialized config components
-
it uses a hierarchial configuration approach with dataclasses that can be overriden via CLI flags
Agent(s)
-
the agent system manages multiple types of agents with different execution models
- handling state synchronization and delayed inference
-
provides infra for managing AI agents during evaluation, gameplay, and real-time interaction with the Dolphin emulator
- handles agent instantiation, asynchronous inference, delay simulation, and controller output management
- built around a hierarchy of agent classes that provide different levels of functionality and performance optimization
Basic Agent
provides the fundamental agent functionality by wrapping a `Policy` and tracking recurrent hidden state across timesteps
-
policy integration wraps policy for inference
-
state management tracks hiddenstate and prevcontroller
-
batching support handles batched inference
-
compilation TF fun JIT compilation
-
Game embedding && needsreset -> BasicAgent.step()
-
embed state action && hiddenstate -> policy.sample() -> updated hiddenstate
- SampleOutputs -> controller state & logits
-
embed state action && hiddenstate -> policy.sample() -> updated hiddenstate
-
Game embedding && needsreset -> BasicAgent.step()
Delayed Agent/Async Delayed Agent
implements delay simulation to model realistic timing constraints between input perception and controller output DelayedAgent uses a PeekableQueue to buffer outputs and simulate processing delay
-
run synchronously
-
game state -> DelayedAgent.push() -> BasicAgent.step() -> outputqueue.put() <- initial queue fill <- dummysampleoutputs
-
DelayedAgent.pop() -> outputqueue.get() -> SampleOutputs
-
batchsteps > 0 -> multistep batching -> inputqueue -> BasicAgent.step()
-
AsyncDelayedAgent runs inference on a separate thread using threading pools and queues
-
runs asynchronously
- worker thread for multi-threading
- state queue for input queue for game states
- output queue for buffered controller outputs
- context manager for lifecycle , start, stop, run
Dolphin Integration Agent
-
the Agent class provides the highest level interface for interacting with dolphin emulator instances
Agent 2 Dolphin integration flow melee.GameState -> Agent.step() -> getgame() -> DelayedAgent.step() -> SampleOutputs -> embedcontroller.decode() -> sendcontroller -> melee.Controller
namecodes -> name management -> nameChangeMode -> FIXED/CYCLE/RANDOM
Agent Factory Functions
factory functions for creating approximately configured agents
- builddelayedagent() creates delayed agents w/ automatic name resolution and config
-
buildagent() creates fully-configured agent instances for Dolphin interaction
- opponent port, agent nametag, melee controller instance, saved agent state
Evaluation system
Evaluation system uses RolloutWorker and Evaluator classes to orchestrate agent execution across multiple envs
rollout() method coorinates between agents and environments to collect structured trajectory data
- states game states over time shape: [T+1, B]
- actions controller outputs shape: [T+1, B]
- rewards computed rewards shape: [T, B]
- isresetting reset flags shape: [T+1, B]
- initialstate agent's initial hidden state … [B]
- delayedactions buffered future actions [D,B]
Distributed Evaluation
RayEvaluator extends evaluation capabilities across multiple workers
RayEvaluator -> RayRolloutWorker.remote() -> Worker 1, 2, … N updatevariables() -> Parameter Sync -> Worker 1, 2, … N rollout() -> ray.get() -> Merge Results -> Aggregated Metrics