Basic Usage¶
This example demonstrates the basic usage of pymahjong.
Single-agent Environment¶
import pymahjong
import numpy as np
# Create environment with random opponents
env = pymahjong.SingleAgentMahjongEnv(opponent_agent="random")
# Reset and get initial observation
obs, info = env.reset()
total_reward = 0
step = 0
while True:
# Get valid actions
valid_actions = env.get_valid_actions()
# Random action selection
action = np.random.choice(valid_actions)
# Step the environment
obs, reward, done, truncated, info = env.step(action)
total_reward += reward
step += 1
if done:
print(f"Game finished in {step} steps")
print(f"Final payoff: {total_reward}")
break
print(f"Average reward per step: {total_reward / step:.4f}")
Multi-agent Environment¶
import pymahjong
import numpy as np
env = pymahjong.MahjongEnv()
num_games = 10
results = []
for game in range(num_games):
env.reset()
steps = 0
while not env.is_over():
# Get current player
player_id = env.get_curr_player_id()
# Get observation and valid actions
obs = env.get_obs(player_id)
valid_actions = env.get_valid_actions()
# Random decision
action = np.random.choice(valid_actions)
# Execute action
env.step(player_id, action)
steps += 1
# Record results
payoffs = env.get_payoffs()
results.append(payoffs)
print(f"Game {game + 1}: steps={steps}, payoffs={payoffs}")
# Statistics
results = np.array(results)
print(f"\nAverage payoffs over {num_games} games:")
for i in range(4):
print(f" Player {i}: {results[:, i].mean():.2f} ± {results[:, i].std():.2f}")
Using Oracle Observation¶
import pymahjong
import numpy as np
env = pymahjong.MahjongEnv()
env.reset()
# Get different observation types
player_id = env.get_curr_player_id()
# Executor observation (visible game state)
executor_obs = env.get_obs(player_id)
print(f"Executor observation shape: {executor_obs.shape}") # (93, 34)
# Oracle observation (hidden information)
oracle_obs = env.get_oracle_obs(player_id)
print(f"Oracle observation shape: {oracle_obs.shape}") # (18, 34)
# Full observation
full_obs = env.get_full_obs(player_id)
print(f"Full observation shape: {full_obs.shape}") # (111, 34)
# Verify concatenation
assert np.array_equal(full_obs, np.concatenate([executor_obs, oracle_obs], axis=0))
Action Masking¶
import pymahjong
import numpy as np
env = pymahjong.SingleAgentMahjongEnv(opponent_agent="random")
obs, info = env.reset()
# Get valid actions as indices
valid_indices = env.get_valid_actions()
print(f"Valid action indices: {valid_indices}")
# Get valid actions as one-hot mask
valid_mask = env.get_valid_actions(nhot=True)
print(f"Valid action mask shape: {valid_mask.shape}") # (54,)
print(f"Number of valid actions: {valid_mask.sum()}")
# Example: softmax with masking
def masked_softmax(logits, mask):
"""Apply softmax only to valid actions."""
logits = logits.copy()
logits[~mask] = -np.inf # Set invalid actions to -inf
exp_logits = np.exp(logits - logits.max()) # Numerical stability
return exp_logits / exp_logits.sum()
# Simulate policy network output
logits = np.random.randn(54)
probs = masked_softmax(logits, valid_mask)
print(f"Action probabilities sum: {probs.sum():.4f}") # Should be 1.0
# Sample from distribution
action = np.random.choice(54, p=probs)
print(f"Sampled action: {action}")
Game State Inspection¶
import pymahjong as pm
env = pm.MahjongEnv()
env.reset()
# Access the underlying C++ Table
table = env.t
print(f"Current turn: {table.turn}")
print(f"Game wind: {table.game_wind}")
print(f"Parent (oya): {table.oya}")
print(f"Honba: {table.honba}")
print(f"Kyoutaku (riichi sticks): {table.kyoutaku}")
# Player information
for i, player in enumerate(table.players):
print(f"\nPlayer {i}:")
print(f" Score: {player.score}")
print(f" Riichi: {player.riichi}")
print(f" Tenpai: {player.is_tenpai()}")
# Dora information
dora = table.get_dora()
print(f"\nDora indicators: {[str(d) for d in dora]}")
# Remaining tiles
print(f"Remaining tiles: {table.get_remain_tile()}")
Custom Game Initialization¶
import pymahjong as pm
env = pm.MahjongEnv()
# Custom initialization
env.reset(
oya=2, # Player 2 is parent
game_wind="south", # South round
seed=42 # Fixed random seed
)
print(f"Parent: {env.t.oya}")
print(f"Game wind: {env.t.game_wind}")
# With specific initial scores
env.reset(
scores=[30000, 25000, 20000, 25000]
)
for i, player in enumerate(env.t.players):
print(f"Player {i} score: {player.score}")
Rendering¶
import pymahjong as pm
env = pm.MahjongEnv()
env.reset()
# Text-based rendering
env.render()
# Output example:
# ----------- current player: 0 -----------------
# [Player 0]
# Hand: 1m 2m 3m 4m 5m 6m 7m 8m 9m 1p 2p 3p 4p
# River: 5p(1) 6p(2)
# ...