{"id":806,"date":"2026-04-28T02:58:20","date_gmt":"2026-04-27T18:58:20","guid":{"rendered":"https:\/\/connectword.dpdns.org\/?p=806"},"modified":"2026-04-28T02:58:20","modified_gmt":"2026-04-27T18:58:20","slug":"build-a-reinforcement-learning-powered-agent-that-learns-to-retrieve-relevant-long-term-memories-for-accurate-llm-question-answering","status":"publish","type":"post","link":"https:\/\/connectword.dpdns.org\/?p=806","title":{"rendered":"Build a Reinforcement Learning Powered Agent that Learns to Retrieve Relevant Long-Term Memories for Accurate LLM Question Answering"},"content":{"rendered":"<p>In this tutorial, we build a Reinforcement Learning\u2013driven agent that learns how to retrieve relevant memories from a long-term memory bank. We start by constructing a synthetic memory dataset and generating queries that require the agent to recall specific information. Using OpenAI embeddings, we convert both memories and queries into vector representations, enabling similarity signals to guide candidate retrieval. We then design a custom RL environment in which the agent observes features of candidate memories and learns a policy to select the most useful one. By training the agent with the PPO algorithm, we enable it to improve retrieval decisions beyond simple similarity search. Finally, we evaluate the system by comparing the RL-based retriever with a baseline approach and demonstrate how an LLM can use retrieved memories to generate accurate answers.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">import sys\nimport subprocess\nimport pkgutil\nimport os\nimport json\nimport math\nimport random\nimport textwrap\nimport getpass\nfrom dataclasses import dataclass\nfrom typing import List, Dict, Any, Tuple\n\n\ndef _install_if_missing(packages):\n   missing = []\n   for package_name, import_name in packages:\n       if pkgutil.find_loader(import_name) is None:\n           missing.append(package_name)\n   if missing:\n       subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\"] + missing)\n\n\n_install_if_missing([\n   (\"openai&gt;=1.40.0\", \"openai\"),\n   (\"gymnasium&gt;=0.29.1\", \"gymnasium\"),\n   (\"stable-baselines3&gt;=2.3.2\", \"stable_baselines3\"),\n   (\"numpy&gt;=1.26.4\", \"numpy\"),\n   (\"pandas&gt;=2.2.2\", \"pandas\"),\n   (\"scikit-learn&gt;=1.5.1\", \"sklearn\"),\n   (\"matplotlib&gt;=3.9.0\", \"matplotlib\"),\n   (\"tqdm&gt;=4.66.4\", \"tqdm\"),\n])\n\n\nimport numpy as np\nimport pandas as pd\nimport gymnasium as gym\nfrom gymnasium import spaces\nfrom tqdm.auto import tqdm\nimport matplotlib.pyplot as plt\nfrom sklearn.metrics.pairwise import cosine_similarity\nfrom stable_baselines3 import PPO\nfrom stable_baselines3.common.vec_env import DummyVecEnv\nfrom openai import OpenAI\n\n\nSEED = 42\nrandom.seed(SEED)\nnp.random.seed(SEED)\n\n\ntry:\n   from google.colab import userdata\n   OPENAI_API_KEY = userdata.get(\"OPENAI_API_KEY\")\nexcept Exception:\n   OPENAI_API_KEY = None\n\n\nif not OPENAI_API_KEY:\n   OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n\n\nif not OPENAI_API_KEY:\n   OPENAI_API_KEY = getpass.getpass(\"Enter OPENAI_API_KEY: \").strip()\n\n\nclient = OpenAI(api_key=OPENAI_API_KEY)\n\n\nEMBED_MODEL = \"text-embedding-3-small\"\nCHAT_MODEL = \"gpt-4o-mini\"\n\n\ndef chunked(xs, n):\n   for i in range(0, len(xs), n):\n       yield xs[i:i+n]\n\n\ndef embed_texts(texts: List[str], model: str = EMBED_MODEL, batch_size: int = 64) -&gt; np.ndarray:\n   outputs = []\n   for batch in tqdm(list(chunked(texts, batch_size)), desc=\"Embedding\"):\n       resp = client.embeddings.create(model=model, input=batch)\n       batch_vecs = [d.embedding for d in resp.data]\n       outputs.extend(batch_vecs)\n   arr = np.array(outputs, dtype=np.float32)\n   norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12\n   arr = arr \/ norms\n   return arr\n\n\ndef chat_answer(question: str, retrieved_memories: List[Dict[str, Any]], model: str = CHAT_MODEL) -&gt; str:\n   memory_block = \"n\".join([f\"[Memory {i+1}] {m['text']}\" for i, m in enumerate(retrieved_memories)])\n   system = \"You are a precise QA assistant. Answer the question using only the provided memories. If the memories do not contain the answer, say 'I do not know from the provided memories.'\"\n   user = f\"Question: {question}nnRetrieved memories:n{memory_block}nnAnswer:\"\n   resp = client.chat.completions.create(\n       model=model,\n       temperature=0,\n       messages=[\n           {\"role\": \"system\", \"content\": system},\n           {\"role\": \"user\", \"content\": user},\n       ],\n   )\n   return resp.choices[0].message.content.strip()\n\n\ndef llm_judge_exact(question: str, gold_answer: str, predicted_answer: str, model: str = CHAT_MODEL) -&gt; float:\n   system = \"You are a strict evaluator. Return only JSON with a single key 'score'. Use 1.0 if the predicted answer is semantically correct, 0.0 otherwise.\"\n   user = json.dumps({\n       \"question\": question,\n       \"gold_answer\": gold_answer,\n       \"predicted_answer\": predicted_answer,\n   }, ensure_ascii=False)\n   resp = client.chat.completions.create(\n       model=model,\n       temperature=0,\n       response_format={\"type\": \"json_object\"},\n       messages=[\n           {\"role\": \"system\", \"content\": system},\n           {\"role\": \"user\", \"content\": user},\n       ],\n   )\n   txt = resp.choices[0].message.content.strip()\n   try:\n       obj = json.loads(txt)\n       score = float(obj[\"score\"])\n       return 1.0 if score &gt;= 0.5 else 0.0\n   except Exception:\n       return 0.0<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We set up the environment required for our reinforcement learning\u2013based memory retrieval system. We install all required libraries, import the necessary modules, and securely load the OpenAI API key for embedding and language model interactions. We also define helper functions that generate embeddings, produce answers from retrieved memories, and evaluate answers using an LLM-based judging mechanism.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">@dataclass\nclass MemoryItem:\n   memory_id: int\n   topic: str\n   entity: str\n   slot: str\n   value: str\n   text: str\n\n\ndef build_memory_bank() -&gt; List[MemoryItem]:\n   entities = [\n       {\n           \"entity\": \"Astra\",\n           \"topic\": \"robotics\",\n           \"facts\": {\n               \"battery\": \"18 hours\",\n               \"sensor\": \"LiDAR\",\n               \"country\": \"Japan\",\n               \"release_year\": \"2023\",\n               \"specialty\": \"warehouse navigation\",\n           },\n       },\n       {\n           \"entity\": \"Orion\",\n           \"topic\": \"astronomy\",\n           \"facts\": {\n               \"telescope\": \"infrared array\",\n               \"country\": \"Chile\",\n               \"discovery_year\": \"2019\",\n               \"target\": \"exoplanet atmospheres\",\n               \"aperture\": \"8 meters\",\n           },\n       },\n       {\n           \"entity\": \"Vita\",\n           \"topic\": \"biomedicine\",\n           \"facts\": {\n               \"compound\": \"VX-17\",\n               \"trial_phase\": \"Phase II\",\n               \"country\": \"Canada\",\n               \"target\": \"inflammatory markers\",\n               \"delivery\": \"oral capsule\",\n           },\n       },\n       {\n           \"entity\": \"Nimbus\",\n           \"topic\": \"climate\",\n           \"facts\": {\n               \"satellite\": \"polar orbiter\",\n               \"country\": \"Norway\",\n               \"launch_year\": \"2022\",\n               \"instrument\": \"microwave radiometer\",\n               \"mission\": \"sea ice monitoring\",\n           },\n       },\n       {\n           \"entity\": \"Atlas\",\n           \"topic\": \"logistics\",\n           \"facts\": {\n               \"fleet_size\": \"240 trucks\",\n               \"hub\": \"Muscat\",\n               \"software\": \"predictive routing\",\n               \"fuel_policy\": \"hybrid-first\",\n               \"region\": \"GCC\",\n           },\n       },\n       {\n           \"entity\": \"Lumos\",\n           \"topic\": \"materials\",\n           \"facts\": {\n               \"alloy\": \"Ti-6Al-4V\",\n               \"process\": \"laser sintering\",\n               \"density\": \"4.43 g\/cm3\",\n               \"country\": \"Germany\",\n               \"use_case\": \"aerospace brackets\",\n           },\n       },\n       {\n           \"entity\": \"Cedar\",\n           \"topic\": \"agriculture\",\n           \"facts\": {\n               \"crop\": \"wheat\",\n               \"irrigation\": \"drip control\",\n               \"country\": \"India\",\n               \"yield_gain\": \"12 percent\",\n               \"soil_sensor\": \"capacitive probe\",\n           },\n       },\n       {\n           \"entity\": \"Pulse\",\n           \"topic\": \"healthcare\",\n           \"facts\": {\n               \"device\": \"ECG patch\",\n               \"battery\": \"7 days\",\n               \"country\": \"USA\",\n               \"connectivity\": \"Bluetooth Low Energy\",\n               \"use_case\": \"arrhythmia screening\",\n           },\n       },\n   ]\n\n\n   phrasing_templates = [\n       \"{entity} in {topic} uses {value} for {slot}.\",\n       \"The {slot} associated with {entity} is {value}.\",\n       \"{entity} has {slot}: {value}.\",\n       \"For {entity}, the recorded {slot} is {value}.\",\n       \"Reference note: {entity} -&gt; {slot} = {value}.\",\n   ]\n\n\n   distractor_templates = [\n       \"{entity} was discussed in a briefing about cross-domain innovation.\",\n       \"{entity} has been compared with several other projects in recent reports.\",\n       \"A summary note mentions {entity} among notable initiatives.\",\n       \"{entity} appears in a high-level update without technical details.\",\n       \"Stakeholders reviewed {entity} in a strategic planning session.\",\n   ]\n\n\n   memory_bank = []\n   memory_id = 0\n\n\n   for item in entities:\n       entity = item[\"entity\"]\n       topic = item[\"topic\"]\n       for slot, value in item[\"facts\"].items():\n           for t in phrasing_templates:\n               text = t.format(entity=entity, topic=topic, slot=slot, value=value)\n               memory_bank.append(MemoryItem(\n                   memory_id=memory_id,\n                   topic=topic,\n                   entity=entity,\n                   slot=slot,\n                   value=value,\n                   text=text\n               ))\n               memory_id += 1\n\n\n       for t in distractor_templates:\n           text = t.format(entity=entity)\n           memory_bank.append(MemoryItem(\n               memory_id=memory_id,\n               topic=topic,\n               entity=entity,\n               slot=\"distractor\",\n               value=\"n\/a\",\n               text=text\n           ))\n           memory_id += 1\n\n\n   extra_noise = [\n       \"General note: system maintenance occurred on Tuesday.\",\n       \"A committee discussed budget timelines and operational readiness.\",\n       \"The archive includes summaries of projects across multiple departments.\",\n       \"No relevant technical value is stated in this memory.\",\n       \"A status update mentioned partnerships and future opportunities.\",\n       \"An unrelated note references shipping delays and staffing changes.\",\n       \"Background memo: the team reviewed dashboards and reporting cadence.\",\n       \"This memory contains no answer-bearing facts.\",\n   ]\n\n\n   for text in extra_noise:\n       memory_bank.append(MemoryItem(\n           memory_id=memory_id,\n           topic=\"noise\",\n           entity=\"none\",\n           slot=\"distractor\",\n           value=\"n\/a\",\n           text=text\n       ))\n       memory_id += 1\n\n\n   return memory_bank\n\n\nmemory_bank = build_memory_bank()\nmemory_texts = [m.text for m in memory_bank]\nmemory_embeddings = embed_texts(memory_texts)\n\n\ndef build_queries(memory_bank: List[MemoryItem]) -&gt; List[Dict[str, Any]]:\n   patterns = [\n       \"What is the {slot} of {entity}?\",\n       \"Which {slot} does {entity} have?\",\n       \"Tell me the {slot} for {entity}.\",\n       \"Can you recall the {slot} associated with {entity}?\",\n       \"What was recorded as the {slot} of {entity}?\",\n   ]\n   queries = []\n   qid = 0\n   for m in memory_bank:\n       if m.slot == \"distractor\":\n           continue\n       q = random.choice(patterns).format(slot=m.slot.replace(\"_\", \" \"), entity=m.entity)\n       queries.append({\n           \"query_id\": qid,\n           \"query\": q,\n           \"entity\": m.entity,\n           \"slot\": m.slot,\n           \"gold_value\": m.value,\n           \"gold_memory_id\": m.memory_id,\n           \"gold_text\": m.text,\n           \"topic\": m.topic,\n       })\n       qid += 1\n   random.shuffle(queries)\n   return queries\n\n\nqueries = build_queries(memory_bank)\nquery_texts = [q[\"query\"] for q in queries]\nquery_embeddings = embed_texts(query_texts)<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We construct a synthetic long-term memory bank that simulates stored knowledge across multiple domains. We generate structured memory items and convert them into textual memories that can later be embedded for semantic retrieval. We also create query datasets from these memories and embed them so the agent can compare queries with stored knowledge.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">MEM_BY_ID = {m.memory_id: m for m in memory_bank}\nQUERY_BY_ID = {q[\"query_id\"]: q for q in queries}\n\n\ndef keyword_overlap(a: str, b: str) -&gt; float:\n   ta = set(a.lower().replace(\"?\", \"\").replace(\".\", \"\").split())\n   tb = set(b.lower().replace(\"?\", \"\").replace(\".\", \"\").split())\n   if not ta or not tb:\n       return 0.0\n   return len(ta &amp; tb) \/ max(1, len(ta | tb))\n\n\ndef get_top_k_candidates(query_idx: int, k: int = 8) -&gt; Dict[str, Any]:\n   qv = query_embeddings[query_idx:query_idx+1]\n   sims = cosine_similarity(qv, memory_embeddings)[0]\n   top_idx = np.argsort(-sims)[:k]\n   candidates = []\n   q = queries[query_idx]\n   for rank, midx in enumerate(top_idx):\n       mem = memory_bank[midx]\n       sim = float(sims[midx])\n       overlap = keyword_overlap(q[\"query\"], mem.text)\n       entity_match = 1.0 if q[\"entity\"].lower() in mem.text.lower() else 0.0\n       slot_match = 1.0 if q[\"slot\"].replace(\"_\", \" \").lower() in mem.text.lower() else 0.0\n       is_gold = 1.0 if mem.memory_id == q[\"gold_memory_id\"] else 0.0\n       candidates.append({\n           \"rank\": rank,\n           \"memory_index\": midx,\n           \"memory_id\": mem.memory_id,\n           \"text\": mem.text,\n           \"sim\": sim,\n           \"overlap\": overlap,\n           \"entity_match\": entity_match,\n           \"slot_match\": slot_match,\n           \"is_gold\": is_gold,\n       })\n   return {\"query\": q, \"candidates\": candidates}\n\n\nALL_CANDIDATES = [get_top_k_candidates(i, k=8) for i in range(len(queries))]\n\n\ndef build_state_features(item: Dict[str, Any]) -&gt; np.ndarray:\n   q = item[\"query\"]\n   feats = []\n   for c in item[\"candidates\"]:\n       feats.extend([\n           c[\"sim\"],\n           c[\"overlap\"],\n           c[\"entity_match\"],\n           c[\"slot_match\"],\n           1.0 \/ (1.0 + c[\"rank\"]),\n       ])\n   unique_topic_bonus = 1.0 if q[\"topic\"] in q[\"query\"].lower() else 0.0\n   query_len = min(len(q[\"query\"].split()) \/ 20.0, 1.0)\n   feats.extend([unique_topic_bonus, query_len])\n   return np.array(feats, dtype=np.float32)\n\n\nSTATE_DIM = len(build_state_features(ALL_CANDIDATES[0]))\nNUM_ACTIONS = len(ALL_CANDIDATES[0][\"candidates\"])\n\n\nclass MemoryRetrievalEnv(gym.Env):\n   metadata = {\"render_modes\": [\"human\"]}\n\n\n   def __init__(self, candidate_items: List[Dict[str, Any]], seed: int = 42):\n       super().__init__()\n       self.candidate_items = candidate_items\n       self.rng = np.random.default_rng(seed)\n       self.observation_space = spaces.Box(low=-10, high=10, shape=(STATE_DIM,), dtype=np.float32)\n       self.action_space = spaces.Discrete(NUM_ACTIONS)\n       self.current = None\n\n\n   def reset(self, seed=None, options=None):\n       if seed is not None:\n           self.rng = np.random.default_rng(seed)\n       idx = int(self.rng.integers(0, len(self.candidate_items)))\n       self.current = self.candidate_items[idx]\n       obs = build_state_features(self.current)\n       info = {\"query_id\": self.current[\"query\"][\"query_id\"]}\n       return obs, info\n\n\n   def step(self, action):\n       chosen = self.current[\"candidates\"][int(action)]\n       q = self.current[\"query\"]\n\n\n       reward = 0.0\n       reward += 2.0 * chosen[\"is_gold\"]\n       reward += 0.8 * chosen[\"entity_match\"]\n       reward += 0.6 * chosen[\"slot_match\"]\n       reward += 0.5 * chosen[\"sim\"]\n       reward += 0.3 * chosen[\"overlap\"]\n       reward -= 0.15 * chosen[\"rank\"]\n\n\n       done = True\n       truncated = False\n       info = {\n           \"query_id\": q[\"query_id\"],\n           \"chosen_memory_id\": chosen[\"memory_id\"],\n           \"gold_memory_id\": q[\"gold_memory_id\"],\n           \"chosen_text\": chosen[\"text\"],\n           \"gold_text\": q[\"gold_text\"],\n           \"is_correct\": bool(chosen[\"memory_id\"] == q[\"gold_memory_id\"]),\n           \"gold_value\": q[\"gold_value\"],\n           \"query\": q[\"query\"],\n       }\n       next_obs = np.zeros(self.observation_space.shape, dtype=np.float32)\n       return next_obs, float(reward), done, truncated, info<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We prepare candidate memories for each query by computing similarity scores between query embeddings and memory embeddings. We then construct feature vectors that describe each candidate memory using similarity, keyword overlap, entity matching, and rank signals. Finally, we define a custom reinforcement learning environment in which the agent learns to select the best memory to answer each query.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">split_1 = int(0.7 * len(ALL_CANDIDATES))\nsplit_2 = int(0.85 * len(ALL_CANDIDATES))\ntrain_items = ALL_CANDIDATES[:split_1]\nval_items = ALL_CANDIDATES[split_1:split_2]\ntest_items = ALL_CANDIDATES[split_2:]\n\n\ntrain_env = DummyVecEnv([lambda: MemoryRetrievalEnv(train_items, seed=SEED)])\nmodel = PPO(\n   \"MlpPolicy\",\n   train_env,\n   learning_rate=3e-4,\n   n_steps=256,\n   batch_size=64,\n   gamma=0.99,\n   gae_lambda=0.95,\n   ent_coef=0.01,\n   clip_range=0.2,\n   verbose=0,\n   seed=SEED,\n)\n\n\nmodel.learn(total_timesteps=12000)\n\n\ndef baseline_retrieve(item: Dict[str, Any]) -&gt; Dict[str, Any]:\n   best = max(item[\"candidates\"], key=lambda x: x[\"sim\"])\n   return best\n\n\ndef rl_retrieve(item: Dict[str, Any]) -&gt; Dict[str, Any]:\n   obs = build_state_features(item)\n   action, _ = model.predict(obs, deterministic=True)\n   return item[\"candidates\"][int(action)]\n\n\ndef evaluate_retriever(items: List[Dict[str, Any]], retriever_fn) -&gt; Dict[str, Any]:\n   rows = []\n   for item in items:\n       chosen = retriever_fn(item)\n       q = item[\"query\"]\n       rows.append({\n           \"query_id\": q[\"query_id\"],\n           \"query\": q[\"query\"],\n           \"gold_value\": q[\"gold_value\"],\n           \"gold_memory_id\": q[\"gold_memory_id\"],\n           \"chosen_memory_id\": chosen[\"memory_id\"],\n           \"correct_retrieval\": int(chosen[\"memory_id\"] == q[\"gold_memory_id\"]),\n           \"chosen_text\": chosen[\"text\"],\n       })\n   df = pd.DataFrame(rows)\n   return {\n       \"df\": df,\n       \"retrieval_accuracy\": df[\"correct_retrieval\"].mean(),\n   }\n\n\nbaseline_val = evaluate_retriever(val_items, baseline_retrieve)\nrl_val = evaluate_retriever(val_items, rl_retrieve)\nbaseline_test = evaluate_retriever(test_items, baseline_retrieve)\nrl_test = evaluate_retriever(test_items, rl_retrieve)\n\n\nprint(\"Validation Retrieval Accuracy\")\nprint(\"Baseline:\", round(float(baseline_val[\"retrieval_accuracy\"]), 4))\nprint(\"RL      :\", round(float(rl_val[\"retrieval_accuracy\"]), 4))\nprint()\nprint(\"Test Retrieval Accuracy\")\nprint(\"Baseline:\", round(float(baseline_test[\"retrieval_accuracy\"]), 4))\nprint(\"RL      :\", round(float(rl_test[\"retrieval_accuracy\"]), 4))\n\n\nresults_df = pd.DataFrame([\n   {\"split\": \"validation\", \"method\": \"baseline_cosine\", \"retrieval_accuracy\": float(baseline_val[\"retrieval_accuracy\"])},\n   {\"split\": \"validation\", \"method\": \"rl_agent\", \"retrieval_accuracy\": float(rl_val[\"retrieval_accuracy\"])},\n   {\"split\": \"test\", \"method\": \"baseline_cosine\", \"retrieval_accuracy\": float(baseline_test[\"retrieval_accuracy\"])},\n   {\"split\": \"test\", \"method\": \"rl_agent\", \"retrieval_accuracy\": float(rl_test[\"retrieval_accuracy\"])},\n])\ndisplay(results_df)\n\n\nplot_df = results_df.copy()\nfor split_name in [\"validation\", \"test\"]:\n   sub = plot_df[plot_df[\"split\"] == split_name]\n   plt.figure(figsize=(6, 4))\n   plt.bar(sub[\"method\"], sub[\"retrieval_accuracy\"])\n   plt.title(f\"Retrieval Accuracy on {split_name.title()}\")\n   plt.ylim(0, 1)\n   plt.ylabel(\"Accuracy\")\n   plt.show()<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We split the datasets and initialize the reinforcement learning model. We train a PPO agent to learn a policy for selecting the most relevant memory from a set of candidates. After training, we evaluate the agent\u2019s retrieval performance and compare it with a baseline embedding-similarity approach.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">def answer_with_retriever(item: Dict[str, Any], retriever_fn) -&gt; Dict[str, Any]:\n   q = item[\"query\"]\n   chosen = retriever_fn(item)\n   retrieved_memories = [{\n       \"memory_id\": chosen[\"memory_id\"],\n       \"text\": chosen[\"text\"],\n   }]\n   answer = chat_answer(q[\"query\"], retrieved_memories)\n   judged = llm_judge_exact(q[\"query\"], q[\"gold_value\"], answer)\n   return {\n       \"query\": q[\"query\"],\n       \"gold_value\": q[\"gold_value\"],\n       \"retrieved_text\": chosen[\"text\"],\n       \"predicted_answer\": answer,\n       \"answer_score\": judged,\n       \"retrieval_correct\": int(chosen[\"memory_id\"] == q[\"gold_memory_id\"]),\n   }\n\n\nsample_test_items = random.sample(test_items, min(12, len(test_items)))\nbaseline_answers = [answer_with_retriever(item, baseline_retrieve) for item in tqdm(sample_test_items, desc=\"Baseline QA\")]\nrl_answers = [answer_with_retriever(item, rl_retrieve) for item in tqdm(sample_test_items, desc=\"RL QA\")]\n\n\nbaseline_answer_df = pd.DataFrame(baseline_answers)\nrl_answer_df = pd.DataFrame(rl_answers)\n\n\nprint(\"Sample Downstream QA Accuracy\")\nprint(\"Baseline:\", round(float(baseline_answer_df[\"answer_score\"].mean()), 4))\nprint(\"RL      :\", round(float(rl_answer_df[\"answer_score\"].mean()), 4))\n\n\ncomparison = pd.DataFrame([\n   {\"method\": \"baseline_cosine\", \"qa_accuracy\": float(baseline_answer_df[\"answer_score\"].mean())},\n   {\"method\": \"rl_agent\", \"qa_accuracy\": float(rl_answer_df[\"answer_score\"].mean())},\n])\ndisplay(comparison)\n\n\nplt.figure(figsize=(6, 4))\nplt.bar(comparison[\"method\"], comparison[\"qa_accuracy\"])\nplt.title(\"Downstream QA Accuracy from Retrieved Memories\")\nplt.ylim(0, 1)\nplt.ylabel(\"Accuracy\")\nplt.show()\n\n\ndef inspect_examples(items: List[Dict[str, Any]], n: int = 5):\n   chosen_items = random.sample(items, min(n, len(items)))\n   rows = []\n   for item in chosen_items:\n       q = item[\"query\"]\n       base = baseline_retrieve(item)\n       rlm = rl_retrieve(item)\n       rows.append({\n           \"query\": q[\"query\"],\n           \"gold_value\": q[\"gold_value\"],\n           \"baseline_text\": base[\"text\"],\n           \"baseline_correct\": int(base[\"memory_id\"] == q[\"gold_memory_id\"]),\n           \"rl_text\": rlm[\"text\"],\n           \"rl_correct\": int(rlm[\"memory_id\"] == q[\"gold_memory_id\"]),\n       })\n   return pd.DataFrame(rows)\n\n\nexamples_df = inspect_examples(test_items, n=8)\npd.set_option(\"display.max_colwidth\", 200)\ndisplay(examples_df)<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We evaluate how well the retrieved memories support downstream question answering. We generate answers using the retrieved memory context and assess the answers with an LLM-based judge to determine correctness. We also inspect example queries to visually compare how the baseline retriever and the RL agent choose different memories.<\/p>\n<div class=\"dm-code-snippet dark dm-normal-version default no-background-mobile\">\n<div class=\"control-language\">\n<div class=\"dm-buttons\">\n<div class=\"dm-buttons-left\">\n<div class=\"dm-button-snippet red-button\"><\/div>\n<div class=\"dm-button-snippet orange-button\"><\/div>\n<div class=\"dm-button-snippet green-button\"><\/div>\n<\/div>\n<div class=\"dm-buttons-right\"><a><span class=\"dm-copy-text\">Copy Code<\/span><span class=\"dm-copy-confirmed\">Copied<\/span><span class=\"dm-error-message\">Use a different Browser<\/span><\/a><\/div>\n<\/div>\n<pre class=\" no-line-numbers\"><code class=\" no-wrap language-php\">def interactive_demo(question: str, top_k: int = 8):\n   qv = embed_texts()\n   sims = cosine_similarity(qv, memory_embeddings)[0]\n   top_idx = np.argsort(-sims)[:top_k]\n\n\n   candidates = []\n   for rank, midx in enumerate(top_idx):\n       mem = memory_bank[midx]\n       candidates.append({\n           \"rank\": rank,\n           \"memory_index\": int(midx),\n           \"memory_id\": int(mem.memory_id),\n           \"text\": mem.text,\n           \"sim\": float(sims[midx]),\n           \"overlap\": keyword_overlap(question, mem.text),\n           \"entity_match\": 0.0,\n           \"slot_match\": 0.0,\n           \"is_gold\": 0.0,\n       })\n\n\n   pseudo_item = {\n       \"query\": {\n           \"query_id\": -1,\n           \"query\": question,\n           \"entity\": \"unknown\",\n           \"slot\": \"unknown\",\n           \"gold_value\": \"unknown\",\n           \"gold_memory_id\": -1,\n           \"gold_text\": \"\",\n           \"topic\": \"unknown\",\n       },\n       \"candidates\": candidates,\n   }\n\n\n   obs = build_state_features(pseudo_item)\n   action, _ = model.predict(obs, deterministic=True)\n   selected = pseudo_item[\"candidates\"][int(action)]\n   answer = chat_answer(question, [{\"memory_id\": selected[\"memory_id\"], \"text\": selected[\"text\"]}])\n\n\n   print(\"=\" * 100)\n   print(\"QUESTION\")\n   print(question)\n   print(\"=\" * 100)\n   print(\"TOP CANDIDATES\")\n   for c in candidates:\n       print(f\"[Rank {c['rank']}] sim={c['sim']:.4f} | {c['text']}\")\n   print(\"=\" * 100)\n   print(\"RL-SELECTED MEMORY\")\n   print(selected[\"text\"])\n   print(\"=\" * 100)\n   print(\"ANSWER\")\n   print(answer)\n   print(\"=\" * 100)\n\n\ninteractive_demo(\"What is the battery of Pulse?\")\ninteractive_demo(\"Which hub does Atlas have?\")\ninteractive_demo(\"Tell me the country for Cedar.\")\n\n\nartifact_dir = \"\/content\/rl_agent_memory_retrieval_artifacts\"\nos.makedirs(artifact_dir, exist_ok=True)\n\n\nresults_df.to_csv(f\"{artifact_dir}\/retrieval_results.csv\", index=False)\nbaseline_val[\"df\"].to_csv(f\"{artifact_dir}\/baseline_val.csv\", index=False)\nrl_val[\"df\"].to_csv(f\"{artifact_dir}\/rl_val.csv\", index=False)\nbaseline_test[\"df\"].to_csv(f\"{artifact_dir}\/baseline_test.csv\", index=False)\nrl_test[\"df\"].to_csv(f\"{artifact_dir}\/rl_test.csv\", index=False)\nbaseline_answer_df.to_csv(f\"{artifact_dir}\/baseline_qa_sample.csv\", index=False)\nrl_answer_df.to_csv(f\"{artifact_dir}\/rl_qa_sample.csv\", index=False)\nexamples_df.to_csv(f\"{artifact_dir}\/example_comparisons.csv\", index=False)\n\n\nnp.save(f\"{artifact_dir}\/memory_embeddings.npy\", memory_embeddings)\nnp.save(f\"{artifact_dir}\/query_embeddings.npy\", query_embeddings)\nmodel.save(f\"{artifact_dir}\/ppo_memory_retriever\")\n\n\nwith open(f\"{artifact_dir}\/memory_bank.json\", \"w\") as f:\n   json.dump([m.__dict__ for m in memory_bank], f, indent=2)\n\n\nwith open(f\"{artifact_dir}\/queries.json\", \"w\") as f:\n   json.dump(queries, f, indent=2)\n\n\nprint(f\"Saved artifacts to: {artifact_dir}\")\nprint(\"Tutorial complete.\")<\/code><\/pre>\n<\/div>\n<\/div>\n<p>We build an interactive demonstration that lets us test the trained retrieval agent on new questions. We show the candidate memories, highlight the memory selected by the RL agent, and generate an answer using the selected context. Also, we save all artifacts, including embeddings, results, datasets, and the trained RL model, so that the system can be reused or further analyzed.<\/p>\n<p>In conclusion, we demonstrated how reinforcement learning can enhance memory retrieval in agentic AI systems. We trained an RL agent to select relevant memories from a set of candidates using signals such as semantic similarity, keyword overlap, and entity matching. We then evaluated the retriever and observed how the learned policy compares with traditional embedding-based retrieval methods. By integrating the retriever with an LLM, we also showed how better memory selection improves downstream question-answering performance. Through experiments, visualizations, and interactive demos, we explored how RL can optimize long-term memory access in intelligent agents.<\/p>\n<hr class=\"wp-block-separator aligncenter has-alpha-channel-opacity is-style-wide\" \/>\n<p>Check out\u00a0the\u00a0<strong><a href=\"https:\/\/github.com\/Marktechpost\/AI-Agents-Projects-Tutorials\/blob\/main\/Agentic%20AI%20Memory\/rl_agent_memory_retrieval_marktechpost.py\" target=\"_blank\" rel=\"noreferrer noopener\">FULL CODES here<\/a><\/strong>.<strong>\u00a0<\/strong>Also,\u00a0feel free to follow us on\u00a0<strong><a href=\"https:\/\/x.com\/intent\/follow?screen_name=marktechpost\" target=\"_blank\" rel=\"noreferrer noopener\"><mark>Twitter<\/mark><\/a><\/strong>\u00a0and don\u2019t forget to join our\u00a0<strong><a href=\"https:\/\/www.reddit.com\/r\/machinelearningnews\/\" target=\"_blank\" rel=\"noreferrer noopener\">130k+ ML SubReddit<\/a><\/strong>\u00a0and Subscribe to\u00a0<strong><a href=\"https:\/\/www.aidevsignals.com\/\" target=\"_blank\" rel=\"noreferrer noopener\">our Newsletter<\/a><\/strong>. Wait! are you on telegram?\u00a0<strong><a href=\"https:\/\/t.me\/machinelearningresearchnews\" target=\"_blank\" rel=\"noreferrer noopener\">now you can join us on telegram as well.<\/a><\/strong><\/p>\n<p>Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.?\u00a0<strong><a href=\"https:\/\/forms.gle\/MTNLpmJtsFA3VRVd9\" target=\"_blank\" rel=\"noreferrer noopener\"><mark>Connect with us<\/mark><\/a><\/strong><\/p>\n<p>The post <a href=\"https:\/\/www.marktechpost.com\/2026\/04\/27\/build-a-reinforcement-learning-powered-agent-that-learns-to-retrieve-relevant-long-term-memories\/\">Build a Reinforcement Learning Powered Agent that Learns to Retrieve Relevant Long-Term Memories for Accurate LLM Question Answering<\/a> appeared first on <a href=\"https:\/\/www.marktechpost.com\/\">MarkTechPost<\/a>.<\/p>","protected":false},"excerpt":{"rendered":"<p>In this tutorial, we build a R&hellip;<\/p>\n","protected":false},"author":1,"featured_media":29,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[1],"tags":[],"class_list":["post-806","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-uncategorized"],"_links":{"self":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/posts\/806","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=806"}],"version-history":[{"count":0,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/posts\/806\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=\/wp\/v2\/media\/29"}],"wp:attachment":[{"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=806"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=806"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/connectword.dpdns.org\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=806"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}