--- license: apache-2.0 metrics: - accuracy tags: - supervised-learning - reinforcement-learning --- ```python import gymnasium as gym import os from tqdm import tqdm import torch class CartPole(torch.nn.Module): def __init__(self,): super(CartPole, self).__init__() self.model = torch.nn.Sequential( torch.nn.Linear(4,64), torch.nn.ReLU(), torch.nn.Linear(64,2), ) def forward(self, x): x = self.model(x) return x def run(model, episodes): video_length = episodes env = gym.make("CartPole-v1", render_mode="human") # human, rgb_array obs, _ = env.reset() total_reward = 0.0 with torch.no_grad(): for i in tqdm(range(video_length+1)): x = torch.tensor(obs).float().unsqueeze(0).to('cuda') action = model(x).argmax(dim=-1).item() obs, reward, terminated, truncated, info = env.step(action) if terminated or truncated: obs, _ = env.reset() total_reward+=reward env.close() print(f"total reward : {total_reward}") model = CartPole() model.load_state_dict(torch.load( os.path.join(os.getcwd(),"99.61_99_policy_net.pth") )['model_state_dict']) model.to("cuda") model.eval() run(model=model, episodes=500) ```