{"id":595,"date":"2026-03-23T05:54:00","date_gmt":"2026-03-22T21:54:00","guid":{"rendered":"https:\/\/connectword.dpdns.org\/?p=595"},"modified":"2026-03-23T05:54:00","modified_gmt":"2026-03-22T21:54:00","slug":"implementing-deep-q-learning-dqn-from-scratch-using-rlax-jax-haiku-and-optax-to-train-a-cartpole-reinforcement-learning-agent","status":"publish","type":"post","link":"https:\/\/connectword.dpdns.org\/?p=595","title":{"rendered":"Implementing Deep Q-Learning (DQN) from Scratch Using RLax JAX Haiku and Optax to Train a CartPole Reinforcement Learning Agent"},"content":{"rendered":"<p>In this tutorial, we implement a reinforcement learning agent using <a href=\"https:\/\/github.com\/google-deepmind\/rlax\"><strong>RLax<\/strong><\/a>, a research-oriented library developed by Google DeepMind for building reinforcement learning algorithms with JAX. We combine RLax with JAX, Haiku, and Optax to construct a Deep Q-Learning (DQN) agent that learns to solve the CartPole environment. Instead of using a fully packaged RL framework, we assemble the training pipeline ourselves so we can clearly understand how the core components of reinforcement learning interact. We define the neural network, build a replay buffer, compute temporal difference errors with RLax, and train the agent using gradient-based optimization. Also, we focus on understanding how RLax provides reusable RL primitives that can be integrated into custom reinforcement learning pipelines. We use JAX for efficient numerical computation, Haiku for neural network modeling, and Optax for optimization.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">!pip -q install \"jax[cpu]\" dm-haiku optax rlax gymnasium matplotlib numpy\n\n\nimport os\nos.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n\n\nimport random\nimport time\nfrom dataclasses import dataclass\nfrom collections import deque\n\n\nimport gymnasium as gym\nimport haiku as hk\nimport jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport optax\nimport rlax\n\n\nseed = 42\nrandom.seed(seed)\nnp.random.seed(seed)\n\n\nenv = gym.make(\"CartPole-v1\")\neval_env = gym.make(\"CartPole-v1\")\n\n\nobs_dim = env.observation_space.shape[0]\nnum_actions = env.action_space.n\n\n\ndef q_network(x):\n   mlp = hk.Sequential([\n       hk.Linear(128), jax.nn.relu,\n       hk.Linear(128), jax.nn.relu,\n       hk.Linear(num_actions),\n   ])\n   return mlp(x)\n\n\nq_net = hk.without_apply_rng(hk.transform(q_network))\n\n\ndummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32)\nrng = jax.random.PRNGKey(seed)\nparams = q_net.init(rng, dummy_obs)\ntarget_params = params\n\n\noptimizer = optax.chain(\n   optax.clip_by_global_norm(10.0),\n   optax.adam(3e-4),\n)\nopt_state = optimizer.init(params)<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We install the required libraries and import all the modules needed for the reinforcement learning pipeline. We initialize the environment, define the neural network architecture using Haiku, and set up the Q-network that predicts action values. We also initialize the network and target network parameters, as well as the optimizer to be used during training.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">@dataclass\nclass Transition:\n   obs: np.ndarray\n   action: int\n   reward: float\n   discount: float\n   next_obs: np.ndarray\n   done: float\n\n\nclass ReplayBuffer:\n   def __init__(self, capacity):\n       self.buffer = deque(maxlen=capacity)\n\n\n   def add(self, *args):\n       self.buffer.append(Transition(*args))\n\n\n   def sample(self, batch_size):\n       batch = random.sample(self.buffer, batch_size)\n       obs = np.stack([t.obs for t in batch]).astype(np.float32)\n       action = np.array([t.action for t in batch], dtype=np.int32)\n       reward = np.array([t.reward for t in batch], dtype=np.float32)\n       discount = np.array([t.discount for t in batch], dtype=np.float32)\n       next_obs = np.stack([t.next_obs for t in batch]).astype(np.float32)\n       done = np.array([t.done for t in batch], dtype=np.float32)\n       return {\n           \"obs\": obs,\n           \"action\": action,\n           \"reward\": reward,\n           \"discount\": discount,\n           \"next_obs\": next_obs,\n           \"done\": done,\n       }\n\n\n   def __len__(self):\n       return len(self.buffer)\n\n\nreplay = ReplayBuffer(capacity=50000)\n\n\ndef epsilon_by_frame(frame_idx, eps_start=1.0, eps_end=0.05, decay_frames=20000):\n   mix = min(frame_idx \/ decay_frames, 1.0)\n   return eps_start + mix * (eps_end - eps_start)\n\n\ndef select_action(params, obs, epsilon):\n   if random.random() &lt; epsilon:\n       return env.action_space.sample()\n   q_values = q_net.apply(params, obs[None, :])\n   return int(jnp.argmax(q_values[0]))<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We define the transition structure and implement a replay buffer to store past experiences from the environment. We create functions to add transitions and sample mini-batches that will later be used to train the agent. We also implement the epsilon-greedy exploration strategy.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">@jax.jit\ndef soft_update(target_params, online_params, tau):\n   return jax.tree_util.tree_map(lambda t, s: (1.0 - tau) * t + tau * s, target_params, online_params)\n\n\ndef batch_td_errors(params, target_params, batch):\n   q_tm1 = q_net.apply(params, batch[\"obs\"])\n   q_t = q_net.apply(target_params, batch[\"next_obs\"])\n   td_errors = jax.vmap(\n       lambda q1, a, r, d, q2: rlax.q_learning(q1, a, r, d, q2)\n   )(q_tm1, batch[\"action\"], batch[\"reward\"], batch[\"discount\"], q_t)\n   return td_errors\n\n\n@jax.jit\ndef train_step(params, target_params, opt_state, batch):\n   def loss_fn(p):\n       td_errors = batch_td_errors(p, target_params, batch)\n       loss = jnp.mean(rlax.huber_loss(td_errors, delta=1.0))\n       metrics = {\n           \"loss\": loss,\n           \"td_abs_mean\": jnp.mean(jnp.abs(td_errors)),\n           \"q_mean\": jnp.mean(q_net.apply(p, batch[\"obs\"])),\n       }\n       return loss, metrics\n\n\n   (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)\n   updates, opt_state = optimizer.update(grads, opt_state, params)\n   params = optax.apply_updates(params, updates)\n   return params, opt_state, metrics<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We define the core learning functions used during training. We compute temporal difference errors using RLax\u2019s Q-learning primitive and calculate the loss using the Huber loss function. We then implement the training step that computes gradients, applies optimizer updates, and returns training metrics.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">def evaluate_agent(params, episodes=5):\n   returns = []\n   for ep in range(episodes):\n       obs, _ = eval_env.reset(seed=seed + 1000 + ep)\n       done = False\n       truncated = False\n       total_reward = 0.0\n       while not (done or truncated):\n           q_values = q_net.apply(params, obs[None, :])\n           action = int(jnp.argmax(q_values[0]))\n           next_obs, reward, done, truncated, _ = eval_env.step(action)\n           total_reward += reward\n           obs = next_obs\n       returns.append(total_reward)\n   return float(np.mean(returns))\n\n\nnum_frames = 40000\nbatch_size = 128\nwarmup_steps = 1000\ntrain_every = 4\neval_every = 2000\ngamma = 0.99\ntau = 0.01\nmax_grad_updates_per_step = 1\n\n\nobs, _ = env.reset(seed=seed)\nepisode_return = 0.0\nepisode_returns = []\neval_returns = []\nlosses = []\ntd_means = []\nq_means = []\neval_steps = []\n\n\nstart_time = time.time()<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We define the evaluation function that measures the agent\u2019s performance. We configure the training hyperparameters, including the number of frames, batch size, discount factor, and target network update rate. We also initialize variables that track training statistics, including episode returns, losses, and evaluation metrics.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">for frame_idx in range(1, num_frames + 1):\n   epsilon = epsilon_by_frame(frame_idx)\n   action = select_action(params, obs.astype(np.float32), epsilon)\n   next_obs, reward, done, truncated, _ = env.step(action)\n   terminal = done or truncated\n   discount = 0.0 if terminal else gamma\n\n\n   replay.add(\n       obs.astype(np.float32),\n       action,\n       float(reward),\n       float(discount),\n       next_obs.astype(np.float32),\n       float(terminal),\n   )\n\n\n   obs = next_obs\n   episode_return += reward\n\n\n   if terminal:\n       episode_returns.append(episode_return)\n       obs, _ = env.reset()\n       episode_return = 0.0\n\n\n   if len(replay) &gt;= warmup_steps and frame_idx % train_every == 0:\n       for _ in range(max_grad_updates_per_step):\n           batch_np = replay.sample(batch_size)\n           batch = {k: jnp.asarray(v) for k, v in batch_np.items()}\n           params, opt_state, metrics = train_step(params, target_params, opt_state, batch)\n           target_params = soft_update(target_params, params, tau)\n           losses.append(float(metrics[\"loss\"]))\n           td_means.append(float(metrics[\"td_abs_mean\"]))\n           q_means.append(float(metrics[\"q_mean\"]))\n\n\n   if frame_idx % eval_every == 0:\n       avg_eval_return = evaluate_agent(params, episodes=5)\n       eval_returns.append(avg_eval_return)\n       eval_steps.append(frame_idx)\n       recent_train = np.mean(episode_returns[-10:]) if episode_returns else 0.0\n       recent_loss = np.mean(losses[-100:]) if losses else 0.0\n       print(\n           f\"step={frame_idx:6d} | epsilon={epsilon:.3f} | \"\n           f\"recent_train_return={recent_train:7.2f} | \"\n           f\"eval_return={avg_eval_return:7.2f} | \"\n           f\"recent_loss={recent_loss:.5f} | buffer={len(replay)}\"\n       )\n\n\nelapsed = time.time() - start_time\nfinal_eval = evaluate_agent(params, episodes=10)\n\n\nprint(\"nTraining complete\")\nprint(f\"Elapsed time: {elapsed:.1f} seconds\")\nprint(f\"Final 10-episode evaluation return: {final_eval:.2f}\")\n\n\nplt.figure(figsize=(14, 4))\nplt.subplot(1, 3, 1)\nplt.plot(episode_returns)\nplt.title(\"Training Episode Returns\")\nplt.xlabel(\"Episode\")\nplt.ylabel(\"Return\")\n\n\nplt.subplot(1, 3, 2)\nplt.plot(eval_steps, eval_returns)\nplt.title(\"Evaluation Returns\")\nplt.xlabel(\"Environment Steps\")\nplt.ylabel(\"Avg Return\")\n\n\nplt.subplot(1, 3, 3)\nplt.plot(losses, label=\"Loss\")\nplt.plot(td_means, label=\"|TD Error| Mean\")\nplt.title(\"Optimization Metrics\")\nplt.xlabel(\"Gradient Updates\")\nplt.legend()\n\n\nplt.tight_layout()\nplt.show()\n\n\nobs, _ = eval_env.reset(seed=999)\nframes = []\ndone = False\ntruncated = False\ntotal_reward = 0.0\n\n\nrender_env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")\nobs, _ = render_env.reset(seed=999)\n\n\nwhile not (done or truncated):\n   frame = render_env.render()\n   frames.append(frame)\n   q_values = q_net.apply(params, obs[None, :])\n   action = int(jnp.argmax(q_values[0]))\n   obs, reward, done, truncated, _ = render_env.step(action)\n   total_reward += reward\n\n\nrender_env.close()\n\n\nprint(f\"Demo episode return: {total_reward:.2f}\")\n\n\ntry:\n   import matplotlib.animation as animation\n   from IPython.display import HTML, display\n\n\n   fig = plt.figure(figsize=(6, 4))\n   patch = plt.imshow(frames[0])\n   plt.axis(\"off\")\n\n\n   def animate(i):\n       patch.set_data(frames[i])\n       return (patch,)\n\n\n   anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=30, blit=True)\n   display(HTML(anim.to_jshtml()))\n   plt.close(fig)\nexcept Exception as e:\n   print(\"Animation display skipped:\", e)<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We run the full reinforcement learning training loop. We periodically update the network parameters, evaluate the agent\u2019s performance, and record metrics for visualization. Also, we plot the training results and render a demonstration episode to observe how the trained agent behaves.<\/p>\n<p>In conclusion, we built a complete Deep Q-Learning agent by combining RLax with the modern JAX-based machine learning ecosystem. We designed a neural network to estimate action values, implement experience replay to stabilize learning, and compute TD errors using RLax\u2019s Q-learning primitive. During training, we updated the network parameters using gradient-based optimization and periodically evaluated the agent to track performance improvements. Also, we saw how RLax enables a modular approach to reinforcement learning by providing reusable algorithmic components rather than full algorithms. This flexibility allows us to easily experiment with different architectures, learning rules, and optimization strategies. By extending this foundation, we can build more advanced agents, such as Double DQN, distributional reinforcement learning models, and actor\u2013critic methods, using the same RLax primitives.<\/p>\n<hr class=\"wp-block-separator has-alpha-channel-opacity\" \/>\n<p>Check out\u00a0the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Tutorial-Codes-Included\/blob\/main\/Reinforcement%20learning\/rlax_dqn_cartpole_jax_tutorial_Marktechpost.ipynb\" target=\"_blank\" rel=\"noreferrer noopener\">Full Notebook here<\/a>.\u00a0<\/strong>Also,\u00a0feel free to follow us on\u00a0<strong><a href=\"https:\/\/x.com\/intent\/follow?screen_name=marktechpost\" target=\"_blank\" rel=\"noreferrer noopener\"><mark>Twitter<\/mark><\/a><\/strong>\u00a0and don\u2019t forget to join our\u00a0<strong><a href=\"https:\/\/www.reddit.com\/r\/machinelearningnews\/\" target=\"_blank\" rel=\"noreferrer noopener\">120k+ ML SubReddit<\/a><\/strong>\u00a0and Subscribe to\u00a0<strong><a href=\"https:\/\/www.aidevsignals.com\/\" target=\"_blank\" rel=\"noreferrer noopener\">our Newsletter<\/a><\/strong>. Wait! are you on telegram?\u00a0<strong><a href=\"https:\/\/t.me\/machinelearningresearchnews\" target=\"_blank\" rel=\"noreferrer noopener\">now you can join us on telegram as well.<\/a><\/strong><\/p>\n<p>The post <a href=\"https:\/\/www.marktechpost.com\/2026\/03\/22\/implementing-deep-q-learning-dqn-from-scratch-using-rlax-jax-haiku-and-optax-to-train-a-cartpole-reinforcement-learning-agent\/\">Implementing Deep Q-Learning (DQN) from Scratch Using RLax JAX Haiku and Optax to Train a CartPole Reinforcement Learning Agent<\/a> appeared first on <a href=\"https:\/\/www.marktechpost.com\/\">MarkTechPost<\/a>.<\/p>","protected":false},"excerpt":{"rendered":"<p>In this tutorial, we implement&hellip;<\/p>\n","protected":false},"author":1,"featured_media":29,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[1],"tags":[],"class_list":["post-595","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-uncategorized"],"_links":{"self":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/posts\/595","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=595"}],"version-history":[{"count":0,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/posts\/595\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/media\/29"}],"wp:attachment":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=595"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=595"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=595"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}