RL Training Example¶
This example shows how to train a reinforcement learning agent with pymahjong.
Setup¶
pip install pymahjong torch gymnasium sb3-contrib
PPO with sb3-contrib¶
sb3-contrib is the recommended package for training with gymnasium environments.
import gymnasium as gym
import numpy as np
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
import pymahjong
class MahjongGymWrapper(gym.Env):
"""Gymnasium wrapper for SingleAgentMahjongEnv."""
def __init__(self, opponent_agent="random"):
super().__init__()
self.env = pymahjong.SingleAgentMahjongEnv(opponent_agent=opponent_agent)
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
def reset(self, **kwargs):
obs, info = self.env.reset()
return obs, info
def step(self, action):
obs, reward, done, info = self.env.step(action)
return obs, reward, done, False, info
def get_action_mask(self):
"""Return boolean mask of valid actions (required by sb3-contrib ActionMasker)."""
mask = np.zeros(54, dtype=bool)
mask[self.env.get_valid_actions()] = True
return mask
def make_env():
env = MahjongGymWrapper(opponent_agent="random")
return ActionMasker(env, np.ones)
# Create vectorized environment
env = DummyVecEnv([make_env for _ in range(4)])
# Create MaskablePPO model (supports action masking via sb3-contrib)
model = MaskablePPO(
"MlpPolicy",
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
verbose=1,
)
# Train
model.learn(total_timesteps=1_000_000)
# Save
model.save("ppo_mahjong")
Custom Policy with Action Masking¶
With MaskablePPO and ActionMasker, action masking is handled automatically. For a custom policy:
from sb3_contrib.common.policies import ActorCriticCnnPolicy
from stable_baselines3.common.policies import ActorCriticPolicy
# Use standard MlpPolicy — ActionMasker handles masking at env level
model = MaskablePPO(
"MlpPolicy",
env,
policy_kwargs=dict(net_arch=[dict(pi=[256, 256], vf=[256, 256])]),
verbose=1,
)
PyTorch DQN Implementation¶
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import pymahjong
class DQNNetwork(nn.Module):
def __init__(self, obs_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(),
nn.Linear(obs_dim, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, action_dim),
)
def forward(self, x):
return self.net(x)
class DQNAgent:
def __init__(self, obs_shape=(93, 34), action_dim=54):
self.obs_dim = np.prod(obs_shape)
self.action_dim = action_dim
self.q_network = DQNNetwork(self.obs_dim, action_dim)
self.target_network = DQNNetwork(self.obs_dim, action_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=1e-4)
self.memory = deque(maxlen=100000)
self.batch_size = 64
self.gamma = 0.99
self.epsilon = 1.0
self.epsilon_min = 0.1
self.epsilon_decay = 0.9995
def select_action(self, obs, valid_actions):
if random.random() < self.epsilon:
return random.choice(valid_actions)
with torch.no_grad():
x = torch.FloatTensor(obs).flatten().unsqueeze(0)
q_values = self.q_network(x).squeeze()
# Mask invalid actions
mask = torch.full((self.action_dim,), float('-inf'))
mask[valid_actions] = 0
q_values = q_values + mask
return q_values.argmax().item()
def store(self, obs, action, reward, next_obs, done):
self.memory.append((obs, action, reward, next_obs, done))
def train(self):
if len(self.memory) < self.batch_size:
return
batch = random.sample(self.memory, self.batch_size)
obs, actions, rewards, next_obs, dones = zip(*batch)
obs = torch.FloatTensor(np.array([o.flatten() for o in obs]))
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_obs = torch.FloatTensor(np.array([o.flatten() for o in next_obs]))
dones = torch.FloatTensor(dones)
current_q = self.q_network(obs).gather(1, actions.unsqueeze(1))
with torch.no_grad():
next_q = self.target_network(next_obs).max(1)[0]
target_q = rewards + self.gamma * next_q * (1 - dones)
loss = nn.MSELoss()(current_q.squeeze(), target_q)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
def update_target(self):
self.target_network.load_state_dict(self.q_network.state_dict())
def train_dqn():
env = pymahjong.SingleAgentMahjongEnv(opponent_agent="random")
agent = DQNAgent()
num_episodes = 10000
target_update_freq = 100
for episode in range(num_episodes):
obs, _ = env.reset()
total_reward = 0
while True:
valid_actions = env.get_valid_actions()
action = agent.select_action(obs, valid_actions)
next_obs, reward, done, info = env.step(action)
agent.store(obs, action, reward, next_obs, done)
obs = next_obs
total_reward += reward
agent.train()
if done:
break
if episode % target_update_freq == 0:
agent.update_target()
if episode % 100 == 0:
print(f"Episode {episode}, Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.3f}")
return agent
if __name__ == "__main__":
agent = train_dqn()
torch.save(agent.q_network.state_dict(), "dqn_mahjong.pth")
Evaluation¶
def evaluate_agent(agent, num_games=100):
"""Evaluate agent performance."""
env = pymahjong.SingleAgentMahjongEnv(opponent_agent="random")
rewards = []
wins = 0
for _ in range(num_games):
obs, _ = env.reset()
total_reward = 0
while True:
valid_actions = env.get_valid_actions()
# Use trained agent
action = agent.select_action(obs, valid_actions)
obs, reward, done, info = env.step(action)
total_reward += reward
if done:
rewards.append(total_reward)
if total_reward > 0:
wins += 1
break
print(f"Average reward: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}")
print(f"Win rate: {wins / num_games * 100:.1f}%")
return rewards
Tips for Training¶
Sparse Rewards: Mahjong has sparse rewards (only at game end). Consider:
Using high gamma (0.99+)
Adding intermediate rewards for tenpai, riichi, etc.
Long Episodes: Games can be 50-150 steps. Consider:
Using n-step returns
Adding episode length normalization
Action Masking: Always use action masking to prevent invalid actions
Opponent Strength: Start with random opponents, gradually increase difficulty